1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for the hybrid tensor forest model.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import random 21 22 # pylint: disable=unused-import 23 24 from tensorflow.contrib.tensor_forest.hybrid.python.models import decisions_to_data_then_nn 25 from tensorflow.contrib.tensor_forest.python import tensor_forest 26 from tensorflow.python.framework import constant_op 27 from tensorflow.python.framework import test_util 28 from tensorflow.python.framework.ops import Operation 29 from tensorflow.python.framework.ops import Tensor 30 from tensorflow.python.ops import variable_scope 31 from tensorflow.python.platform import googletest 32 33 34 class DecisionsToDataThenNNTest(test_util.TensorFlowTestCase): 35 36 def setUp(self): 37 self.params = tensor_forest.ForestHParams( 38 num_classes=2, 39 num_features=31, 40 layer_size=11, 41 num_layers=13, 42 num_trees=17, 43 connection_probability=0.1, 44 hybrid_tree_depth=4, 45 regularization_strength=0.01, 46 learning_rate=0.01, 47 regularization="", 48 weight_init_mean=0.0, 49 weight_init_std=0.1) 50 self.params.regression = False 51 self.params.num_nodes = 2**self.params.hybrid_tree_depth - 1 52 self.params.num_leaves = 2**(self.params.hybrid_tree_depth - 1) 53 54 def testHParams(self): 55 self.assertEquals(self.params.num_classes, 2) 56 self.assertEquals(self.params.num_features, 31) 57 self.assertEquals(self.params.layer_size, 11) 58 self.assertEquals(self.params.num_layers, 13) 59 self.assertEquals(self.params.num_trees, 17) 60 self.assertEquals(self.params.hybrid_tree_depth, 4) 61 self.assertEquals(self.params.connection_probability, 0.1) 62 63 # Building the graphs modifies the params. 64 with variable_scope.variable_scope("DecisionsToDataThenNNTest_testHParams"): 65 # pylint: disable=W0612 66 graph_builder = decisions_to_data_then_nn.DecisionsToDataThenNN( 67 self.params) 68 69 # Tree with depth 4 should have 2**0 + 2**1 + 2**2 + 2**3 = 15 nodes. 70 self.assertEquals(self.params.num_nodes, 15) 71 72 def testConstructionPollution(self): 73 """Ensure that graph building doesn't modify the params in a bad way.""" 74 # pylint: disable=W0612 75 data = [[random.uniform(-1, 1) for i in range(self.params.num_features)] 76 for _ in range(100)] 77 78 self.assertTrue(isinstance(self.params, tensor_forest.ForestHParams)) 79 self.assertFalse( 80 isinstance(self.params.num_trees, tensor_forest.ForestHParams)) 81 82 with variable_scope.variable_scope( 83 "DecisionsToDataThenNNTest_testConstructionPollution"): 84 graph_builder = decisions_to_data_then_nn.DecisionsToDataThenNN( 85 self.params) 86 87 self.assertTrue(isinstance(self.params, tensor_forest.ForestHParams)) 88 self.assertFalse( 89 isinstance(self.params.num_trees, tensor_forest.ForestHParams)) 90 91 def testInferenceConstruction(self): 92 # pylint: disable=W0612 93 data = constant_op.constant( 94 [[random.uniform(-1, 1) for i in range(self.params.num_features)] 95 for _ in range(100)]) 96 97 with variable_scope.variable_scope( 98 "DecisionsToDataThenNNTest_testInferenceConstruction"): 99 graph_builder = decisions_to_data_then_nn.DecisionsToDataThenNN( 100 self.params) 101 graph = graph_builder.inference_graph(data, None) 102 103 self.assertTrue(isinstance(graph, Tensor)) 104 105 def testTrainingConstruction(self): 106 # pylint: disable=W0612 107 data = constant_op.constant( 108 [[random.uniform(-1, 1) for i in range(self.params.num_features)] 109 for _ in range(100)]) 110 111 labels = [1 for _ in range(100)] 112 113 with variable_scope.variable_scope( 114 "DecisionsToDataThenNNTest_testTrainingConstruction"): 115 graph_builder = decisions_to_data_then_nn.DecisionsToDataThenNN( 116 self.params) 117 graph = graph_builder.training_graph(data, labels, None) 118 119 self.assertTrue(isinstance(graph, Operation)) 120 121 122 if __name__ == "__main__": 123 googletest.main() 124