Home | History | Annotate | Download | only in client
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """A tf.learn implementation of online extremely random forests."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from tensorflow.contrib import layers
     21 from tensorflow.contrib.learn.python.learn.estimators import constants
     22 from tensorflow.contrib.learn.python.learn.estimators import estimator
     23 from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
     24 from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
     25 
     26 from tensorflow.contrib.tensor_forest.client import eval_metrics
     27 from tensorflow.contrib.tensor_forest.python import tensor_forest
     28 
     29 from tensorflow.python.framework import ops
     30 from tensorflow.python.framework import sparse_tensor
     31 from tensorflow.python.ops import array_ops
     32 from tensorflow.python.ops import control_flow_ops
     33 from tensorflow.python.ops import math_ops
     34 from tensorflow.python.ops import resource_variable_ops
     35 from tensorflow.python.ops import state_ops
     36 from tensorflow.python.ops import variable_scope
     37 from tensorflow.python.platform import tf_logging as logging
     38 from tensorflow.python.summary import summary
     39 from tensorflow.python.training import session_run_hook
     40 from tensorflow.python.training import training_util
     41 
     42 
     43 KEYS_NAME = 'keys'
     44 LOSS_NAME = 'rf_training_loss'
     45 TREE_PATHS_PREDICTION_KEY = 'tree_paths'
     46 VARIANCE_PREDICTION_KEY = 'prediction_variance'
     47 ALL_SERVING_KEY = 'tensorforest_all'
     48 EPSILON = 0.000001
     49 
     50 
     51 class TensorForestRunOpAtEndHook(session_run_hook.SessionRunHook):
     52 
     53   def __init__(self, op_dict):
     54     """Ops is a dict of {name: op} to run before the session is destroyed."""
     55     self._ops = op_dict
     56 
     57   def end(self, session):
     58     for name in sorted(self._ops.keys()):
     59       logging.info('{0}: {1}'.format(name, session.run(self._ops[name])))
     60 
     61 
     62 class TensorForestLossHook(session_run_hook.SessionRunHook):
     63   """Monitor to request stop when loss stops decreasing."""
     64 
     65   def __init__(self,
     66                early_stopping_rounds,
     67                early_stopping_loss_threshold=None,
     68                loss_op=None):
     69     self.early_stopping_rounds = early_stopping_rounds
     70     self.early_stopping_loss_threshold = early_stopping_loss_threshold
     71     self.loss_op = loss_op
     72     self.min_loss = None
     73     self.last_step = -1
     74     # self.steps records the number of steps for which the loss has been
     75     # non-decreasing
     76     self.steps = 0
     77 
     78   def before_run(self, run_context):
     79     loss = (self.loss_op if self.loss_op is not None else
     80             run_context.session.graph.get_operation_by_name(
     81                 LOSS_NAME).outputs[0])
     82     return session_run_hook.SessionRunArgs(
     83         {'global_step': training_util.get_global_step(),
     84          'current_loss': loss})
     85 
     86   def after_run(self, run_context, run_values):
     87     current_loss = run_values.results['current_loss']
     88     current_step = run_values.results['global_step']
     89     self.steps += 1
     90     # Guard against the global step going backwards, which might happen
     91     # if we recover from something.
     92     if self.last_step == -1 or self.last_step > current_step:
     93       logging.info('TensorForestLossHook resetting last_step.')
     94       self.last_step = current_step
     95       self.steps = 0
     96       self.min_loss = None
     97       return
     98 
     99     self.last_step = current_step
    100     if (self.min_loss is None or current_loss <
    101         (self.min_loss - self.min_loss * self.early_stopping_loss_threshold)):
    102       self.min_loss = current_loss
    103       self.steps = 0
    104     if self.steps > self.early_stopping_rounds:
    105       logging.info('TensorForestLossHook requesting stop.')
    106       run_context.request_stop()
    107 
    108 
    109 def get_default_head(params, weights_name, name=None):
    110   if params.regression:
    111     return head_lib.regression_head(
    112         weight_column_name=weights_name,
    113         label_dimension=params.num_outputs,
    114         enable_centered_bias=False,
    115         head_name=name)
    116   else:
    117     return head_lib.multi_class_head(
    118         params.num_classes,
    119         weight_column_name=weights_name,
    120         enable_centered_bias=False,
    121         head_name=name)
    122 
    123 
    124 def get_model_fn(params,
    125                  graph_builder_class,
    126                  device_assigner,
    127                  feature_columns=None,
    128                  weights_name=None,
    129                  model_head=None,
    130                  keys_name=None,
    131                  early_stopping_rounds=100,
    132                  early_stopping_loss_threshold=0.001,
    133                  num_trainers=1,
    134                  trainer_id=0,
    135                  report_feature_importances=False,
    136                  local_eval=False,
    137                  head_scope=None,
    138                  include_all_in_serving=False):
    139   """Return a model function given a way to construct a graph builder."""
    140   if model_head is None:
    141     model_head = get_default_head(params, weights_name)
    142 
    143   def _model_fn(features, labels, mode):
    144     """Function that returns predictions, training loss, and training op."""
    145     if (isinstance(features, ops.Tensor) or
    146         isinstance(features, sparse_tensor.SparseTensor)):
    147       features = {'features': features}
    148     if feature_columns:
    149       features = features.copy()
    150       features.update(layers.transform_features(features, feature_columns))
    151 
    152     weights = None
    153     if weights_name and weights_name in features:
    154       weights = features.pop(weights_name)
    155 
    156     keys = None
    157     if keys_name and keys_name in features:
    158       keys = features.pop(keys_name)
    159 
    160     # If we're doing eval, optionally ignore device_assigner.
    161     # Also ignore device assigner if we're exporting (mode == INFER)
    162     dev_assn = device_assigner
    163     if (mode == model_fn_lib.ModeKeys.INFER or
    164         (local_eval and mode == model_fn_lib.ModeKeys.EVAL)):
    165       dev_assn = None
    166 
    167     graph_builder = graph_builder_class(params,
    168                                         device_assigner=dev_assn)
    169 
    170     logits, tree_paths, regression_variance = graph_builder.inference_graph(
    171         features)
    172 
    173     summary.scalar('average_tree_size', graph_builder.average_size())
    174     # For binary classification problems, convert probabilities to logits.
    175     # Includes hack to get around the fact that a probability might be 0 or 1.
    176     if not params.regression and params.num_classes == 2:
    177       class_1_probs = array_ops.slice(logits, [0, 1], [-1, 1])
    178       logits = math_ops.log(
    179           math_ops.maximum(class_1_probs / math_ops.maximum(
    180               1.0 - class_1_probs, EPSILON), EPSILON))
    181 
    182     # labels might be None if we're doing prediction (which brings up the
    183     # question of why we force everything to adhere to a single model_fn).
    184     training_graph = None
    185     training_hooks = []
    186     if labels is not None and mode == model_fn_lib.ModeKeys.TRAIN:
    187       with ops.control_dependencies([logits.op]):
    188         training_graph = control_flow_ops.group(
    189             graph_builder.training_graph(
    190                 features, labels, input_weights=weights,
    191                 num_trainers=num_trainers,
    192                 trainer_id=trainer_id),
    193             state_ops.assign_add(training_util.get_global_step(), 1))
    194 
    195     # Put weights back in
    196     if weights is not None:
    197       features[weights_name] = weights
    198 
    199     # TensorForest's training graph isn't calculated directly from the loss
    200     # like many other models.
    201     def _train_fn(unused_loss):
    202       return training_graph
    203 
    204     model_ops = model_head.create_model_fn_ops(
    205         features=features,
    206         labels=labels,
    207         mode=mode,
    208         train_op_fn=_train_fn,
    209         logits=logits,
    210         scope=head_scope)
    211 
    212     # Ops are run in lexigraphical order of their keys. Run the resource
    213     # clean-up op last.
    214     all_handles = graph_builder.get_all_resource_handles()
    215     ops_at_end = {
    216         '9: clean up resources': control_flow_ops.group(
    217             *[resource_variable_ops.destroy_resource_op(handle)
    218               for handle in all_handles])}
    219 
    220     if report_feature_importances:
    221       ops_at_end['1: feature_importances'] = (
    222           graph_builder.feature_importances())
    223 
    224     training_hooks.append(TensorForestRunOpAtEndHook(ops_at_end))
    225 
    226     if early_stopping_rounds:
    227       training_hooks.append(
    228           TensorForestLossHook(
    229               early_stopping_rounds,
    230               early_stopping_loss_threshold=early_stopping_loss_threshold,
    231               loss_op=model_ops.loss))
    232 
    233     model_ops.training_hooks.extend(training_hooks)
    234 
    235     if keys is not None:
    236       model_ops.predictions[keys_name] = keys
    237 
    238     if params.inference_tree_paths:
    239       model_ops.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths
    240 
    241     model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance
    242     if include_all_in_serving:
    243       # In order to serve the variance we need to add the prediction dict
    244       # to output_alternatives dict.
    245       if not model_ops.output_alternatives:
    246         model_ops.output_alternatives = {}
    247       model_ops.output_alternatives[ALL_SERVING_KEY] = (
    248           constants.ProblemType.UNSPECIFIED, model_ops.predictions)
    249     return model_ops
    250 
    251   return _model_fn
    252 
    253 
    254 class TensorForestEstimator(estimator.Estimator):
    255   """An estimator that can train and evaluate a random forest.
    256 
    257   Example:
    258 
    259   ```python
    260   params = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams(
    261       num_classes=2, num_features=40, num_trees=10, max_nodes=1000)
    262 
    263   # Estimator using the default graph builder.
    264   estimator = TensorForestEstimator(params, model_dir=model_dir)
    265 
    266   # Or estimator using TrainingLossForest as the graph builder.
    267   estimator = TensorForestEstimator(
    268       params, graph_builder_class=tensor_forest.TrainingLossForest,
    269       model_dir=model_dir)
    270 
    271   # Input builders
    272   def input_fn_train: # returns x, y
    273     ...
    274   def input_fn_eval: # returns x, y
    275     ...
    276   estimator.fit(input_fn=input_fn_train)
    277   estimator.evaluate(input_fn=input_fn_eval)
    278 
    279   # Predict returns an iterable of dicts.
    280   results = list(estimator.predict(x=x))
    281   prob0 = results[0][eval_metrics.INFERENCE_PROB_NAME]
    282   prediction0 = results[0][eval_metrics.INFERENCE_PRED_NAME]
    283   ```
    284   """
    285 
    286   def __init__(self,
    287                params,
    288                device_assigner=None,
    289                model_dir=None,
    290                feature_columns=None,
    291                graph_builder_class=tensor_forest.RandomForestGraphs,
    292                config=None,
    293                weight_column=None,
    294                keys_column=None,
    295                feature_engineering_fn=None,
    296                early_stopping_rounds=100,
    297                early_stopping_loss_threshold=0.001,
    298                num_trainers=1,
    299                trainer_id=0,
    300                report_feature_importances=False,
    301                local_eval=False,
    302                version=None,
    303                head=None,
    304                include_all_in_serving=False):
    305     """Initializes a TensorForestEstimator instance.
    306 
    307     Args:
    308       params: ForestHParams object that holds random forest hyperparameters.
    309         These parameters will be passed into `model_fn`.
    310       device_assigner: An `object` instance that controls how trees get
    311         assigned to devices. If `None`, will use
    312         `tensor_forest.RandomForestDeviceAssigner`.
    313       model_dir: Directory to save model parameters, graph, etc. To continue
    314         training a previously saved model, load checkpoints saved to this
    315         directory into an estimator.
    316       feature_columns: An iterable containing all the feature columns used by
    317         the model. All items in the set should be instances of classes derived
    318         from `_FeatureColumn`.
    319       graph_builder_class: An `object` instance that defines how TF graphs for
    320         random forest training and inference are built. By default will use
    321         `tensor_forest.RandomForestGraphs`. Can be overridden by version
    322         kwarg.
    323       config: `RunConfig` object to configure the runtime settings.
    324       weight_column: A string defining feature column name representing
    325         weights. Will be multiplied by the loss of the example. Used to
    326         downweight or boost examples during training.
    327       keys_column: A string naming one of the features to strip out and
    328         pass through into the inference/eval results dict.  Useful for
    329         associating specific examples with their prediction.
    330       feature_engineering_fn: Feature engineering function. Takes features and
    331         labels which are the output of `input_fn` and returns features and
    332         labels which will be fed into the model.
    333       early_stopping_rounds: Allows training to terminate early if the forest is
    334         no longer growing. 100 by default.  Set to a Falsy value to disable
    335         the default training hook.
    336       early_stopping_loss_threshold: Percentage (as fraction) that loss must
    337         improve by within early_stopping_rounds steps, otherwise training will
    338         terminate.
    339       num_trainers: Number of training jobs, which will partition trees
    340         among them.
    341       trainer_id: Which trainer this instance is.
    342       report_feature_importances: If True, print out feature importances
    343         during evaluation.
    344       local_eval: If True, don't use a device assigner for eval. This is to
    345         support some common setups where eval is done on a single machine, even
    346         though training might be distributed.
    347       version: Unused.
    348       head: A heads_lib.Head object that calculates losses and such. If None,
    349         one will be automatically created based on params.
    350       include_all_in_serving: if True, allow preparation of the complete
    351         prediction dict including the variance to be exported for serving with
    352         the Servo lib; and it also requires calling export_savedmodel with
    353         default_output_alternative_key=ALL_SERVING_KEY, i.e.
    354         estimator.export_savedmodel(export_dir_base=your_export_dir,
    355           serving_input_fn=your_export_input_fn,
    356           default_output_alternative_key=ALL_SERVING_KEY)
    357         if False, resort to default behavior, i.e. export scores and
    358           probabilities but no variances. In this case
    359           default_output_alternative_key should be None while calling
    360           export_savedmodel().
    361         Note, that due to backward compatibility we cannot always set
    362         include_all_in_serving to True because in this case calling
    363         export_saved_model() without
    364         default_output_alternative_key=ALL_SERVING_KEY (legacy behavior) the
    365         saved_model_export_utils.get_output_alternatives() would raise
    366         ValueError.
    367 
    368     Returns:
    369       A `TensorForestEstimator` instance.
    370     """
    371     super(TensorForestEstimator, self).__init__(
    372         model_fn=get_model_fn(
    373             params.fill(),
    374             graph_builder_class,
    375             device_assigner,
    376             feature_columns=feature_columns,
    377             model_head=head,
    378             weights_name=weight_column,
    379             keys_name=keys_column,
    380             early_stopping_rounds=early_stopping_rounds,
    381             early_stopping_loss_threshold=early_stopping_loss_threshold,
    382             num_trainers=num_trainers,
    383             trainer_id=trainer_id,
    384             report_feature_importances=report_feature_importances,
    385             local_eval=local_eval,
    386             include_all_in_serving=include_all_in_serving,
    387         ),
    388         model_dir=model_dir,
    389         config=config,
    390         feature_engineering_fn=feature_engineering_fn)
    391 
    392 
    393 def get_combined_model_fn(model_fns):
    394   """Get a combined model function given a list of other model fns.
    395 
    396   The model function returned will call the individual model functions and
    397   combine them appropriately.  For:
    398 
    399   training ops: tf.group them.
    400   loss: average them.
    401   predictions: concat probabilities such that predictions[*][0-C1] are the
    402     probablities for output 1 (where C1 is the number of classes in output 1),
    403     predictions[*][C1-(C1+C2)] are the probabilities for output 2 (where C2
    404     is the number of classes in output 2), etc.  Also stack predictions such
    405     that predictions[i][j] is the class prediction for example i and output j.
    406 
    407   This assumes that labels are 2-dimensional, with labels[i][j] being the
    408   label for example i and output j, where forest j is trained using only
    409   output j.
    410 
    411   Args:
    412     model_fns: A list of model functions obtained from get_model_fn.
    413 
    414   Returns:
    415     A ModelFnOps instance.
    416   """
    417   def _model_fn(features, labels, mode):
    418     """Function that returns predictions, training loss, and training op."""
    419     model_fn_ops = []
    420     for i in range(len(model_fns)):
    421       with variable_scope.variable_scope('label_{0}'.format(i)):
    422         sliced_labels = array_ops.slice(labels, [0, i], [-1, 1])
    423         model_fn_ops.append(
    424             model_fns[i](features, sliced_labels, mode))
    425     training_hooks = []
    426     for mops in model_fn_ops:
    427       training_hooks += mops.training_hooks
    428     predictions = {}
    429     if (mode == model_fn_lib.ModeKeys.EVAL or
    430         mode == model_fn_lib.ModeKeys.INFER):
    431       # Flatten the probabilities into one dimension.
    432       predictions[eval_metrics.INFERENCE_PROB_NAME] = array_ops.concat(
    433           [mops.predictions[eval_metrics.INFERENCE_PROB_NAME]
    434            for mops in model_fn_ops], axis=1)
    435       predictions[eval_metrics.INFERENCE_PRED_NAME] = array_ops.stack(
    436           [mops.predictions[eval_metrics.INFERENCE_PRED_NAME]
    437            for mops in model_fn_ops], axis=1)
    438     loss = None
    439     if (mode == model_fn_lib.ModeKeys.EVAL or
    440         mode == model_fn_lib.ModeKeys.TRAIN):
    441       loss = math_ops.reduce_sum(
    442           array_ops.stack(
    443               [mops.loss for mops in model_fn_ops])) / len(model_fn_ops)
    444 
    445     train_op = None
    446     if mode == model_fn_lib.ModeKeys.TRAIN:
    447       train_op = control_flow_ops.group(
    448           *[mops.train_op for mops in model_fn_ops])
    449     return model_fn_lib.ModelFnOps(
    450         mode=mode,
    451         predictions=predictions,
    452         loss=loss,
    453         train_op=train_op,
    454         training_hooks=training_hooks,
    455         scaffold=None,
    456         output_alternatives=None)
    457 
    458   return _model_fn
    459 
    460 
    461 class MultiForestMultiHeadEstimator(estimator.Estimator):
    462   """An estimator that can train a forest for a multi-headed problems.
    463 
    464   This class essentially trains separate forests (each with their own
    465   ForestHParams) for each output.
    466 
    467   For multi-headed regression, a single-headed TensorForestEstimator can
    468   be used to train a single model that predicts all outputs.  This class can
    469   be used to train separate forests for each output.
    470   """
    471 
    472   def __init__(self,
    473                params_list,
    474                device_assigner=None,
    475                model_dir=None,
    476                feature_columns=None,
    477                graph_builder_class=tensor_forest.RandomForestGraphs,
    478                config=None,
    479                weight_column=None,
    480                keys_column=None,
    481                feature_engineering_fn=None,
    482                early_stopping_rounds=100,
    483                num_trainers=1,
    484                trainer_id=0,
    485                report_feature_importances=False,
    486                local_eval=False):
    487     """See TensorForestEstimator.__init__."""
    488     model_fns = []
    489     for i in range(len(params_list)):
    490       params = params_list[i].fill()
    491       model_fns.append(
    492           get_model_fn(
    493               params,
    494               graph_builder_class,
    495               device_assigner,
    496               model_head=get_default_head(
    497                   params, weight_column, name='head{0}'.format(i)),
    498               weights_name=weight_column,
    499               keys_name=keys_column,
    500               early_stopping_rounds=early_stopping_rounds,
    501               num_trainers=num_trainers,
    502               trainer_id=trainer_id,
    503               report_feature_importances=report_feature_importances,
    504               local_eval=local_eval,
    505               head_scope='output{0}'.format(i)))
    506 
    507     super(MultiForestMultiHeadEstimator, self).__init__(
    508         model_fn=get_combined_model_fn(model_fns),
    509         model_dir=model_dir,
    510         config=config,
    511         feature_engineering_fn=feature_engineering_fn)
    512