1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """A model that places a hard decision tree embedding before a neural net.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 from tensorflow.contrib import layers 21 from tensorflow.contrib.tensor_forest.hybrid.python import hybrid_model 22 from tensorflow.contrib.tensor_forest.hybrid.python.layers import decisions_to_data 23 from tensorflow.contrib.tensor_forest.hybrid.python.layers import fully_connected 24 from tensorflow.python.ops import nn_ops 25 from tensorflow.python.training import adagrad 26 27 28 class HardDecisionsToDataThenNN(hybrid_model.HybridModel): 29 """A model that treats tree inference as hard at test.""" 30 31 def __init__(self, 32 params, 33 device_assigner=None, 34 optimizer_class=adagrad.AdagradOptimizer, 35 **kwargs): 36 37 super(HardDecisionsToDataThenNN, self).__init__( 38 params, 39 device_assigner=device_assigner, 40 optimizer_class=optimizer_class, 41 **kwargs) 42 43 self.layers = [decisions_to_data.HardDecisionsToDataLayer( 44 params, 0, device_assigner), 45 fully_connected.FullyConnectedLayer( 46 params, 1, device_assigner=device_assigner)] 47 48 def _base_inference(self, data, data_spec=None, soft=False): 49 if soft: 50 inference_result = self.layers[0].soft_inference_graph(data) 51 else: 52 inference_result = self._do_layer_inference(self.layers[0], data) 53 54 for layer in self.layers[1:]: 55 inference_result = self._do_layer_inference(layer, inference_result) 56 57 output_size = 1 if self.is_regression else self.params.num_classes 58 output = layers.fully_connected( 59 inference_result, output_size, activation_fn=nn_ops.softmax) 60 return output 61 62 def inference_graph(self, data, data_spec=None): 63 """Returns the op that performs inference on a batch of data.""" 64 65 return nn_ops.softmax( 66 self._base_inference( 67 data, data_spec=data_spec, soft=True)) 68 69 # pylint: disable=unused-argument 70 def training_inference_graph(self, data, data_spec=None): 71 return self._base_inference(data, data_spec=data_spec, soft=False) 72