1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Treats a decision tree as a representation transformation layer. 16 17 A decision tree transformer takes features as input and returns the probability 18 of reaching each leaf as output. The routing throughout the tree is learnable 19 via backpropagation. 20 """ 21 from __future__ import absolute_import 22 from __future__ import division 23 from __future__ import print_function 24 25 from tensorflow.contrib.tensor_forest.hybrid.ops import gen_training_ops 26 from tensorflow.contrib.tensor_forest.hybrid.python import hybrid_layer 27 from tensorflow.contrib.tensor_forest.hybrid.python.ops import training_ops 28 from tensorflow.python.framework import ops 29 from tensorflow.python.ops import array_ops 30 from tensorflow.python.ops import init_ops 31 from tensorflow.python.ops import variable_scope 32 33 34 class DecisionsToDataLayer(hybrid_layer.HybridLayer): 35 """A layer that treats soft decisions as data.""" 36 37 def _define_vars(self, params, **kwargs): 38 with ops.device(self.device_assigner): 39 40 self.tree_parameters = variable_scope.get_variable( 41 name='tree_parameters_%d' % self.layer_num, 42 shape=[params.num_nodes, params.num_features], 43 initializer=init_ops.truncated_normal_initializer( 44 mean=params.weight_init_mean, stddev=params.weight_init_std)) 45 46 self.tree_thresholds = variable_scope.get_variable( 47 name='tree_thresholds_%d' % self.layer_num, 48 shape=[params.num_nodes], 49 initializer=init_ops.truncated_normal_initializer( 50 mean=params.weight_init_mean, stddev=params.weight_init_std)) 51 52 def __init__(self, params, layer_num, device_assigner, 53 *args, **kwargs): 54 super(DecisionsToDataLayer, self).__init__( 55 params, layer_num, device_assigner, *args, **kwargs) 56 57 self._training_ops = training_ops.Load() 58 59 def inference_graph(self, data): 60 with ops.device(self.device_assigner): 61 routing_probabilities = gen_training_ops.routing_function( 62 data, 63 self.tree_parameters, 64 self.tree_thresholds, 65 max_nodes=self.params.num_nodes) 66 67 output = array_ops.slice( 68 routing_probabilities, 69 [0, self.params.num_nodes - self.params.num_leaves - 1], 70 [-1, self.params.num_leaves]) 71 72 return output 73 74 75 class KFeatureDecisionsToDataLayer(hybrid_layer.HybridLayer): 76 """A layer that treats soft decisions made on single features as data.""" 77 78 def _define_vars(self, params, **kwargs): 79 with ops.device(self.device_assigner): 80 81 self.tree_parameters = variable_scope.get_variable( 82 name='tree_parameters_%d' % self.layer_num, 83 shape=[params.num_nodes, params.num_features_per_node], 84 initializer=init_ops.truncated_normal_initializer( 85 mean=params.weight_init_mean, stddev=params.weight_init_std)) 86 87 self.tree_thresholds = variable_scope.get_variable( 88 name='tree_thresholds_%d' % self.layer_num, 89 shape=[params.num_nodes], 90 initializer=init_ops.truncated_normal_initializer( 91 mean=params.weight_init_mean, stddev=params.weight_init_std)) 92 93 def __init__(self, params, layer_num, device_assigner, 94 *args, **kwargs): 95 super(KFeatureDecisionsToDataLayer, self).__init__( 96 params, layer_num, device_assigner, *args, **kwargs) 97 98 self._training_ops = training_ops.Load() 99 100 # pylint: disable=unused-argument 101 def inference_graph(self, data): 102 with ops.device(self.device_assigner): 103 routing_probabilities = gen_training_ops.k_feature_routing_function( 104 data, 105 self.tree_parameters, 106 self.tree_thresholds, 107 max_nodes=self.params.num_nodes, 108 num_features_per_node=self.params.num_features_per_node, 109 layer_num=0, 110 random_seed=self.params.base_random_seed) 111 112 output = array_ops.slice( 113 routing_probabilities, 114 [0, self.params.num_nodes - self.params.num_leaves - 1], 115 [-1, self.params.num_leaves]) 116 117 return output 118 119 120 class HardDecisionsToDataLayer(DecisionsToDataLayer): 121 """A layer that learns a soft decision tree but treats it as hard at test.""" 122 123 def _define_vars(self, params, **kwargs): 124 with ops.device(self.device_assigner): 125 126 self.tree_parameters = variable_scope.get_variable( 127 name='hard_tree_parameters_%d' % self.layer_num, 128 shape=[params.num_nodes, params.num_features], 129 initializer=variable_scope.truncated_normal_initializer( 130 mean=params.weight_init_mean, stddev=params.weight_init_std)) 131 132 self.tree_thresholds = variable_scope.get_variable( 133 name='hard_tree_thresholds_%d' % self.layer_num, 134 shape=[params.num_nodes], 135 initializer=variable_scope.truncated_normal_initializer( 136 mean=params.weight_init_mean, stddev=params.weight_init_std)) 137 138 def soft_inference_graph(self, data): 139 return super(HardDecisionsToDataLayer, self).inference_graph(data) 140 141 def inference_graph(self, data): 142 with ops.device(self.device_assigner): 143 path_probability, path = gen_training_ops.hard_routing_function( 144 data, 145 self.tree_parameters, 146 self.tree_thresholds, 147 max_nodes=self.params.num_nodes, 148 tree_depth=self.params.hybrid_tree_depth) 149 150 output = array_ops.slice( 151 gen_training_ops.unpack_path(path, path_probability), 152 [0, self.params.num_nodes - self.params.num_leaves - 1], 153 [-1, self.params.num_leaves]) 154 155 return output 156 157 158 class StochasticHardDecisionsToDataLayer(HardDecisionsToDataLayer): 159 """A layer that learns a soft decision tree by sampling paths.""" 160 161 def _define_vars(self, params, **kwargs): 162 with ops.device(self.device_assigner): 163 164 self.tree_parameters = variable_scope.get_variable( 165 name='stochastic_hard_tree_parameters_%d' % self.layer_num, 166 shape=[params.num_nodes, params.num_features], 167 initializer=init_ops.truncated_normal_initializer( 168 mean=params.weight_init_mean, stddev=params.weight_init_std)) 169 170 self.tree_thresholds = variable_scope.get_variable( 171 name='stochastic_hard_tree_thresholds_%d' % self.layer_num, 172 shape=[params.num_nodes], 173 initializer=init_ops.truncated_normal_initializer( 174 mean=params.weight_init_mean, stddev=params.weight_init_std)) 175 176 def soft_inference_graph(self, data): 177 with ops.device(self.device_assigner): 178 path_probability, path = ( 179 gen_training_ops.stochastic_hard_routing_function( 180 data, 181 self.tree_parameters, 182 self.tree_thresholds, 183 tree_depth=self.params.hybrid_tree_depth, 184 random_seed=self.params.base_random_seed)) 185 186 output = array_ops.slice( 187 gen_training_ops.unpack_path(path, path_probability), 188 [0, self.params.num_nodes - self.params.num_leaves - 1], 189 [-1, self.params.num_leaves]) 190 191 return output 192 193 def inference_graph(self, data): 194 with ops.device(self.device_assigner): 195 path_probability, path = gen_training_ops.hard_routing_function( 196 data, 197 self.tree_parameters, 198 self.tree_thresholds, 199 max_nodes=self.params.num_nodes, 200 tree_depth=self.params.hybrid_tree_depth) 201 202 output = array_ops.slice( 203 gen_training_ops.unpack_path(path, path_probability), 204 [0, self.params.num_nodes - self.params.num_leaves - 1], 205 [-1, self.params.num_leaves]) 206 207 return output 208 209 210 class StochasticSoftDecisionsToDataLayer(StochasticHardDecisionsToDataLayer): 211 """A layer that learns a soft decision tree by sampling paths.""" 212 213 def _define_vars(self, params, **kwargs): 214 with ops.device(self.device_assigner): 215 216 self.tree_parameters = variable_scope.get_variable( 217 name='stochastic_soft_tree_parameters_%d' % self.layer_num, 218 shape=[params.num_nodes, params.num_features], 219 initializer=init_ops.truncated_normal_initializer( 220 mean=params.weight_init_mean, stddev=params.weight_init_std)) 221 222 self.tree_thresholds = variable_scope.get_variable( 223 name='stochastic_soft_tree_thresholds_%d' % self.layer_num, 224 shape=[params.num_nodes], 225 initializer=init_ops.truncated_normal_initializer( 226 mean=params.weight_init_mean, stddev=params.weight_init_std)) 227 228 def inference_graph(self, data): 229 with ops.device(self.device_assigner): 230 routes = gen_training_ops.routing_function( 231 data, 232 self.tree_parameters, 233 self.tree_thresholds, 234 max_nodes=self.params.num_nodes) 235 236 leaf_routes = array_ops.slice( 237 routes, [0, self.params.num_nodes - self.params.num_leaves - 1], 238 [-1, self.params.num_leaves]) 239 240 return leaf_routes 241