Home | History | Annotate | Download | only in python
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Extremely random forest graph builder. go/brain-tree."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import math
     21 import numbers
     22 import random
     23 
     24 from google.protobuf import text_format
     25 
     26 from tensorflow.contrib.decision_trees.proto import generic_tree_model_pb2 as _tree_proto
     27 from tensorflow.contrib.framework.python.ops import variables as framework_variables
     28 from tensorflow.contrib.tensor_forest.proto import tensor_forest_params_pb2 as _params_proto
     29 from tensorflow.contrib.tensor_forest.python.ops import data_ops
     30 from tensorflow.contrib.tensor_forest.python.ops import model_ops
     31 from tensorflow.contrib.tensor_forest.python.ops import stats_ops
     32 
     33 from tensorflow.python.framework import ops
     34 from tensorflow.python.ops import array_ops
     35 from tensorflow.python.ops import control_flow_ops
     36 from tensorflow.python.ops import math_ops
     37 from tensorflow.python.ops import random_ops
     38 from tensorflow.python.ops import variable_scope
     39 from tensorflow.python.ops import variables as tf_variables
     40 from tensorflow.python.platform import tf_logging as logging
     41 
     42 
     43 # Stores tuples of (leaf model type, stats model type)
     44 CLASSIFICATION_LEAF_MODEL_TYPES = {
     45     'all_dense': (_params_proto.MODEL_DENSE_CLASSIFICATION,
     46                   _params_proto.STATS_DENSE_GINI),
     47     'all_sparse': (_params_proto.MODEL_SPARSE_CLASSIFICATION,
     48                    _params_proto.STATS_SPARSE_GINI),
     49     'sparse_then_dense':
     50         (_params_proto.MODEL_SPARSE_OR_DENSE_CLASSIFICATION,
     51          _params_proto.STATS_SPARSE_THEN_DENSE_GINI),
     52 }
     53 REGRESSION_MODEL_TYPE = (
     54     _params_proto.MODEL_REGRESSION,
     55     _params_proto.STATS_LEAST_SQUARES_REGRESSION,
     56     _params_proto.COLLECTION_BASIC)
     57 
     58 FINISH_TYPES = {
     59     'basic': _params_proto.SPLIT_FINISH_BASIC,
     60     'hoeffding': _params_proto.SPLIT_FINISH_DOMINATE_HOEFFDING,
     61     'bootstrap': _params_proto.SPLIT_FINISH_DOMINATE_BOOTSTRAP
     62 }
     63 PRUNING_TYPES = {
     64     'none': _params_proto.SPLIT_PRUNE_NONE,
     65     'half': _params_proto.SPLIT_PRUNE_HALF,
     66     'quarter': _params_proto.SPLIT_PRUNE_QUARTER,
     67     '10_percent': _params_proto.SPLIT_PRUNE_10_PERCENT,
     68     'hoeffding': _params_proto.SPLIT_PRUNE_HOEFFDING,
     69 }
     70 SPLIT_TYPES = {
     71     'less_or_equal': _tree_proto.InequalityTest.LESS_OR_EQUAL,
     72     'less': _tree_proto.InequalityTest.LESS_THAN
     73 }
     74 
     75 
     76 def parse_number_or_string_to_proto(proto, param):
     77   if isinstance(param, numbers.Number):
     78     proto.constant_value = param
     79   else:  # assume it's a string
     80     if param.isdigit():
     81       proto.constant_value = int(param)
     82     else:
     83       text_format.Merge(param, proto)
     84 
     85 
     86 def build_params_proto(params):
     87   """Build a TensorForestParams proto out of the V4ForestHParams object."""
     88   proto = _params_proto.TensorForestParams()
     89   proto.num_trees = params.num_trees
     90   proto.max_nodes = params.max_nodes
     91   proto.is_regression = params.regression
     92   proto.num_outputs = params.num_classes
     93   proto.num_features = params.num_features
     94 
     95   proto.leaf_type = params.leaf_model_type
     96   proto.stats_type = params.stats_model_type
     97   proto.collection_type = _params_proto.COLLECTION_BASIC
     98   proto.pruning_type.type = params.pruning_type
     99   proto.finish_type.type = params.finish_type
    100 
    101   proto.inequality_test_type = params.split_type
    102 
    103   proto.drop_final_class = False
    104   proto.collate_examples = params.collate_examples
    105   proto.checkpoint_stats = params.checkpoint_stats
    106   proto.use_running_stats_method = params.use_running_stats_method
    107   proto.initialize_average_splits = params.initialize_average_splits
    108   proto.inference_tree_paths = params.inference_tree_paths
    109 
    110   parse_number_or_string_to_proto(proto.pruning_type.prune_every_samples,
    111                                   params.prune_every_samples)
    112   parse_number_or_string_to_proto(proto.finish_type.check_every_steps,
    113                                   params.early_finish_check_every_samples)
    114   parse_number_or_string_to_proto(proto.split_after_samples,
    115                                   params.split_after_samples)
    116   parse_number_or_string_to_proto(proto.num_splits_to_consider,
    117                                   params.num_splits_to_consider)
    118 
    119   proto.dominate_fraction.constant_value = params.dominate_fraction
    120 
    121   if params.param_file:
    122     with open(params.param_file) as f:
    123       text_format.Merge(f.read(), proto)
    124 
    125   return proto
    126 
    127 
    128 # A convenience class for holding random forest hyperparameters.
    129 #
    130 # To just get some good default parameters, use:
    131 #   hparams = ForestHParams(num_classes=2, num_features=40).fill()
    132 #
    133 # Note that num_classes can not be inferred and so must always be specified.
    134 # Also, either num_splits_to_consider or num_features should be set.
    135 #
    136 # To override specific values, pass them to the constructor:
    137 #   hparams = ForestHParams(num_classes=5, num_trees=10, num_features=5).fill()
    138 #
    139 # TODO(thomaswc): Inherit from tf.HParams when that is publicly available.
    140 class ForestHParams(object):
    141   """A base class for holding hyperparameters and calculating good defaults."""
    142 
    143   def __init__(
    144       self,
    145       num_trees=100,
    146       max_nodes=10000,
    147       bagging_fraction=1.0,
    148       num_splits_to_consider=0,
    149       feature_bagging_fraction=1.0,
    150       max_fertile_nodes=0,  # deprecated, unused.
    151       split_after_samples=250,
    152       valid_leaf_threshold=1,
    153       dominate_method='bootstrap',
    154       dominate_fraction=0.99,
    155       model_name='all_dense',
    156       split_finish_name='basic',
    157       split_pruning_name='none',
    158       prune_every_samples=0,
    159       early_finish_check_every_samples=0,
    160       collate_examples=False,
    161       checkpoint_stats=False,
    162       use_running_stats_method=False,
    163       initialize_average_splits=False,
    164       inference_tree_paths=False,
    165       param_file=None,
    166       split_name='less_or_equal',
    167       **kwargs):
    168     self.num_trees = num_trees
    169     self.max_nodes = max_nodes
    170     self.bagging_fraction = bagging_fraction
    171     self.feature_bagging_fraction = feature_bagging_fraction
    172     self.num_splits_to_consider = num_splits_to_consider
    173     self.max_fertile_nodes = max_fertile_nodes
    174     self.split_after_samples = split_after_samples
    175     self.valid_leaf_threshold = valid_leaf_threshold
    176     self.dominate_method = dominate_method
    177     self.dominate_fraction = dominate_fraction
    178     self.model_name = model_name
    179     self.split_finish_name = split_finish_name
    180     self.split_pruning_name = split_pruning_name
    181     self.collate_examples = collate_examples
    182     self.checkpoint_stats = checkpoint_stats
    183     self.use_running_stats_method = use_running_stats_method
    184     self.initialize_average_splits = initialize_average_splits
    185     self.inference_tree_paths = inference_tree_paths
    186     self.param_file = param_file
    187     self.split_name = split_name
    188     self.early_finish_check_every_samples = early_finish_check_every_samples
    189     self.prune_every_samples = prune_every_samples
    190 
    191     for name, value in kwargs.items():
    192       setattr(self, name, value)
    193 
    194   def values(self):
    195     return self.__dict__
    196 
    197   def fill(self):
    198     """Intelligently sets any non-specific parameters."""
    199     # Fail fast if num_classes or num_features isn't set.
    200     _ = getattr(self, 'num_classes')
    201     _ = getattr(self, 'num_features')
    202 
    203     self.bagged_num_features = int(self.feature_bagging_fraction *
    204                                    self.num_features)
    205 
    206     self.bagged_features = None
    207     if self.feature_bagging_fraction < 1.0:
    208       self.bagged_features = [random.sample(
    209           range(self.num_features),
    210           self.bagged_num_features) for _ in range(self.num_trees)]
    211 
    212     self.regression = getattr(self, 'regression', False)
    213 
    214     # Num_outputs is the actual number of outputs (a single prediction for
    215     # classification, a N-dimenensional point for regression).
    216     self.num_outputs = self.num_classes if self.regression else 1
    217 
    218     # Add an extra column to classes for storing counts, which is needed for
    219     # regression and avoids having to recompute sums for classification.
    220     self.num_output_columns = self.num_classes + 1
    221 
    222     # Our experiments have found that num_splits_to_consider = num_features
    223     # gives good accuracy.
    224     self.num_splits_to_consider = self.num_splits_to_consider or min(
    225         max(10, math.floor(math.sqrt(self.num_features))), 1000)
    226 
    227     # If base_random_seed is 0, the current time will be used to seed the
    228     # random number generators for each tree.  If non-zero, the i-th tree
    229     # will be seeded with base_random_seed + i.
    230     self.base_random_seed = getattr(self, 'base_random_seed', 0)
    231 
    232     # How to store leaf models.
    233     self.leaf_model_type = (
    234         REGRESSION_MODEL_TYPE[0] if self.regression else
    235         CLASSIFICATION_LEAF_MODEL_TYPES[self.model_name][0])
    236 
    237     # How to store stats objects.
    238     self.stats_model_type = (
    239         REGRESSION_MODEL_TYPE[1] if self.regression else
    240         CLASSIFICATION_LEAF_MODEL_TYPES[self.model_name][1])
    241 
    242     self.finish_type = (
    243         _params_proto.SPLIT_FINISH_BASIC if self.regression else
    244         FINISH_TYPES[self.split_finish_name])
    245 
    246     self.pruning_type = PRUNING_TYPES[self.split_pruning_name]
    247 
    248     if self.pruning_type == _params_proto.SPLIT_PRUNE_NONE:
    249       self.prune_every_samples = 0
    250     else:
    251       if (not self.prune_every_samples and
    252           not (isinstance(numbers.Number) or
    253                self.split_after_samples.isdigit())):
    254         logging.error(
    255             'Must specify prune_every_samples if using a depth-dependent '
    256             'split_after_samples')
    257       # Pruning half-way through split_after_samples seems like a decent
    258       # default, making it easy to select the number being pruned with
    259       # pruning_type while not paying the cost of pruning too often.  Note that
    260       # this only holds if not using a depth-dependent split_after_samples.
    261       self.prune_every_samples = (self.prune_every_samples or
    262                                   int(self.split_after_samples) / 2)
    263 
    264     if self.finish_type == _params_proto.SPLIT_FINISH_BASIC:
    265       self.early_finish_check_every_samples = 0
    266     else:
    267       if (not self.early_finish_check_every_samples and
    268           not (isinstance(numbers.Number) or
    269                self.split_after_samples.isdigit())):
    270         logging.error(
    271             'Must specify prune_every_samples if using a depth-dependent '
    272             'split_after_samples')
    273       # Checking for early finish every quarter through split_after_samples
    274       # seems like a decent default. We don't want to incur the checking cost
    275       # too often, but (at least for hoeffding) it's lower than the cost of
    276       # pruning so we can do it a little more frequently.
    277       self.early_finish_check_every_samples = (
    278           self.early_finish_check_every_samples or
    279           int(self.split_after_samples) / 4)
    280 
    281     self.split_type = SPLIT_TYPES[self.split_name]
    282 
    283     return self
    284 
    285 
    286 def get_epoch_variable():
    287   """Returns the epoch variable, or [0] if not defined."""
    288   # Grab epoch variable defined in
    289   # //third_party/tensorflow/python/training/input.py::limit_epochs
    290   for v in tf_variables.local_variables():
    291     if 'limit_epochs/epoch' in v.op.name:
    292       return array_ops.reshape(v, [1])
    293   # TODO(thomaswc): Access epoch from the data feeder.
    294   return [0]
    295 
    296 
    297 # A simple container to hold the training variables for a single tree.
    298 class TreeTrainingVariables(object):
    299   """Stores tf.Variables for training a single random tree.
    300 
    301   Uses tf.get_variable to get tree-specific names so that this can be used
    302   with a tf.learn-style implementation (one that trains a model, saves it,
    303   then relies on restoring that model to evaluate).
    304   """
    305 
    306   def __init__(self, params, tree_num, training):
    307     if (not hasattr(params, 'params_proto') or
    308         not isinstance(params.params_proto,
    309                        _params_proto.TensorForestParams)):
    310       params.params_proto = build_params_proto(params)
    311 
    312     params.serialized_params_proto = params.params_proto.SerializeToString()
    313     self.stats = None
    314     if training:
    315       # TODO(gilberth): Manually shard this to be able to fit it on
    316       # multiple machines.
    317       self.stats = stats_ops.fertile_stats_variable(
    318           params, '', self.get_tree_name('stats', tree_num))
    319     self.tree = model_ops.tree_variable(
    320         params, '', self.stats, self.get_tree_name('tree', tree_num))
    321 
    322   def get_tree_name(self, name, num):
    323     return '{0}-{1}'.format(name, num)
    324 
    325 
    326 class ForestTrainingVariables(object):
    327   """A container for a forests training data, consisting of multiple trees.
    328 
    329   Instantiates a TreeTrainingVariables object for each tree. We override the
    330   __getitem__ and __setitem__ function so that usage looks like this:
    331 
    332     forest_variables = ForestTrainingVariables(params)
    333 
    334     ... forest_variables.tree ...
    335   """
    336 
    337   def __init__(self, params, device_assigner, training=True,
    338                tree_variables_class=TreeTrainingVariables):
    339     self.variables = []
    340     # Set up some scalar variables to run through the device assigner, then
    341     # we can use those to colocate everything related to a tree.
    342     self.device_dummies = []
    343     with ops.device(device_assigner):
    344       for i in range(params.num_trees):
    345         self.device_dummies.append(variable_scope.get_variable(
    346             name='device_dummy_%d' % i, shape=0))
    347 
    348     for i in range(params.num_trees):
    349       with ops.device(self.device_dummies[i].device):
    350         self.variables.append(tree_variables_class(params, i, training))
    351 
    352   def __setitem__(self, t, val):
    353     self.variables[t] = val
    354 
    355   def __getitem__(self, t):
    356     return self.variables[t]
    357 
    358 
    359 class RandomForestGraphs(object):
    360   """Builds TF graphs for random forest training and inference."""
    361 
    362   def __init__(self,
    363                params,
    364                device_assigner=None,
    365                variables=None,
    366                tree_variables_class=TreeTrainingVariables,
    367                tree_graphs=None,
    368                training=True):
    369     self.params = params
    370     self.device_assigner = (
    371         device_assigner or framework_variables.VariableDeviceChooser())
    372     logging.info('Constructing forest with params = ')
    373     logging.info(self.params.__dict__)
    374     self.variables = variables or ForestTrainingVariables(
    375         self.params, device_assigner=self.device_assigner, training=training,
    376         tree_variables_class=tree_variables_class)
    377     tree_graph_class = tree_graphs or RandomTreeGraphs
    378     self.trees = [
    379         tree_graph_class(self.variables[i], self.params, i)
    380         for i in range(self.params.num_trees)
    381     ]
    382 
    383   def _bag_features(self, tree_num, input_data):
    384     split_data = array_ops.split(
    385         value=input_data, num_or_size_splits=self.params.num_features, axis=1)
    386     return array_ops.concat(
    387         [split_data[ind] for ind in self.params.bagged_features[tree_num]], 1)
    388 
    389   def get_all_resource_handles(self):
    390     return ([self.variables[i].tree for i in range(len(self.trees))] +
    391             [self.variables[i].stats for i in range(len(self.trees))])
    392 
    393   def training_graph(self,
    394                      input_data,
    395                      input_labels,
    396                      num_trainers=1,
    397                      trainer_id=0,
    398                      **tree_kwargs):
    399     """Constructs a TF graph for training a random forest.
    400 
    401     Args:
    402       input_data: A tensor or dict of string->Tensor for input data.
    403       input_labels: A tensor or placeholder for labels associated with
    404         input_data.
    405       num_trainers: Number of parallel trainers to split trees among.
    406       trainer_id: Which trainer this instance is.
    407       **tree_kwargs: Keyword arguments passed to each tree's training_graph.
    408 
    409     Returns:
    410       The last op in the random forest training graph.
    411 
    412     Raises:
    413       NotImplementedError: If trying to use bagging with sparse features.
    414     """
    415     processed_dense_features, processed_sparse_features, data_spec = (
    416         data_ops.ParseDataTensorOrDict(input_data))
    417 
    418     if input_labels is not None:
    419       labels = data_ops.ParseLabelTensorOrDict(input_labels)
    420 
    421     data_spec = data_spec or self.get_default_data_spec(input_data)
    422 
    423     tree_graphs = []
    424     trees_per_trainer = self.params.num_trees / num_trainers
    425     tree_start = int(trainer_id * trees_per_trainer)
    426     tree_end = int((trainer_id + 1) * trees_per_trainer)
    427     for i in range(tree_start, tree_end):
    428       with ops.device(self.variables.device_dummies[i].device):
    429         seed = self.params.base_random_seed
    430         if seed != 0:
    431           seed += i
    432         # If using bagging, randomly select some of the input.
    433         tree_data = processed_dense_features
    434         tree_labels = labels
    435         if self.params.bagging_fraction < 1.0:
    436           # TODO(gilberth): Support bagging for sparse features.
    437           if processed_sparse_features is not None:
    438             raise NotImplementedError(
    439                 'Bagging not supported with sparse features.')
    440           # TODO(thomaswc): This does sampling without replacement.  Consider
    441           # also allowing sampling with replacement as an option.
    442           batch_size = array_ops.strided_slice(
    443               array_ops.shape(processed_dense_features), [0], [1])
    444           r = random_ops.random_uniform(batch_size, seed=seed)
    445           mask = math_ops.less(
    446               r, array_ops.ones_like(r) * self.params.bagging_fraction)
    447           gather_indices = array_ops.squeeze(
    448               array_ops.where(mask), squeeze_dims=[1])
    449           # TODO(thomaswc): Calculate out-of-bag data and labels, and store
    450           # them for use in calculating statistics later.
    451           tree_data = array_ops.gather(processed_dense_features, gather_indices)
    452           tree_labels = array_ops.gather(labels, gather_indices)
    453         if self.params.bagged_features:
    454           if processed_sparse_features is not None:
    455             raise NotImplementedError(
    456                 'Feature bagging not supported with sparse features.')
    457           tree_data = self._bag_features(i, tree_data)
    458 
    459         tree_graphs.append(self.trees[i].training_graph(
    460             tree_data,
    461             tree_labels,
    462             seed,
    463             data_spec=data_spec,
    464             sparse_features=processed_sparse_features,
    465             **tree_kwargs))
    466 
    467     return control_flow_ops.group(*tree_graphs, name='train')
    468 
    469   def inference_graph(self, input_data, **inference_args):
    470     """Constructs a TF graph for evaluating a random forest.
    471 
    472     Args:
    473       input_data: A tensor or dict of string->Tensor for the input data.
    474                   This input_data must generate the same spec as the
    475                   input_data used in training_graph:  the dict must have
    476                   the same keys, for example, and all tensors must have
    477                   the same size in their first dimension.
    478       **inference_args: Keyword arguments to pass through to each tree.
    479 
    480     Returns:
    481       A tuple of (probabilities, tree_paths, variance).
    482 
    483     Raises:
    484       NotImplementedError: If trying to use feature bagging with sparse
    485         features.
    486     """
    487     processed_dense_features, processed_sparse_features, data_spec = (
    488         data_ops.ParseDataTensorOrDict(input_data))
    489 
    490     probabilities = []
    491     paths = []
    492     for i in range(self.params.num_trees):
    493       with ops.device(self.variables.device_dummies[i].device):
    494         tree_data = processed_dense_features
    495         if self.params.bagged_features:
    496           if processed_sparse_features is not None:
    497             raise NotImplementedError(
    498                 'Feature bagging not supported with sparse features.')
    499           tree_data = self._bag_features(i, tree_data)
    500         probs, path = self.trees[i].inference_graph(
    501             tree_data,
    502             data_spec,
    503             sparse_features=processed_sparse_features,
    504             **inference_args)
    505         probabilities.append(probs)
    506         paths.append(path)
    507     with ops.device(self.variables.device_dummies[0].device):
    508       # shape of all_predict should be [batch_size, num_trees, num_outputs]
    509       all_predict = array_ops.stack(probabilities, axis=1)
    510       average_values = math_ops.div(
    511           math_ops.reduce_sum(all_predict, 1),
    512           self.params.num_trees,
    513           name='probabilities')
    514       tree_paths = array_ops.stack(paths, axis=1)
    515 
    516       expected_squares = math_ops.div(
    517           math_ops.reduce_sum(all_predict * all_predict, 1),
    518           self.params.num_trees)
    519       regression_variance = math_ops.maximum(
    520           0., expected_squares - average_values * average_values)
    521       return average_values, tree_paths, regression_variance
    522 
    523   def average_size(self):
    524     """Constructs a TF graph for evaluating the average size of a forest.
    525 
    526     Returns:
    527       The average number of nodes over the trees.
    528     """
    529     sizes = []
    530     for i in range(self.params.num_trees):
    531       with ops.device(self.variables.device_dummies[i].device):
    532         sizes.append(self.trees[i].size())
    533     return math_ops.reduce_mean(math_ops.to_float(array_ops.stack(sizes)))
    534 
    535   # pylint: disable=unused-argument
    536   def training_loss(self, features, labels, name='training_loss'):
    537     return math_ops.negative(self.average_size(), name=name)
    538 
    539   # pylint: disable=unused-argument
    540   def validation_loss(self, features, labels):
    541     return math_ops.negative(self.average_size())
    542 
    543   def average_impurity(self):
    544     """Constructs a TF graph for evaluating the leaf impurity of a forest.
    545 
    546     Returns:
    547       The last op in the graph.
    548     """
    549     impurities = []
    550     for i in range(self.params.num_trees):
    551       with ops.device(self.variables.device_dummies[i].device):
    552         impurities.append(self.trees[i].average_impurity())
    553     return math_ops.reduce_mean(array_ops.stack(impurities))
    554 
    555   def feature_importances(self):
    556     tree_counts = [self.trees[i].feature_usage_counts()
    557                    for i in range(self.params.num_trees)]
    558     total_counts = math_ops.reduce_sum(array_ops.stack(tree_counts, 0), 0)
    559     return total_counts / math_ops.reduce_sum(total_counts)
    560 
    561 
    562 class RandomTreeGraphs(object):
    563   """Builds TF graphs for random tree training and inference."""
    564 
    565   def __init__(self, variables, params, tree_num):
    566     self.variables = variables
    567     self.params = params
    568     self.tree_num = tree_num
    569 
    570   def training_graph(self,
    571                      input_data,
    572                      input_labels,
    573                      random_seed,
    574                      data_spec,
    575                      sparse_features=None,
    576                      input_weights=None):
    577 
    578     """Constructs a TF graph for training a random tree.
    579 
    580     Args:
    581       input_data: A tensor or placeholder for input data.
    582       input_labels: A tensor or placeholder for labels associated with
    583         input_data.
    584       random_seed: The random number generator seed to use for this tree.  0
    585         means use the current time as the seed.
    586       data_spec: A data_ops.TensorForestDataSpec object specifying the
    587         original feature/columns of the data.
    588       sparse_features: A tf.SparseTensor for sparse input data.
    589       input_weights: A float tensor or placeholder holding per-input weights,
    590         or None if all inputs are to be weighted equally.
    591 
    592     Returns:
    593       The last op in the random tree training graph.
    594     """
    595     # TODO(gilberth): Use this.
    596     unused_epoch = math_ops.to_int32(get_epoch_variable())
    597 
    598     if input_weights is None:
    599       input_weights = []
    600 
    601     sparse_indices = []
    602     sparse_values = []
    603     sparse_shape = []
    604     if sparse_features is not None:
    605       sparse_indices = sparse_features.indices
    606       sparse_values = sparse_features.values
    607       sparse_shape = sparse_features.dense_shape
    608 
    609     if input_data is None:
    610       input_data = []
    611 
    612     leaf_ids = model_ops.traverse_tree_v4(
    613         self.variables.tree,
    614         input_data,
    615         sparse_indices,
    616         sparse_values,
    617         sparse_shape,
    618         input_spec=data_spec.SerializeToString(),
    619         params=self.params.serialized_params_proto)
    620 
    621     update_model = model_ops.update_model_v4(
    622         self.variables.tree,
    623         leaf_ids,
    624         input_labels,
    625         input_weights,
    626         params=self.params.serialized_params_proto)
    627 
    628     finished_nodes = stats_ops.process_input_v4(
    629         self.variables.tree,
    630         self.variables.stats,
    631         input_data,
    632         sparse_indices,
    633         sparse_values,
    634         sparse_shape,
    635         input_labels,
    636         input_weights,
    637         leaf_ids,
    638         input_spec=data_spec.SerializeToString(),
    639         random_seed=random_seed,
    640         params=self.params.serialized_params_proto)
    641 
    642     with ops.control_dependencies([update_model]):
    643       return stats_ops.grow_tree_v4(
    644           self.variables.tree,
    645           self.variables.stats,
    646           finished_nodes,
    647           params=self.params.serialized_params_proto)
    648 
    649   def inference_graph(self, input_data, data_spec, sparse_features=None):
    650     """Constructs a TF graph for evaluating a random tree.
    651 
    652     Args:
    653       input_data: A tensor or placeholder for input data.
    654       data_spec: A TensorForestDataSpec proto specifying the original
    655         input columns.
    656       sparse_features: A tf.SparseTensor for sparse input data.
    657 
    658     Returns:
    659       A tuple of (probabilities, tree_paths).
    660     """
    661     sparse_indices = []
    662     sparse_values = []
    663     sparse_shape = []
    664     if sparse_features is not None:
    665       sparse_indices = sparse_features.indices
    666       sparse_values = sparse_features.values
    667       sparse_shape = sparse_features.dense_shape
    668     if input_data is None:
    669       input_data = []
    670 
    671     return model_ops.tree_predictions_v4(
    672         self.variables.tree,
    673         input_data,
    674         sparse_indices,
    675         sparse_values,
    676         sparse_shape,
    677         input_spec=data_spec.SerializeToString(),
    678         params=self.params.serialized_params_proto)
    679 
    680   def size(self):
    681     """Constructs a TF graph for evaluating the current number of nodes.
    682 
    683     Returns:
    684       The current number of nodes in the tree.
    685     """
    686     return model_ops.tree_size(self.variables.tree)
    687 
    688   def feature_usage_counts(self):
    689     return model_ops.feature_usage_counts(
    690         self.variables.tree, params=self.params.serialized_params_proto)
    691