Home | History | Annotate | Download | only in estimator_batch
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """GTFlow Model definitions."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import copy
     22 
     23 from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks
     24 from tensorflow.contrib.boosted_trees.python.ops import model_ops
     25 from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch
     26 from tensorflow.python.framework import ops
     27 from tensorflow.python.ops import state_ops
     28 from tensorflow.python.training import training_util
     29 
     30 
     31 def model_builder(features, labels, mode, params, config):
     32   """Multi-machine batch gradient descent tree model.
     33 
     34   Args:
     35     features: `Tensor` or `dict` of `Tensor` objects.
     36     labels: Labels used to train on.
     37     mode: Mode we are in. (TRAIN/EVAL/INFER)
     38     params: A dict of hyperparameters.
     39       The following hyperparameters are expected:
     40       * head: A `Head` instance.
     41       * learner_config: A config for the learner.
     42       * feature_columns: An iterable containing all the feature columns used by
     43           the model.
     44       * examples_per_layer: Number of examples to accumulate before growing a
     45           layer. It can also be a function that computes the number of examples
     46           based on the depth of the layer that's being built.
     47       * weight_column_name: The name of weight column.
     48       * center_bias: Whether a separate tree should be created for first fitting
     49           the bias.
     50     config: `RunConfig` of the estimator.
     51 
     52   Returns:
     53     A `ModelFnOps` object.
     54   Raises:
     55     ValueError: if inputs are not valid.
     56   """
     57   head = params["head"]
     58   learner_config = params["learner_config"]
     59   examples_per_layer = params["examples_per_layer"]
     60   feature_columns = params["feature_columns"]
     61   weight_column_name = params["weight_column_name"]
     62   num_trees = params["num_trees"]
     63   logits_modifier_function = params["logits_modifier_function"]
     64   if features is None:
     65     raise ValueError("At least one feature must be specified.")
     66 
     67   if config is None:
     68     raise ValueError("Missing estimator RunConfig.")
     69 
     70   center_bias = params["center_bias"]
     71 
     72   if isinstance(features, ops.Tensor):
     73     features = {features.name: features}
     74 
     75   # Make a shallow copy of features to ensure downstream usage
     76   # is unaffected by modifications in the model function.
     77   training_features = copy.copy(features)
     78   training_features.pop(weight_column_name, None)
     79   global_step = training_util.get_global_step()
     80   with ops.device(global_step.device):
     81     ensemble_handle = model_ops.tree_ensemble_variable(
     82         stamp_token=0,
     83         tree_ensemble_config="",  # Initialize an empty ensemble.
     84         name="ensemble_model")
     85 
     86   # Create GBDT model.
     87   gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel(
     88       is_chief=config.is_chief,
     89       num_ps_replicas=config.num_ps_replicas,
     90       ensemble_handle=ensemble_handle,
     91       center_bias=center_bias,
     92       examples_per_layer=examples_per_layer,
     93       learner_config=learner_config,
     94       feature_columns=feature_columns,
     95       logits_dimension=head.logits_dimension,
     96       features=training_features)
     97   with ops.name_scope("gbdt", "gbdt_optimizer"):
     98     predictions_dict = gbdt_model.predict(mode)
     99     logits = predictions_dict["predictions"]
    100     if logits_modifier_function:
    101       logits = logits_modifier_function(logits, features, mode)
    102 
    103     def _train_op_fn(loss):
    104       """Returns the op to optimize the loss."""
    105       update_op = gbdt_model.train(loss, predictions_dict, labels)
    106       with ops.control_dependencies(
    107           [update_op]), (ops.colocate_with(global_step)):
    108         update_op = state_ops.assign_add(global_step, 1).op
    109         return update_op
    110 
    111   model_fn_ops = head.create_model_fn_ops(
    112       features=features,
    113       mode=mode,
    114       labels=labels,
    115       train_op_fn=_train_op_fn,
    116       logits=logits)
    117   if num_trees:
    118     if center_bias:
    119       num_trees += 1
    120     finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor()
    121     model_fn_ops.training_hooks.append(
    122         trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
    123                                       finalized_trees))
    124   return model_fn_ops
    125