Home | History | Annotate | Download | only in models
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for the hybrid tensor forest model."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import random
     21 
     22 # pylint: disable=unused-import
     23 
     24 from tensorflow.contrib.tensor_forest.hybrid.python.models import decisions_to_data_then_nn
     25 from tensorflow.contrib.tensor_forest.python import tensor_forest
     26 from tensorflow.python.framework import constant_op
     27 from tensorflow.python.framework import test_util
     28 from tensorflow.python.framework.ops import Operation
     29 from tensorflow.python.framework.ops import Tensor
     30 from tensorflow.python.ops import variable_scope
     31 from tensorflow.python.platform import googletest
     32 
     33 
     34 class DecisionsToDataThenNNTest(test_util.TensorFlowTestCase):
     35 
     36   def setUp(self):
     37     self.params = tensor_forest.ForestHParams(
     38         num_classes=2,
     39         num_features=31,
     40         layer_size=11,
     41         num_layers=13,
     42         num_trees=17,
     43         connection_probability=0.1,
     44         hybrid_tree_depth=4,
     45         regularization_strength=0.01,
     46         learning_rate=0.01,
     47         regularization="",
     48         weight_init_mean=0.0,
     49         weight_init_std=0.1)
     50     self.params.regression = False
     51     self.params.num_nodes = 2**self.params.hybrid_tree_depth - 1
     52     self.params.num_leaves = 2**(self.params.hybrid_tree_depth - 1)
     53 
     54   def testHParams(self):
     55     self.assertEquals(self.params.num_classes, 2)
     56     self.assertEquals(self.params.num_features, 31)
     57     self.assertEquals(self.params.layer_size, 11)
     58     self.assertEquals(self.params.num_layers, 13)
     59     self.assertEquals(self.params.num_trees, 17)
     60     self.assertEquals(self.params.hybrid_tree_depth, 4)
     61     self.assertEquals(self.params.connection_probability, 0.1)
     62 
     63     # Building the graphs modifies the params.
     64     with variable_scope.variable_scope("DecisionsToDataThenNNTest_testHParams"):
     65       # pylint: disable=W0612
     66       graph_builder = decisions_to_data_then_nn.DecisionsToDataThenNN(
     67           self.params)
     68 
     69       # Tree with depth 4 should have 2**0 + 2**1 + 2**2 + 2**3 = 15 nodes.
     70       self.assertEquals(self.params.num_nodes, 15)
     71 
     72   def testConstructionPollution(self):
     73     """Ensure that graph building doesn't modify the params in a bad way."""
     74     # pylint: disable=W0612
     75     data = [[random.uniform(-1, 1) for i in range(self.params.num_features)]
     76             for _ in range(100)]
     77 
     78     self.assertTrue(isinstance(self.params, tensor_forest.ForestHParams))
     79     self.assertFalse(
     80         isinstance(self.params.num_trees, tensor_forest.ForestHParams))
     81 
     82     with variable_scope.variable_scope(
     83         "DecisionsToDataThenNNTest_testConstructionPollution"):
     84       graph_builder = decisions_to_data_then_nn.DecisionsToDataThenNN(
     85           self.params)
     86 
     87       self.assertTrue(isinstance(self.params, tensor_forest.ForestHParams))
     88       self.assertFalse(
     89           isinstance(self.params.num_trees, tensor_forest.ForestHParams))
     90 
     91   def testInferenceConstruction(self):
     92     # pylint: disable=W0612
     93     data = constant_op.constant(
     94         [[random.uniform(-1, 1) for i in range(self.params.num_features)]
     95          for _ in range(100)])
     96 
     97     with variable_scope.variable_scope(
     98         "DecisionsToDataThenNNTest_testInferenceConstruction"):
     99       graph_builder = decisions_to_data_then_nn.DecisionsToDataThenNN(
    100           self.params)
    101       graph = graph_builder.inference_graph(data, None)
    102 
    103       self.assertTrue(isinstance(graph, Tensor))
    104 
    105   def testTrainingConstruction(self):
    106     # pylint: disable=W0612
    107     data = constant_op.constant(
    108         [[random.uniform(-1, 1) for i in range(self.params.num_features)]
    109          for _ in range(100)])
    110 
    111     labels = [1 for _ in range(100)]
    112 
    113     with variable_scope.variable_scope(
    114         "DecisionsToDataThenNNTest_testTrainingConstruction"):
    115       graph_builder = decisions_to_data_then_nn.DecisionsToDataThenNN(
    116           self.params)
    117       graph = graph_builder.training_graph(data, labels, None)
    118 
    119       self.assertTrue(isinstance(graph, Operation))
    120 
    121 
    122 if __name__ == "__main__":
    123   googletest.main()
    124