Home | History | Annotate | Download | only in python
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for tf.contrib.tensor_forest.ops.tensor_forest."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.contrib.tensor_forest.python import tensor_forest
     22 from tensorflow.python.framework import ops
     23 from tensorflow.python.framework import sparse_tensor
     24 from tensorflow.python.framework import test_util
     25 from tensorflow.python.platform import googletest
     26 
     27 
     28 class TensorForestTest(test_util.TensorFlowTestCase):
     29 
     30   def testForestHParams(self):
     31     hparams = tensor_forest.ForestHParams(
     32         num_classes=2,
     33         num_trees=100,
     34         max_nodes=1000,
     35         split_after_samples=25,
     36         num_features=60).fill()
     37     self.assertEquals(2, hparams.num_classes)
     38     self.assertEquals(3, hparams.num_output_columns)
     39     self.assertEquals(10, hparams.num_splits_to_consider)
     40     # Default value of valid_leaf_threshold
     41     self.assertEquals(1, hparams.valid_leaf_threshold)
     42     self.assertEquals(0, hparams.base_random_seed)
     43 
     44   def testForestHParamsBigTree(self):
     45     hparams = tensor_forest.ForestHParams(
     46         num_classes=2,
     47         num_trees=100,
     48         max_nodes=1000000,
     49         split_after_samples=25,
     50         num_features=1000).fill()
     51     self.assertEquals(31, hparams.num_splits_to_consider)
     52 
     53   def testForestHParamsStringParams(self):
     54     hparams = tensor_forest.ForestHParams(
     55         num_classes=2,
     56         num_trees=100,
     57         max_nodes=1000000,
     58         split_after_samples="25",
     59         num_splits_to_consider="1000000",
     60         num_features=1000).fill()
     61     self.assertEquals("1000000", hparams.num_splits_to_consider)
     62 
     63   def testTrainingConstructionClassification(self):
     64     input_data = [[-1., 0.], [-1., 2.],  # node 1
     65                   [1., 0.], [1., -2.]]  # node 2
     66     input_labels = [0, 1, 2, 3]
     67 
     68     params = tensor_forest.ForestHParams(
     69         num_classes=4,
     70         num_features=2,
     71         num_trees=10,
     72         max_nodes=1000,
     73         split_after_samples=25).fill()
     74 
     75     graph_builder = tensor_forest.RandomForestGraphs(params)
     76     graph = graph_builder.training_graph(input_data, input_labels)
     77     self.assertTrue(isinstance(graph, ops.Operation))
     78 
     79   def testTrainingConstructionRegression(self):
     80     input_data = [[-1., 0.], [-1., 2.],  # node 1
     81                   [1., 0.], [1., -2.]]  # node 2
     82     input_labels = [0, 1, 2, 3]
     83 
     84     params = tensor_forest.ForestHParams(
     85         num_classes=4,
     86         num_features=2,
     87         num_trees=10,
     88         max_nodes=1000,
     89         split_after_samples=25,
     90         regression=True).fill()
     91 
     92     graph_builder = tensor_forest.RandomForestGraphs(params)
     93     graph = graph_builder.training_graph(input_data, input_labels)
     94     self.assertTrue(isinstance(graph, ops.Operation))
     95 
     96   def testInferenceConstruction(self):
     97     input_data = [[-1., 0.], [-1., 2.],  # node 1
     98                   [1., 0.], [1., -2.]]  # node 2
     99 
    100     params = tensor_forest.ForestHParams(
    101         num_classes=4,
    102         num_features=2,
    103         num_trees=10,
    104         max_nodes=1000,
    105         split_after_samples=25).fill()
    106 
    107     graph_builder = tensor_forest.RandomForestGraphs(params)
    108     probs, paths, var = graph_builder.inference_graph(input_data)
    109     self.assertTrue(isinstance(probs, ops.Tensor))
    110     self.assertTrue(isinstance(paths, ops.Tensor))
    111     self.assertTrue(isinstance(var, ops.Tensor))
    112 
    113   def testTrainingConstructionClassificationSparse(self):
    114     input_data = sparse_tensor.SparseTensor(
    115         indices=[[0, 0], [0, 3], [1, 0], [1, 7], [2, 1], [3, 9]],
    116         values=[-1.0, 0.0, -1., 2., 1., -2.0],
    117         dense_shape=[4, 10])
    118     input_labels = [0, 1, 2, 3]
    119 
    120     params = tensor_forest.ForestHParams(
    121         num_classes=4,
    122         num_features=10,
    123         num_trees=10,
    124         max_nodes=1000,
    125         split_after_samples=25).fill()
    126 
    127     graph_builder = tensor_forest.RandomForestGraphs(params)
    128     graph = graph_builder.training_graph(input_data, input_labels)
    129     self.assertTrue(isinstance(graph, ops.Operation))
    130 
    131   def testInferenceConstructionSparse(self):
    132     input_data = sparse_tensor.SparseTensor(
    133         indices=[[0, 0], [0, 3],
    134                  [1, 0], [1, 7],
    135                  [2, 1],
    136                  [3, 9]],
    137         values=[-1.0, 0.0,
    138                 -1., 2.,
    139                 1.,
    140                 -2.0],
    141         dense_shape=[4, 10])
    142 
    143     params = tensor_forest.ForestHParams(
    144         num_classes=4,
    145         num_features=10,
    146         num_trees=10,
    147         max_nodes=1000,
    148         regression=True,
    149         split_after_samples=25).fill()
    150 
    151     graph_builder = tensor_forest.RandomForestGraphs(params)
    152     probs, paths, var = graph_builder.inference_graph(input_data)
    153     self.assertTrue(isinstance(probs, ops.Tensor))
    154     self.assertTrue(isinstance(paths, ops.Tensor))
    155     self.assertTrue(isinstance(var, ops.Tensor))
    156 
    157 
    158 if __name__ == "__main__":
    159   googletest.main()
    160