1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for tf.contrib.tensor_forest.ops.tensor_forest.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.contrib.tensor_forest.python import tensor_forest 22 from tensorflow.python.framework import ops 23 from tensorflow.python.framework import sparse_tensor 24 from tensorflow.python.framework import test_util 25 from tensorflow.python.platform import googletest 26 27 28 class TensorForestTest(test_util.TensorFlowTestCase): 29 30 def testForestHParams(self): 31 hparams = tensor_forest.ForestHParams( 32 num_classes=2, 33 num_trees=100, 34 max_nodes=1000, 35 split_after_samples=25, 36 num_features=60).fill() 37 self.assertEquals(2, hparams.num_classes) 38 self.assertEquals(3, hparams.num_output_columns) 39 self.assertEquals(10, hparams.num_splits_to_consider) 40 # Default value of valid_leaf_threshold 41 self.assertEquals(1, hparams.valid_leaf_threshold) 42 self.assertEquals(0, hparams.base_random_seed) 43 44 def testForestHParamsBigTree(self): 45 hparams = tensor_forest.ForestHParams( 46 num_classes=2, 47 num_trees=100, 48 max_nodes=1000000, 49 split_after_samples=25, 50 num_features=1000).fill() 51 self.assertEquals(31, hparams.num_splits_to_consider) 52 53 def testForestHParamsStringParams(self): 54 hparams = tensor_forest.ForestHParams( 55 num_classes=2, 56 num_trees=100, 57 max_nodes=1000000, 58 split_after_samples="25", 59 num_splits_to_consider="1000000", 60 num_features=1000).fill() 61 self.assertEquals("1000000", hparams.num_splits_to_consider) 62 63 def testTrainingConstructionClassification(self): 64 input_data = [[-1., 0.], [-1., 2.], # node 1 65 [1., 0.], [1., -2.]] # node 2 66 input_labels = [0, 1, 2, 3] 67 68 params = tensor_forest.ForestHParams( 69 num_classes=4, 70 num_features=2, 71 num_trees=10, 72 max_nodes=1000, 73 split_after_samples=25).fill() 74 75 graph_builder = tensor_forest.RandomForestGraphs(params) 76 graph = graph_builder.training_graph(input_data, input_labels) 77 self.assertTrue(isinstance(graph, ops.Operation)) 78 79 def testTrainingConstructionRegression(self): 80 input_data = [[-1., 0.], [-1., 2.], # node 1 81 [1., 0.], [1., -2.]] # node 2 82 input_labels = [0, 1, 2, 3] 83 84 params = tensor_forest.ForestHParams( 85 num_classes=4, 86 num_features=2, 87 num_trees=10, 88 max_nodes=1000, 89 split_after_samples=25, 90 regression=True).fill() 91 92 graph_builder = tensor_forest.RandomForestGraphs(params) 93 graph = graph_builder.training_graph(input_data, input_labels) 94 self.assertTrue(isinstance(graph, ops.Operation)) 95 96 def testInferenceConstruction(self): 97 input_data = [[-1., 0.], [-1., 2.], # node 1 98 [1., 0.], [1., -2.]] # node 2 99 100 params = tensor_forest.ForestHParams( 101 num_classes=4, 102 num_features=2, 103 num_trees=10, 104 max_nodes=1000, 105 split_after_samples=25).fill() 106 107 graph_builder = tensor_forest.RandomForestGraphs(params) 108 probs, paths, var = graph_builder.inference_graph(input_data) 109 self.assertTrue(isinstance(probs, ops.Tensor)) 110 self.assertTrue(isinstance(paths, ops.Tensor)) 111 self.assertTrue(isinstance(var, ops.Tensor)) 112 113 def testTrainingConstructionClassificationSparse(self): 114 input_data = sparse_tensor.SparseTensor( 115 indices=[[0, 0], [0, 3], [1, 0], [1, 7], [2, 1], [3, 9]], 116 values=[-1.0, 0.0, -1., 2., 1., -2.0], 117 dense_shape=[4, 10]) 118 input_labels = [0, 1, 2, 3] 119 120 params = tensor_forest.ForestHParams( 121 num_classes=4, 122 num_features=10, 123 num_trees=10, 124 max_nodes=1000, 125 split_after_samples=25).fill() 126 127 graph_builder = tensor_forest.RandomForestGraphs(params) 128 graph = graph_builder.training_graph(input_data, input_labels) 129 self.assertTrue(isinstance(graph, ops.Operation)) 130 131 def testInferenceConstructionSparse(self): 132 input_data = sparse_tensor.SparseTensor( 133 indices=[[0, 0], [0, 3], 134 [1, 0], [1, 7], 135 [2, 1], 136 [3, 9]], 137 values=[-1.0, 0.0, 138 -1., 2., 139 1., 140 -2.0], 141 dense_shape=[4, 10]) 142 143 params = tensor_forest.ForestHParams( 144 num_classes=4, 145 num_features=10, 146 num_trees=10, 147 max_nodes=1000, 148 regression=True, 149 split_after_samples=25).fill() 150 151 graph_builder = tensor_forest.RandomForestGraphs(params) 152 probs, paths, var = graph_builder.inference_graph(input_data) 153 self.assertTrue(isinstance(probs, ops.Tensor)) 154 self.assertTrue(isinstance(paths, ops.Tensor)) 155 self.assertTrue(isinstance(var, ops.Tensor)) 156 157 158 if __name__ == "__main__": 159 googletest.main() 160