Home | History | Annotate | Download | only in layers
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Treats a decision tree as a representation transformation layer.
     16 
     17 A decision tree transformer takes features as input and returns the probability
     18 of reaching each leaf as output.  The routing throughout the tree is learnable
     19 via backpropagation.
     20 """
     21 from __future__ import absolute_import
     22 from __future__ import division
     23 from __future__ import print_function
     24 
     25 from tensorflow.contrib.tensor_forest.hybrid.ops import gen_training_ops
     26 from tensorflow.contrib.tensor_forest.hybrid.python import hybrid_layer
     27 from tensorflow.contrib.tensor_forest.hybrid.python.ops import training_ops
     28 from tensorflow.python.framework import ops
     29 from tensorflow.python.ops import array_ops
     30 from tensorflow.python.ops import init_ops
     31 from tensorflow.python.ops import variable_scope
     32 
     33 
     34 class DecisionsToDataLayer(hybrid_layer.HybridLayer):
     35   """A layer that treats soft decisions as data."""
     36 
     37   def _define_vars(self, params, **kwargs):
     38     with ops.device(self.device_assigner):
     39 
     40       self.tree_parameters = variable_scope.get_variable(
     41           name='tree_parameters_%d' % self.layer_num,
     42           shape=[params.num_nodes, params.num_features],
     43           initializer=init_ops.truncated_normal_initializer(
     44               mean=params.weight_init_mean, stddev=params.weight_init_std))
     45 
     46       self.tree_thresholds = variable_scope.get_variable(
     47           name='tree_thresholds_%d' % self.layer_num,
     48           shape=[params.num_nodes],
     49           initializer=init_ops.truncated_normal_initializer(
     50               mean=params.weight_init_mean, stddev=params.weight_init_std))
     51 
     52   def __init__(self, params, layer_num, device_assigner,
     53                *args, **kwargs):
     54     super(DecisionsToDataLayer, self).__init__(
     55         params, layer_num, device_assigner, *args, **kwargs)
     56 
     57     self._training_ops = training_ops.Load()
     58 
     59   def inference_graph(self, data):
     60     with ops.device(self.device_assigner):
     61       routing_probabilities = gen_training_ops.routing_function(
     62           data,
     63           self.tree_parameters,
     64           self.tree_thresholds,
     65           max_nodes=self.params.num_nodes)
     66 
     67       output = array_ops.slice(
     68           routing_probabilities,
     69           [0, self.params.num_nodes - self.params.num_leaves - 1],
     70           [-1, self.params.num_leaves])
     71 
     72       return output
     73 
     74 
     75 class KFeatureDecisionsToDataLayer(hybrid_layer.HybridLayer):
     76   """A layer that treats soft decisions made on single features as data."""
     77 
     78   def _define_vars(self, params, **kwargs):
     79     with ops.device(self.device_assigner):
     80 
     81       self.tree_parameters = variable_scope.get_variable(
     82           name='tree_parameters_%d' % self.layer_num,
     83           shape=[params.num_nodes, params.num_features_per_node],
     84           initializer=init_ops.truncated_normal_initializer(
     85               mean=params.weight_init_mean, stddev=params.weight_init_std))
     86 
     87       self.tree_thresholds = variable_scope.get_variable(
     88           name='tree_thresholds_%d' % self.layer_num,
     89           shape=[params.num_nodes],
     90           initializer=init_ops.truncated_normal_initializer(
     91               mean=params.weight_init_mean, stddev=params.weight_init_std))
     92 
     93   def __init__(self, params, layer_num, device_assigner,
     94                *args, **kwargs):
     95     super(KFeatureDecisionsToDataLayer, self).__init__(
     96         params, layer_num, device_assigner, *args, **kwargs)
     97 
     98     self._training_ops = training_ops.Load()
     99 
    100   # pylint: disable=unused-argument
    101   def inference_graph(self, data):
    102     with ops.device(self.device_assigner):
    103       routing_probabilities = gen_training_ops.k_feature_routing_function(
    104           data,
    105           self.tree_parameters,
    106           self.tree_thresholds,
    107           max_nodes=self.params.num_nodes,
    108           num_features_per_node=self.params.num_features_per_node,
    109           layer_num=0,
    110           random_seed=self.params.base_random_seed)
    111 
    112       output = array_ops.slice(
    113           routing_probabilities,
    114           [0, self.params.num_nodes - self.params.num_leaves - 1],
    115           [-1, self.params.num_leaves])
    116 
    117       return output
    118 
    119 
    120 class HardDecisionsToDataLayer(DecisionsToDataLayer):
    121   """A layer that learns a soft decision tree but treats it as hard at test."""
    122 
    123   def _define_vars(self, params, **kwargs):
    124     with ops.device(self.device_assigner):
    125 
    126       self.tree_parameters = variable_scope.get_variable(
    127           name='hard_tree_parameters_%d' % self.layer_num,
    128           shape=[params.num_nodes, params.num_features],
    129           initializer=variable_scope.truncated_normal_initializer(
    130               mean=params.weight_init_mean, stddev=params.weight_init_std))
    131 
    132       self.tree_thresholds = variable_scope.get_variable(
    133           name='hard_tree_thresholds_%d' % self.layer_num,
    134           shape=[params.num_nodes],
    135           initializer=variable_scope.truncated_normal_initializer(
    136               mean=params.weight_init_mean, stddev=params.weight_init_std))
    137 
    138   def soft_inference_graph(self, data):
    139     return super(HardDecisionsToDataLayer, self).inference_graph(data)
    140 
    141   def inference_graph(self, data):
    142     with ops.device(self.device_assigner):
    143       path_probability, path = gen_training_ops.hard_routing_function(
    144           data,
    145           self.tree_parameters,
    146           self.tree_thresholds,
    147           max_nodes=self.params.num_nodes,
    148           tree_depth=self.params.hybrid_tree_depth)
    149 
    150       output = array_ops.slice(
    151           gen_training_ops.unpack_path(path, path_probability),
    152           [0, self.params.num_nodes - self.params.num_leaves - 1],
    153           [-1, self.params.num_leaves])
    154 
    155       return output
    156 
    157 
    158 class StochasticHardDecisionsToDataLayer(HardDecisionsToDataLayer):
    159   """A layer that learns a soft decision tree by sampling paths."""
    160 
    161   def _define_vars(self, params, **kwargs):
    162     with ops.device(self.device_assigner):
    163 
    164       self.tree_parameters = variable_scope.get_variable(
    165           name='stochastic_hard_tree_parameters_%d' % self.layer_num,
    166           shape=[params.num_nodes, params.num_features],
    167           initializer=init_ops.truncated_normal_initializer(
    168               mean=params.weight_init_mean, stddev=params.weight_init_std))
    169 
    170       self.tree_thresholds = variable_scope.get_variable(
    171           name='stochastic_hard_tree_thresholds_%d' % self.layer_num,
    172           shape=[params.num_nodes],
    173           initializer=init_ops.truncated_normal_initializer(
    174               mean=params.weight_init_mean, stddev=params.weight_init_std))
    175 
    176   def soft_inference_graph(self, data):
    177     with ops.device(self.device_assigner):
    178       path_probability, path = (
    179           gen_training_ops.stochastic_hard_routing_function(
    180               data,
    181               self.tree_parameters,
    182               self.tree_thresholds,
    183               tree_depth=self.params.hybrid_tree_depth,
    184               random_seed=self.params.base_random_seed))
    185 
    186       output = array_ops.slice(
    187           gen_training_ops.unpack_path(path, path_probability),
    188           [0, self.params.num_nodes - self.params.num_leaves - 1],
    189           [-1, self.params.num_leaves])
    190 
    191       return output
    192 
    193   def inference_graph(self, data):
    194     with ops.device(self.device_assigner):
    195       path_probability, path = gen_training_ops.hard_routing_function(
    196           data,
    197           self.tree_parameters,
    198           self.tree_thresholds,
    199           max_nodes=self.params.num_nodes,
    200           tree_depth=self.params.hybrid_tree_depth)
    201 
    202       output = array_ops.slice(
    203           gen_training_ops.unpack_path(path, path_probability),
    204           [0, self.params.num_nodes - self.params.num_leaves - 1],
    205           [-1, self.params.num_leaves])
    206 
    207       return output
    208 
    209 
    210 class StochasticSoftDecisionsToDataLayer(StochasticHardDecisionsToDataLayer):
    211   """A layer that learns a soft decision tree by sampling paths."""
    212 
    213   def _define_vars(self, params, **kwargs):
    214     with ops.device(self.device_assigner):
    215 
    216       self.tree_parameters = variable_scope.get_variable(
    217           name='stochastic_soft_tree_parameters_%d' % self.layer_num,
    218           shape=[params.num_nodes, params.num_features],
    219           initializer=init_ops.truncated_normal_initializer(
    220               mean=params.weight_init_mean, stddev=params.weight_init_std))
    221 
    222       self.tree_thresholds = variable_scope.get_variable(
    223           name='stochastic_soft_tree_thresholds_%d' % self.layer_num,
    224           shape=[params.num_nodes],
    225           initializer=init_ops.truncated_normal_initializer(
    226               mean=params.weight_init_mean, stddev=params.weight_init_std))
    227 
    228   def inference_graph(self, data):
    229     with ops.device(self.device_assigner):
    230       routes = gen_training_ops.routing_function(
    231           data,
    232           self.tree_parameters,
    233           self.tree_thresholds,
    234           max_nodes=self.params.num_nodes)
    235 
    236       leaf_routes = array_ops.slice(
    237           routes, [0, self.params.num_nodes - self.params.num_leaves - 1],
    238           [-1, self.params.num_leaves])
    239 
    240       return leaf_routes
    241