1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 #include "tensorflow/contrib/tensor_forest/kernels/v4/params.h" 16 #include "tensorflow/contrib/tensor_forest/proto/tensor_forest_params.pb.h" 17 #include "tensorflow/core/platform/test.h" 18 19 namespace { 20 21 using tensorflow::tensorforest::DepthDependentParam; 22 using tensorflow::tensorforest::ResolveParam; 23 24 TEST(ParamsTest, TestConstant) { 25 DepthDependentParam param; 26 param.set_constant_value(10.0); 27 28 ASSERT_EQ(ResolveParam(param, 0), 10.0); 29 ASSERT_EQ(ResolveParam(param, 100), 10.0); 30 } 31 32 TEST(ParamsTest, TestLinear) { 33 DepthDependentParam param; 34 auto* linear = param.mutable_linear(); 35 linear->set_y_intercept(100.0); 36 linear->set_slope(-10.0); 37 linear->set_min_val(23.0); 38 linear->set_max_val(90.0); 39 40 ASSERT_EQ(ResolveParam(param, 0), 90); 41 ASSERT_EQ(ResolveParam(param, 1), 90); 42 ASSERT_EQ(ResolveParam(param, 2), 80); 43 44 ASSERT_EQ(ResolveParam(param, 30), 23); 45 } 46 47 TEST(ParamsTest, TestExponential) { 48 DepthDependentParam param; 49 auto* expo = param.mutable_exponential(); 50 expo->set_bias(100.0); 51 expo->set_base(10.0); 52 expo->set_multiplier(-1.0); 53 expo->set_depth_multiplier(1.0); 54 55 ASSERT_EQ(ResolveParam(param, 0), 99); 56 ASSERT_EQ(ResolveParam(param, 1), 90); 57 ASSERT_EQ(ResolveParam(param, 2), 0); 58 } 59 60 TEST(ParamsTest, TestThreshold) { 61 DepthDependentParam param; 62 auto* threshold = param.mutable_threshold(); 63 threshold->set_on_value(100.0); 64 threshold->set_off_value(10.0); 65 threshold->set_threshold(5.0); 66 67 ASSERT_EQ(ResolveParam(param, 0), 10); 68 ASSERT_EQ(ResolveParam(param, 4), 10); 69 ASSERT_EQ(ResolveParam(param, 5), 100); 70 ASSERT_EQ(ResolveParam(param, 6), 100); 71 } 72 73 } // namespace 74