1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """A tf.learn implementation of online extremely random forests.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 from tensorflow.contrib import layers 21 from tensorflow.contrib.learn.python.learn.estimators import constants 22 from tensorflow.contrib.learn.python.learn.estimators import estimator 23 from tensorflow.contrib.learn.python.learn.estimators import head as head_lib 24 from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib 25 26 from tensorflow.contrib.tensor_forest.client import eval_metrics 27 from tensorflow.contrib.tensor_forest.python import tensor_forest 28 29 from tensorflow.python.framework import ops 30 from tensorflow.python.framework import sparse_tensor 31 from tensorflow.python.ops import array_ops 32 from tensorflow.python.ops import control_flow_ops 33 from tensorflow.python.ops import math_ops 34 from tensorflow.python.ops import resource_variable_ops 35 from tensorflow.python.ops import state_ops 36 from tensorflow.python.ops import variable_scope 37 from tensorflow.python.platform import tf_logging as logging 38 from tensorflow.python.summary import summary 39 from tensorflow.python.training import session_run_hook 40 from tensorflow.python.training import training_util 41 42 43 KEYS_NAME = 'keys' 44 LOSS_NAME = 'rf_training_loss' 45 TREE_PATHS_PREDICTION_KEY = 'tree_paths' 46 VARIANCE_PREDICTION_KEY = 'prediction_variance' 47 ALL_SERVING_KEY = 'tensorforest_all' 48 EPSILON = 0.000001 49 50 51 class TensorForestRunOpAtEndHook(session_run_hook.SessionRunHook): 52 53 def __init__(self, op_dict): 54 """Ops is a dict of {name: op} to run before the session is destroyed.""" 55 self._ops = op_dict 56 57 def end(self, session): 58 for name in sorted(self._ops.keys()): 59 logging.info('{0}: {1}'.format(name, session.run(self._ops[name]))) 60 61 62 class TensorForestLossHook(session_run_hook.SessionRunHook): 63 """Monitor to request stop when loss stops decreasing.""" 64 65 def __init__(self, 66 early_stopping_rounds, 67 early_stopping_loss_threshold=None, 68 loss_op=None): 69 self.early_stopping_rounds = early_stopping_rounds 70 self.early_stopping_loss_threshold = early_stopping_loss_threshold 71 self.loss_op = loss_op 72 self.min_loss = None 73 self.last_step = -1 74 # self.steps records the number of steps for which the loss has been 75 # non-decreasing 76 self.steps = 0 77 78 def before_run(self, run_context): 79 loss = (self.loss_op if self.loss_op is not None else 80 run_context.session.graph.get_operation_by_name( 81 LOSS_NAME).outputs[0]) 82 return session_run_hook.SessionRunArgs( 83 {'global_step': training_util.get_global_step(), 84 'current_loss': loss}) 85 86 def after_run(self, run_context, run_values): 87 current_loss = run_values.results['current_loss'] 88 current_step = run_values.results['global_step'] 89 self.steps += 1 90 # Guard against the global step going backwards, which might happen 91 # if we recover from something. 92 if self.last_step == -1 or self.last_step > current_step: 93 logging.info('TensorForestLossHook resetting last_step.') 94 self.last_step = current_step 95 self.steps = 0 96 self.min_loss = None 97 return 98 99 self.last_step = current_step 100 if (self.min_loss is None or current_loss < 101 (self.min_loss - self.min_loss * self.early_stopping_loss_threshold)): 102 self.min_loss = current_loss 103 self.steps = 0 104 if self.steps > self.early_stopping_rounds: 105 logging.info('TensorForestLossHook requesting stop.') 106 run_context.request_stop() 107 108 109 def get_default_head(params, weights_name, name=None): 110 if params.regression: 111 return head_lib.regression_head( 112 weight_column_name=weights_name, 113 label_dimension=params.num_outputs, 114 enable_centered_bias=False, 115 head_name=name) 116 else: 117 return head_lib.multi_class_head( 118 params.num_classes, 119 weight_column_name=weights_name, 120 enable_centered_bias=False, 121 head_name=name) 122 123 124 def get_model_fn(params, 125 graph_builder_class, 126 device_assigner, 127 feature_columns=None, 128 weights_name=None, 129 model_head=None, 130 keys_name=None, 131 early_stopping_rounds=100, 132 early_stopping_loss_threshold=0.001, 133 num_trainers=1, 134 trainer_id=0, 135 report_feature_importances=False, 136 local_eval=False, 137 head_scope=None, 138 include_all_in_serving=False): 139 """Return a model function given a way to construct a graph builder.""" 140 if model_head is None: 141 model_head = get_default_head(params, weights_name) 142 143 def _model_fn(features, labels, mode): 144 """Function that returns predictions, training loss, and training op.""" 145 if (isinstance(features, ops.Tensor) or 146 isinstance(features, sparse_tensor.SparseTensor)): 147 features = {'features': features} 148 if feature_columns: 149 features = features.copy() 150 features.update(layers.transform_features(features, feature_columns)) 151 152 weights = None 153 if weights_name and weights_name in features: 154 weights = features.pop(weights_name) 155 156 keys = None 157 if keys_name and keys_name in features: 158 keys = features.pop(keys_name) 159 160 # If we're doing eval, optionally ignore device_assigner. 161 # Also ignore device assigner if we're exporting (mode == INFER) 162 dev_assn = device_assigner 163 if (mode == model_fn_lib.ModeKeys.INFER or 164 (local_eval and mode == model_fn_lib.ModeKeys.EVAL)): 165 dev_assn = None 166 167 graph_builder = graph_builder_class(params, 168 device_assigner=dev_assn) 169 170 logits, tree_paths, regression_variance = graph_builder.inference_graph( 171 features) 172 173 summary.scalar('average_tree_size', graph_builder.average_size()) 174 # For binary classification problems, convert probabilities to logits. 175 # Includes hack to get around the fact that a probability might be 0 or 1. 176 if not params.regression and params.num_classes == 2: 177 class_1_probs = array_ops.slice(logits, [0, 1], [-1, 1]) 178 logits = math_ops.log( 179 math_ops.maximum(class_1_probs / math_ops.maximum( 180 1.0 - class_1_probs, EPSILON), EPSILON)) 181 182 # labels might be None if we're doing prediction (which brings up the 183 # question of why we force everything to adhere to a single model_fn). 184 training_graph = None 185 training_hooks = [] 186 if labels is not None and mode == model_fn_lib.ModeKeys.TRAIN: 187 with ops.control_dependencies([logits.op]): 188 training_graph = control_flow_ops.group( 189 graph_builder.training_graph( 190 features, labels, input_weights=weights, 191 num_trainers=num_trainers, 192 trainer_id=trainer_id), 193 state_ops.assign_add(training_util.get_global_step(), 1)) 194 195 # Put weights back in 196 if weights is not None: 197 features[weights_name] = weights 198 199 # TensorForest's training graph isn't calculated directly from the loss 200 # like many other models. 201 def _train_fn(unused_loss): 202 return training_graph 203 204 model_ops = model_head.create_model_fn_ops( 205 features=features, 206 labels=labels, 207 mode=mode, 208 train_op_fn=_train_fn, 209 logits=logits, 210 scope=head_scope) 211 212 # Ops are run in lexigraphical order of their keys. Run the resource 213 # clean-up op last. 214 all_handles = graph_builder.get_all_resource_handles() 215 ops_at_end = { 216 '9: clean up resources': control_flow_ops.group( 217 *[resource_variable_ops.destroy_resource_op(handle) 218 for handle in all_handles])} 219 220 if report_feature_importances: 221 ops_at_end['1: feature_importances'] = ( 222 graph_builder.feature_importances()) 223 224 training_hooks.append(TensorForestRunOpAtEndHook(ops_at_end)) 225 226 if early_stopping_rounds: 227 training_hooks.append( 228 TensorForestLossHook( 229 early_stopping_rounds, 230 early_stopping_loss_threshold=early_stopping_loss_threshold, 231 loss_op=model_ops.loss)) 232 233 model_ops.training_hooks.extend(training_hooks) 234 235 if keys is not None: 236 model_ops.predictions[keys_name] = keys 237 238 if params.inference_tree_paths: 239 model_ops.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths 240 241 model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance 242 if include_all_in_serving: 243 # In order to serve the variance we need to add the prediction dict 244 # to output_alternatives dict. 245 if not model_ops.output_alternatives: 246 model_ops.output_alternatives = {} 247 model_ops.output_alternatives[ALL_SERVING_KEY] = ( 248 constants.ProblemType.UNSPECIFIED, model_ops.predictions) 249 return model_ops 250 251 return _model_fn 252 253 254 class TensorForestEstimator(estimator.Estimator): 255 """An estimator that can train and evaluate a random forest. 256 257 Example: 258 259 ```python 260 params = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams( 261 num_classes=2, num_features=40, num_trees=10, max_nodes=1000) 262 263 # Estimator using the default graph builder. 264 estimator = TensorForestEstimator(params, model_dir=model_dir) 265 266 # Or estimator using TrainingLossForest as the graph builder. 267 estimator = TensorForestEstimator( 268 params, graph_builder_class=tensor_forest.TrainingLossForest, 269 model_dir=model_dir) 270 271 # Input builders 272 def input_fn_train: # returns x, y 273 ... 274 def input_fn_eval: # returns x, y 275 ... 276 estimator.fit(input_fn=input_fn_train) 277 estimator.evaluate(input_fn=input_fn_eval) 278 279 # Predict returns an iterable of dicts. 280 results = list(estimator.predict(x=x)) 281 prob0 = results[0][eval_metrics.INFERENCE_PROB_NAME] 282 prediction0 = results[0][eval_metrics.INFERENCE_PRED_NAME] 283 ``` 284 """ 285 286 def __init__(self, 287 params, 288 device_assigner=None, 289 model_dir=None, 290 feature_columns=None, 291 graph_builder_class=tensor_forest.RandomForestGraphs, 292 config=None, 293 weight_column=None, 294 keys_column=None, 295 feature_engineering_fn=None, 296 early_stopping_rounds=100, 297 early_stopping_loss_threshold=0.001, 298 num_trainers=1, 299 trainer_id=0, 300 report_feature_importances=False, 301 local_eval=False, 302 version=None, 303 head=None, 304 include_all_in_serving=False): 305 """Initializes a TensorForestEstimator instance. 306 307 Args: 308 params: ForestHParams object that holds random forest hyperparameters. 309 These parameters will be passed into `model_fn`. 310 device_assigner: An `object` instance that controls how trees get 311 assigned to devices. If `None`, will use 312 `tensor_forest.RandomForestDeviceAssigner`. 313 model_dir: Directory to save model parameters, graph, etc. To continue 314 training a previously saved model, load checkpoints saved to this 315 directory into an estimator. 316 feature_columns: An iterable containing all the feature columns used by 317 the model. All items in the set should be instances of classes derived 318 from `_FeatureColumn`. 319 graph_builder_class: An `object` instance that defines how TF graphs for 320 random forest training and inference are built. By default will use 321 `tensor_forest.RandomForestGraphs`. Can be overridden by version 322 kwarg. 323 config: `RunConfig` object to configure the runtime settings. 324 weight_column: A string defining feature column name representing 325 weights. Will be multiplied by the loss of the example. Used to 326 downweight or boost examples during training. 327 keys_column: A string naming one of the features to strip out and 328 pass through into the inference/eval results dict. Useful for 329 associating specific examples with their prediction. 330 feature_engineering_fn: Feature engineering function. Takes features and 331 labels which are the output of `input_fn` and returns features and 332 labels which will be fed into the model. 333 early_stopping_rounds: Allows training to terminate early if the forest is 334 no longer growing. 100 by default. Set to a Falsy value to disable 335 the default training hook. 336 early_stopping_loss_threshold: Percentage (as fraction) that loss must 337 improve by within early_stopping_rounds steps, otherwise training will 338 terminate. 339 num_trainers: Number of training jobs, which will partition trees 340 among them. 341 trainer_id: Which trainer this instance is. 342 report_feature_importances: If True, print out feature importances 343 during evaluation. 344 local_eval: If True, don't use a device assigner for eval. This is to 345 support some common setups where eval is done on a single machine, even 346 though training might be distributed. 347 version: Unused. 348 head: A heads_lib.Head object that calculates losses and such. If None, 349 one will be automatically created based on params. 350 include_all_in_serving: if True, allow preparation of the complete 351 prediction dict including the variance to be exported for serving with 352 the Servo lib; and it also requires calling export_savedmodel with 353 default_output_alternative_key=ALL_SERVING_KEY, i.e. 354 estimator.export_savedmodel(export_dir_base=your_export_dir, 355 serving_input_fn=your_export_input_fn, 356 default_output_alternative_key=ALL_SERVING_KEY) 357 if False, resort to default behavior, i.e. export scores and 358 probabilities but no variances. In this case 359 default_output_alternative_key should be None while calling 360 export_savedmodel(). 361 Note, that due to backward compatibility we cannot always set 362 include_all_in_serving to True because in this case calling 363 export_saved_model() without 364 default_output_alternative_key=ALL_SERVING_KEY (legacy behavior) the 365 saved_model_export_utils.get_output_alternatives() would raise 366 ValueError. 367 368 Returns: 369 A `TensorForestEstimator` instance. 370 """ 371 super(TensorForestEstimator, self).__init__( 372 model_fn=get_model_fn( 373 params.fill(), 374 graph_builder_class, 375 device_assigner, 376 feature_columns=feature_columns, 377 model_head=head, 378 weights_name=weight_column, 379 keys_name=keys_column, 380 early_stopping_rounds=early_stopping_rounds, 381 early_stopping_loss_threshold=early_stopping_loss_threshold, 382 num_trainers=num_trainers, 383 trainer_id=trainer_id, 384 report_feature_importances=report_feature_importances, 385 local_eval=local_eval, 386 include_all_in_serving=include_all_in_serving, 387 ), 388 model_dir=model_dir, 389 config=config, 390 feature_engineering_fn=feature_engineering_fn) 391 392 393 def get_combined_model_fn(model_fns): 394 """Get a combined model function given a list of other model fns. 395 396 The model function returned will call the individual model functions and 397 combine them appropriately. For: 398 399 training ops: tf.group them. 400 loss: average them. 401 predictions: concat probabilities such that predictions[*][0-C1] are the 402 probablities for output 1 (where C1 is the number of classes in output 1), 403 predictions[*][C1-(C1+C2)] are the probabilities for output 2 (where C2 404 is the number of classes in output 2), etc. Also stack predictions such 405 that predictions[i][j] is the class prediction for example i and output j. 406 407 This assumes that labels are 2-dimensional, with labels[i][j] being the 408 label for example i and output j, where forest j is trained using only 409 output j. 410 411 Args: 412 model_fns: A list of model functions obtained from get_model_fn. 413 414 Returns: 415 A ModelFnOps instance. 416 """ 417 def _model_fn(features, labels, mode): 418 """Function that returns predictions, training loss, and training op.""" 419 model_fn_ops = [] 420 for i in range(len(model_fns)): 421 with variable_scope.variable_scope('label_{0}'.format(i)): 422 sliced_labels = array_ops.slice(labels, [0, i], [-1, 1]) 423 model_fn_ops.append( 424 model_fns[i](features, sliced_labels, mode)) 425 training_hooks = [] 426 for mops in model_fn_ops: 427 training_hooks += mops.training_hooks 428 predictions = {} 429 if (mode == model_fn_lib.ModeKeys.EVAL or 430 mode == model_fn_lib.ModeKeys.INFER): 431 # Flatten the probabilities into one dimension. 432 predictions[eval_metrics.INFERENCE_PROB_NAME] = array_ops.concat( 433 [mops.predictions[eval_metrics.INFERENCE_PROB_NAME] 434 for mops in model_fn_ops], axis=1) 435 predictions[eval_metrics.INFERENCE_PRED_NAME] = array_ops.stack( 436 [mops.predictions[eval_metrics.INFERENCE_PRED_NAME] 437 for mops in model_fn_ops], axis=1) 438 loss = None 439 if (mode == model_fn_lib.ModeKeys.EVAL or 440 mode == model_fn_lib.ModeKeys.TRAIN): 441 loss = math_ops.reduce_sum( 442 array_ops.stack( 443 [mops.loss for mops in model_fn_ops])) / len(model_fn_ops) 444 445 train_op = None 446 if mode == model_fn_lib.ModeKeys.TRAIN: 447 train_op = control_flow_ops.group( 448 *[mops.train_op for mops in model_fn_ops]) 449 return model_fn_lib.ModelFnOps( 450 mode=mode, 451 predictions=predictions, 452 loss=loss, 453 train_op=train_op, 454 training_hooks=training_hooks, 455 scaffold=None, 456 output_alternatives=None) 457 458 return _model_fn 459 460 461 class MultiForestMultiHeadEstimator(estimator.Estimator): 462 """An estimator that can train a forest for a multi-headed problems. 463 464 This class essentially trains separate forests (each with their own 465 ForestHParams) for each output. 466 467 For multi-headed regression, a single-headed TensorForestEstimator can 468 be used to train a single model that predicts all outputs. This class can 469 be used to train separate forests for each output. 470 """ 471 472 def __init__(self, 473 params_list, 474 device_assigner=None, 475 model_dir=None, 476 feature_columns=None, 477 graph_builder_class=tensor_forest.RandomForestGraphs, 478 config=None, 479 weight_column=None, 480 keys_column=None, 481 feature_engineering_fn=None, 482 early_stopping_rounds=100, 483 num_trainers=1, 484 trainer_id=0, 485 report_feature_importances=False, 486 local_eval=False): 487 """See TensorForestEstimator.__init__.""" 488 model_fns = [] 489 for i in range(len(params_list)): 490 params = params_list[i].fill() 491 model_fns.append( 492 get_model_fn( 493 params, 494 graph_builder_class, 495 device_assigner, 496 model_head=get_default_head( 497 params, weight_column, name='head{0}'.format(i)), 498 weights_name=weight_column, 499 keys_name=keys_column, 500 early_stopping_rounds=early_stopping_rounds, 501 num_trainers=num_trainers, 502 trainer_id=trainer_id, 503 report_feature_importances=report_feature_importances, 504 local_eval=local_eval, 505 head_scope='output{0}'.format(i))) 506 507 super(MultiForestMultiHeadEstimator, self).__init__( 508 model_fn=get_combined_model_fn(model_fns), 509 model_dir=model_dir, 510 config=config, 511 feature_engineering_fn=feature_engineering_fn) 512