1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for the hybrid tensor forest model.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import random 21 22 # pylint: disable=unused-import 23 24 from tensorflow.contrib.tensor_forest.hybrid.python.models import k_feature_decisions_to_data_then_nn 25 from tensorflow.contrib.tensor_forest.python import tensor_forest 26 from tensorflow.python.framework import constant_op 27 from tensorflow.python.framework import test_util 28 from tensorflow.python.framework.ops import Operation 29 from tensorflow.python.framework.ops import Tensor 30 from tensorflow.python.ops import variable_scope 31 from tensorflow.python.platform import googletest 32 33 34 class KFeatureDecisionsToDataThenNNTest(test_util.TensorFlowTestCase): 35 36 def setUp(self): 37 self.params = tensor_forest.ForestHParams( 38 num_classes=2, 39 num_features=31, 40 layer_size=11, 41 num_layers=13, 42 num_trees=17, 43 connection_probability=0.1, 44 hybrid_tree_depth=4, 45 regularization_strength=0.01, 46 regularization="", 47 base_random_seed=10, 48 hybrid_feature_bagging_fraction=1.0, 49 learning_rate=0.01, 50 weight_init_mean=0.0, 51 weight_init_std=0.1) 52 self.params.regression = False 53 self.params.num_nodes = 2**self.params.hybrid_tree_depth - 1 54 self.params.num_leaves = 2**(self.params.hybrid_tree_depth - 1) 55 self.params.num_features_per_node = (self.params.feature_bagging_fraction * 56 self.params.num_features) 57 58 def testKFeatureInferenceConstruction(self): 59 # pylint: disable=W0612 60 data = constant_op.constant( 61 [[random.uniform(-1, 1) for i in range(self.params.num_features)] 62 for _ in range(100)]) 63 64 with variable_scope.variable_scope( 65 "KFeatureDecisionsToDataThenNNTest.testKFeatureInferenceContruction"): 66 graph_builder = ( 67 k_feature_decisions_to_data_then_nn.KFeatureDecisionsToDataThenNN( 68 self.params)) 69 graph = graph_builder.inference_graph(data, None) 70 71 self.assertTrue(isinstance(graph, Tensor)) 72 73 def testKFeatureTrainingConstruction(self): 74 # pylint: disable=W0612 75 data = constant_op.constant( 76 [[random.uniform(-1, 1) for i in range(self.params.num_features)] 77 for _ in range(100)]) 78 79 labels = [1 for _ in range(100)] 80 81 with variable_scope.variable_scope( 82 "KFeatureDecisionsToDataThenNNTest.testKFeatureTrainingContruction"): 83 graph_builder = ( 84 k_feature_decisions_to_data_then_nn.KFeatureDecisionsToDataThenNN( 85 self.params)) 86 graph = graph_builder.training_graph(data, labels, None) 87 88 self.assertTrue(isinstance(graph, Operation)) 89 90 91 if __name__ == "__main__": 92 googletest.main() 93