Home | History | Annotate | Download | only in v4
      1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 // =============================================================================
     15 #include "tensorflow/contrib/tensor_forest/kernels/v4/params.h"
     16 #include "tensorflow/contrib/tensor_forest/proto/tensor_forest_params.pb.h"
     17 #include "tensorflow/core/platform/test.h"
     18 
     19 namespace {
     20 
     21 using tensorflow::tensorforest::DepthDependentParam;
     22 using tensorflow::tensorforest::ResolveParam;
     23 
     24 TEST(ParamsTest, TestConstant) {
     25   DepthDependentParam param;
     26   param.set_constant_value(10.0);
     27 
     28   ASSERT_EQ(ResolveParam(param, 0), 10.0);
     29   ASSERT_EQ(ResolveParam(param, 100), 10.0);
     30 }
     31 
     32 TEST(ParamsTest, TestLinear) {
     33   DepthDependentParam param;
     34   auto* linear = param.mutable_linear();
     35   linear->set_y_intercept(100.0);
     36   linear->set_slope(-10.0);
     37   linear->set_min_val(23.0);
     38   linear->set_max_val(90.0);
     39 
     40   ASSERT_EQ(ResolveParam(param, 0), 90);
     41   ASSERT_EQ(ResolveParam(param, 1), 90);
     42   ASSERT_EQ(ResolveParam(param, 2), 80);
     43 
     44   ASSERT_EQ(ResolveParam(param, 30), 23);
     45 }
     46 
     47 TEST(ParamsTest, TestExponential) {
     48   DepthDependentParam param;
     49   auto* expo = param.mutable_exponential();
     50   expo->set_bias(100.0);
     51   expo->set_base(10.0);
     52   expo->set_multiplier(-1.0);
     53   expo->set_depth_multiplier(1.0);
     54 
     55   ASSERT_EQ(ResolveParam(param, 0), 99);
     56   ASSERT_EQ(ResolveParam(param, 1), 90);
     57   ASSERT_EQ(ResolveParam(param, 2), 0);
     58 }
     59 
     60 TEST(ParamsTest, TestThreshold) {
     61   DepthDependentParam param;
     62   auto* threshold = param.mutable_threshold();
     63   threshold->set_on_value(100.0);
     64   threshold->set_off_value(10.0);
     65   threshold->set_threshold(5.0);
     66 
     67   ASSERT_EQ(ResolveParam(param, 0), 10);
     68   ASSERT_EQ(ResolveParam(param, 4), 10);
     69   ASSERT_EQ(ResolveParam(param, 5), 100);
     70   ASSERT_EQ(ResolveParam(param, 6), 100);
     71 }
     72 
     73 }  // namespace
     74