Home | History | Annotate | Download | only in models
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for the hybrid tensor forest model."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import random
     21 
     22 # pylint: disable=unused-import
     23 
     24 from tensorflow.contrib.tensor_forest.hybrid.python.models import k_feature_decisions_to_data_then_nn
     25 from tensorflow.contrib.tensor_forest.python import tensor_forest
     26 from tensorflow.python.framework import constant_op
     27 from tensorflow.python.framework import test_util
     28 from tensorflow.python.framework.ops import Operation
     29 from tensorflow.python.framework.ops import Tensor
     30 from tensorflow.python.ops import variable_scope
     31 from tensorflow.python.platform import googletest
     32 
     33 
     34 class KFeatureDecisionsToDataThenNNTest(test_util.TensorFlowTestCase):
     35 
     36   def setUp(self):
     37     self.params = tensor_forest.ForestHParams(
     38         num_classes=2,
     39         num_features=31,
     40         layer_size=11,
     41         num_layers=13,
     42         num_trees=17,
     43         connection_probability=0.1,
     44         hybrid_tree_depth=4,
     45         regularization_strength=0.01,
     46         regularization="",
     47         base_random_seed=10,
     48         hybrid_feature_bagging_fraction=1.0,
     49         learning_rate=0.01,
     50         weight_init_mean=0.0,
     51         weight_init_std=0.1)
     52     self.params.regression = False
     53     self.params.num_nodes = 2**self.params.hybrid_tree_depth - 1
     54     self.params.num_leaves = 2**(self.params.hybrid_tree_depth - 1)
     55     self.params.num_features_per_node = (self.params.feature_bagging_fraction *
     56                                          self.params.num_features)
     57 
     58   def testKFeatureInferenceConstruction(self):
     59     # pylint: disable=W0612
     60     data = constant_op.constant(
     61         [[random.uniform(-1, 1) for i in range(self.params.num_features)]
     62          for _ in range(100)])
     63 
     64     with variable_scope.variable_scope(
     65         "KFeatureDecisionsToDataThenNNTest.testKFeatureInferenceContruction"):
     66       graph_builder = (
     67           k_feature_decisions_to_data_then_nn.KFeatureDecisionsToDataThenNN(
     68               self.params))
     69       graph = graph_builder.inference_graph(data, None)
     70 
     71       self.assertTrue(isinstance(graph, Tensor))
     72 
     73   def testKFeatureTrainingConstruction(self):
     74     # pylint: disable=W0612
     75     data = constant_op.constant(
     76         [[random.uniform(-1, 1) for i in range(self.params.num_features)]
     77          for _ in range(100)])
     78 
     79     labels = [1 for _ in range(100)]
     80 
     81     with variable_scope.variable_scope(
     82         "KFeatureDecisionsToDataThenNNTest.testKFeatureTrainingContruction"):
     83       graph_builder = (
     84           k_feature_decisions_to_data_then_nn.KFeatureDecisionsToDataThenNN(
     85               self.params))
     86       graph = graph_builder.training_graph(data, labels, None)
     87 
     88       self.assertTrue(isinstance(graph, Operation))
     89 
     90 
     91 if __name__ == "__main__":
     92   googletest.main()
     93