1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """GTFlow Model definitions.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import copy 22 23 from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks 24 from tensorflow.contrib.boosted_trees.python.ops import model_ops 25 from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch 26 from tensorflow.python.framework import ops 27 from tensorflow.python.ops import state_ops 28 from tensorflow.python.training import training_util 29 30 31 def model_builder(features, labels, mode, params, config): 32 """Multi-machine batch gradient descent tree model. 33 34 Args: 35 features: `Tensor` or `dict` of `Tensor` objects. 36 labels: Labels used to train on. 37 mode: Mode we are in. (TRAIN/EVAL/INFER) 38 params: A dict of hyperparameters. 39 The following hyperparameters are expected: 40 * head: A `Head` instance. 41 * learner_config: A config for the learner. 42 * feature_columns: An iterable containing all the feature columns used by 43 the model. 44 * examples_per_layer: Number of examples to accumulate before growing a 45 layer. It can also be a function that computes the number of examples 46 based on the depth of the layer that's being built. 47 * weight_column_name: The name of weight column. 48 * center_bias: Whether a separate tree should be created for first fitting 49 the bias. 50 config: `RunConfig` of the estimator. 51 52 Returns: 53 A `ModelFnOps` object. 54 Raises: 55 ValueError: if inputs are not valid. 56 """ 57 head = params["head"] 58 learner_config = params["learner_config"] 59 examples_per_layer = params["examples_per_layer"] 60 feature_columns = params["feature_columns"] 61 weight_column_name = params["weight_column_name"] 62 num_trees = params["num_trees"] 63 logits_modifier_function = params["logits_modifier_function"] 64 if features is None: 65 raise ValueError("At least one feature must be specified.") 66 67 if config is None: 68 raise ValueError("Missing estimator RunConfig.") 69 70 center_bias = params["center_bias"] 71 72 if isinstance(features, ops.Tensor): 73 features = {features.name: features} 74 75 # Make a shallow copy of features to ensure downstream usage 76 # is unaffected by modifications in the model function. 77 training_features = copy.copy(features) 78 training_features.pop(weight_column_name, None) 79 global_step = training_util.get_global_step() 80 with ops.device(global_step.device): 81 ensemble_handle = model_ops.tree_ensemble_variable( 82 stamp_token=0, 83 tree_ensemble_config="", # Initialize an empty ensemble. 84 name="ensemble_model") 85 86 # Create GBDT model. 87 gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( 88 is_chief=config.is_chief, 89 num_ps_replicas=config.num_ps_replicas, 90 ensemble_handle=ensemble_handle, 91 center_bias=center_bias, 92 examples_per_layer=examples_per_layer, 93 learner_config=learner_config, 94 feature_columns=feature_columns, 95 logits_dimension=head.logits_dimension, 96 features=training_features) 97 with ops.name_scope("gbdt", "gbdt_optimizer"): 98 predictions_dict = gbdt_model.predict(mode) 99 logits = predictions_dict["predictions"] 100 if logits_modifier_function: 101 logits = logits_modifier_function(logits, features, mode) 102 103 def _train_op_fn(loss): 104 """Returns the op to optimize the loss.""" 105 update_op = gbdt_model.train(loss, predictions_dict, labels) 106 with ops.control_dependencies( 107 [update_op]), (ops.colocate_with(global_step)): 108 update_op = state_ops.assign_add(global_step, 1).op 109 return update_op 110 111 model_fn_ops = head.create_model_fn_ops( 112 features=features, 113 mode=mode, 114 labels=labels, 115 train_op_fn=_train_op_fn, 116 logits=logits) 117 if num_trees: 118 if center_bias: 119 num_trees += 1 120 finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor() 121 model_fn_ops.training_hooks.append( 122 trainer_hooks.StopAfterNTrees(num_trees, attempted_trees, 123 finalized_trees)) 124 return model_fn_ops 125