1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Extremely random forest graph builder. go/brain-tree.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import math 21 import numbers 22 import random 23 24 from google.protobuf import text_format 25 26 from tensorflow.contrib.decision_trees.proto import generic_tree_model_pb2 as _tree_proto 27 from tensorflow.contrib.framework.python.ops import variables as framework_variables 28 from tensorflow.contrib.tensor_forest.proto import tensor_forest_params_pb2 as _params_proto 29 from tensorflow.contrib.tensor_forest.python.ops import data_ops 30 from tensorflow.contrib.tensor_forest.python.ops import model_ops 31 from tensorflow.contrib.tensor_forest.python.ops import stats_ops 32 33 from tensorflow.python.framework import ops 34 from tensorflow.python.ops import array_ops 35 from tensorflow.python.ops import control_flow_ops 36 from tensorflow.python.ops import math_ops 37 from tensorflow.python.ops import random_ops 38 from tensorflow.python.ops import variable_scope 39 from tensorflow.python.ops import variables as tf_variables 40 from tensorflow.python.platform import tf_logging as logging 41 42 43 # Stores tuples of (leaf model type, stats model type) 44 CLASSIFICATION_LEAF_MODEL_TYPES = { 45 'all_dense': (_params_proto.MODEL_DENSE_CLASSIFICATION, 46 _params_proto.STATS_DENSE_GINI), 47 'all_sparse': (_params_proto.MODEL_SPARSE_CLASSIFICATION, 48 _params_proto.STATS_SPARSE_GINI), 49 'sparse_then_dense': 50 (_params_proto.MODEL_SPARSE_OR_DENSE_CLASSIFICATION, 51 _params_proto.STATS_SPARSE_THEN_DENSE_GINI), 52 } 53 REGRESSION_MODEL_TYPE = ( 54 _params_proto.MODEL_REGRESSION, 55 _params_proto.STATS_LEAST_SQUARES_REGRESSION, 56 _params_proto.COLLECTION_BASIC) 57 58 FINISH_TYPES = { 59 'basic': _params_proto.SPLIT_FINISH_BASIC, 60 'hoeffding': _params_proto.SPLIT_FINISH_DOMINATE_HOEFFDING, 61 'bootstrap': _params_proto.SPLIT_FINISH_DOMINATE_BOOTSTRAP 62 } 63 PRUNING_TYPES = { 64 'none': _params_proto.SPLIT_PRUNE_NONE, 65 'half': _params_proto.SPLIT_PRUNE_HALF, 66 'quarter': _params_proto.SPLIT_PRUNE_QUARTER, 67 '10_percent': _params_proto.SPLIT_PRUNE_10_PERCENT, 68 'hoeffding': _params_proto.SPLIT_PRUNE_HOEFFDING, 69 } 70 SPLIT_TYPES = { 71 'less_or_equal': _tree_proto.InequalityTest.LESS_OR_EQUAL, 72 'less': _tree_proto.InequalityTest.LESS_THAN 73 } 74 75 76 def parse_number_or_string_to_proto(proto, param): 77 if isinstance(param, numbers.Number): 78 proto.constant_value = param 79 else: # assume it's a string 80 if param.isdigit(): 81 proto.constant_value = int(param) 82 else: 83 text_format.Merge(param, proto) 84 85 86 def build_params_proto(params): 87 """Build a TensorForestParams proto out of the V4ForestHParams object.""" 88 proto = _params_proto.TensorForestParams() 89 proto.num_trees = params.num_trees 90 proto.max_nodes = params.max_nodes 91 proto.is_regression = params.regression 92 proto.num_outputs = params.num_classes 93 proto.num_features = params.num_features 94 95 proto.leaf_type = params.leaf_model_type 96 proto.stats_type = params.stats_model_type 97 proto.collection_type = _params_proto.COLLECTION_BASIC 98 proto.pruning_type.type = params.pruning_type 99 proto.finish_type.type = params.finish_type 100 101 proto.inequality_test_type = params.split_type 102 103 proto.drop_final_class = False 104 proto.collate_examples = params.collate_examples 105 proto.checkpoint_stats = params.checkpoint_stats 106 proto.use_running_stats_method = params.use_running_stats_method 107 proto.initialize_average_splits = params.initialize_average_splits 108 proto.inference_tree_paths = params.inference_tree_paths 109 110 parse_number_or_string_to_proto(proto.pruning_type.prune_every_samples, 111 params.prune_every_samples) 112 parse_number_or_string_to_proto(proto.finish_type.check_every_steps, 113 params.early_finish_check_every_samples) 114 parse_number_or_string_to_proto(proto.split_after_samples, 115 params.split_after_samples) 116 parse_number_or_string_to_proto(proto.num_splits_to_consider, 117 params.num_splits_to_consider) 118 119 proto.dominate_fraction.constant_value = params.dominate_fraction 120 121 if params.param_file: 122 with open(params.param_file) as f: 123 text_format.Merge(f.read(), proto) 124 125 return proto 126 127 128 # A convenience class for holding random forest hyperparameters. 129 # 130 # To just get some good default parameters, use: 131 # hparams = ForestHParams(num_classes=2, num_features=40).fill() 132 # 133 # Note that num_classes can not be inferred and so must always be specified. 134 # Also, either num_splits_to_consider or num_features should be set. 135 # 136 # To override specific values, pass them to the constructor: 137 # hparams = ForestHParams(num_classes=5, num_trees=10, num_features=5).fill() 138 # 139 # TODO(thomaswc): Inherit from tf.HParams when that is publicly available. 140 class ForestHParams(object): 141 """A base class for holding hyperparameters and calculating good defaults.""" 142 143 def __init__( 144 self, 145 num_trees=100, 146 max_nodes=10000, 147 bagging_fraction=1.0, 148 num_splits_to_consider=0, 149 feature_bagging_fraction=1.0, 150 max_fertile_nodes=0, # deprecated, unused. 151 split_after_samples=250, 152 valid_leaf_threshold=1, 153 dominate_method='bootstrap', 154 dominate_fraction=0.99, 155 model_name='all_dense', 156 split_finish_name='basic', 157 split_pruning_name='none', 158 prune_every_samples=0, 159 early_finish_check_every_samples=0, 160 collate_examples=False, 161 checkpoint_stats=False, 162 use_running_stats_method=False, 163 initialize_average_splits=False, 164 inference_tree_paths=False, 165 param_file=None, 166 split_name='less_or_equal', 167 **kwargs): 168 self.num_trees = num_trees 169 self.max_nodes = max_nodes 170 self.bagging_fraction = bagging_fraction 171 self.feature_bagging_fraction = feature_bagging_fraction 172 self.num_splits_to_consider = num_splits_to_consider 173 self.max_fertile_nodes = max_fertile_nodes 174 self.split_after_samples = split_after_samples 175 self.valid_leaf_threshold = valid_leaf_threshold 176 self.dominate_method = dominate_method 177 self.dominate_fraction = dominate_fraction 178 self.model_name = model_name 179 self.split_finish_name = split_finish_name 180 self.split_pruning_name = split_pruning_name 181 self.collate_examples = collate_examples 182 self.checkpoint_stats = checkpoint_stats 183 self.use_running_stats_method = use_running_stats_method 184 self.initialize_average_splits = initialize_average_splits 185 self.inference_tree_paths = inference_tree_paths 186 self.param_file = param_file 187 self.split_name = split_name 188 self.early_finish_check_every_samples = early_finish_check_every_samples 189 self.prune_every_samples = prune_every_samples 190 191 for name, value in kwargs.items(): 192 setattr(self, name, value) 193 194 def values(self): 195 return self.__dict__ 196 197 def fill(self): 198 """Intelligently sets any non-specific parameters.""" 199 # Fail fast if num_classes or num_features isn't set. 200 _ = getattr(self, 'num_classes') 201 _ = getattr(self, 'num_features') 202 203 self.bagged_num_features = int(self.feature_bagging_fraction * 204 self.num_features) 205 206 self.bagged_features = None 207 if self.feature_bagging_fraction < 1.0: 208 self.bagged_features = [random.sample( 209 range(self.num_features), 210 self.bagged_num_features) for _ in range(self.num_trees)] 211 212 self.regression = getattr(self, 'regression', False) 213 214 # Num_outputs is the actual number of outputs (a single prediction for 215 # classification, a N-dimenensional point for regression). 216 self.num_outputs = self.num_classes if self.regression else 1 217 218 # Add an extra column to classes for storing counts, which is needed for 219 # regression and avoids having to recompute sums for classification. 220 self.num_output_columns = self.num_classes + 1 221 222 # Our experiments have found that num_splits_to_consider = num_features 223 # gives good accuracy. 224 self.num_splits_to_consider = self.num_splits_to_consider or min( 225 max(10, math.floor(math.sqrt(self.num_features))), 1000) 226 227 # If base_random_seed is 0, the current time will be used to seed the 228 # random number generators for each tree. If non-zero, the i-th tree 229 # will be seeded with base_random_seed + i. 230 self.base_random_seed = getattr(self, 'base_random_seed', 0) 231 232 # How to store leaf models. 233 self.leaf_model_type = ( 234 REGRESSION_MODEL_TYPE[0] if self.regression else 235 CLASSIFICATION_LEAF_MODEL_TYPES[self.model_name][0]) 236 237 # How to store stats objects. 238 self.stats_model_type = ( 239 REGRESSION_MODEL_TYPE[1] if self.regression else 240 CLASSIFICATION_LEAF_MODEL_TYPES[self.model_name][1]) 241 242 self.finish_type = ( 243 _params_proto.SPLIT_FINISH_BASIC if self.regression else 244 FINISH_TYPES[self.split_finish_name]) 245 246 self.pruning_type = PRUNING_TYPES[self.split_pruning_name] 247 248 if self.pruning_type == _params_proto.SPLIT_PRUNE_NONE: 249 self.prune_every_samples = 0 250 else: 251 if (not self.prune_every_samples and 252 not (isinstance(numbers.Number) or 253 self.split_after_samples.isdigit())): 254 logging.error( 255 'Must specify prune_every_samples if using a depth-dependent ' 256 'split_after_samples') 257 # Pruning half-way through split_after_samples seems like a decent 258 # default, making it easy to select the number being pruned with 259 # pruning_type while not paying the cost of pruning too often. Note that 260 # this only holds if not using a depth-dependent split_after_samples. 261 self.prune_every_samples = (self.prune_every_samples or 262 int(self.split_after_samples) / 2) 263 264 if self.finish_type == _params_proto.SPLIT_FINISH_BASIC: 265 self.early_finish_check_every_samples = 0 266 else: 267 if (not self.early_finish_check_every_samples and 268 not (isinstance(numbers.Number) or 269 self.split_after_samples.isdigit())): 270 logging.error( 271 'Must specify prune_every_samples if using a depth-dependent ' 272 'split_after_samples') 273 # Checking for early finish every quarter through split_after_samples 274 # seems like a decent default. We don't want to incur the checking cost 275 # too often, but (at least for hoeffding) it's lower than the cost of 276 # pruning so we can do it a little more frequently. 277 self.early_finish_check_every_samples = ( 278 self.early_finish_check_every_samples or 279 int(self.split_after_samples) / 4) 280 281 self.split_type = SPLIT_TYPES[self.split_name] 282 283 return self 284 285 286 def get_epoch_variable(): 287 """Returns the epoch variable, or [0] if not defined.""" 288 # Grab epoch variable defined in 289 # //third_party/tensorflow/python/training/input.py::limit_epochs 290 for v in tf_variables.local_variables(): 291 if 'limit_epochs/epoch' in v.op.name: 292 return array_ops.reshape(v, [1]) 293 # TODO(thomaswc): Access epoch from the data feeder. 294 return [0] 295 296 297 # A simple container to hold the training variables for a single tree. 298 class TreeTrainingVariables(object): 299 """Stores tf.Variables for training a single random tree. 300 301 Uses tf.get_variable to get tree-specific names so that this can be used 302 with a tf.learn-style implementation (one that trains a model, saves it, 303 then relies on restoring that model to evaluate). 304 """ 305 306 def __init__(self, params, tree_num, training): 307 if (not hasattr(params, 'params_proto') or 308 not isinstance(params.params_proto, 309 _params_proto.TensorForestParams)): 310 params.params_proto = build_params_proto(params) 311 312 params.serialized_params_proto = params.params_proto.SerializeToString() 313 self.stats = None 314 if training: 315 # TODO(gilberth): Manually shard this to be able to fit it on 316 # multiple machines. 317 self.stats = stats_ops.fertile_stats_variable( 318 params, '', self.get_tree_name('stats', tree_num)) 319 self.tree = model_ops.tree_variable( 320 params, '', self.stats, self.get_tree_name('tree', tree_num)) 321 322 def get_tree_name(self, name, num): 323 return '{0}-{1}'.format(name, num) 324 325 326 class ForestTrainingVariables(object): 327 """A container for a forests training data, consisting of multiple trees. 328 329 Instantiates a TreeTrainingVariables object for each tree. We override the 330 __getitem__ and __setitem__ function so that usage looks like this: 331 332 forest_variables = ForestTrainingVariables(params) 333 334 ... forest_variables.tree ... 335 """ 336 337 def __init__(self, params, device_assigner, training=True, 338 tree_variables_class=TreeTrainingVariables): 339 self.variables = [] 340 # Set up some scalar variables to run through the device assigner, then 341 # we can use those to colocate everything related to a tree. 342 self.device_dummies = [] 343 with ops.device(device_assigner): 344 for i in range(params.num_trees): 345 self.device_dummies.append(variable_scope.get_variable( 346 name='device_dummy_%d' % i, shape=0)) 347 348 for i in range(params.num_trees): 349 with ops.device(self.device_dummies[i].device): 350 self.variables.append(tree_variables_class(params, i, training)) 351 352 def __setitem__(self, t, val): 353 self.variables[t] = val 354 355 def __getitem__(self, t): 356 return self.variables[t] 357 358 359 class RandomForestGraphs(object): 360 """Builds TF graphs for random forest training and inference.""" 361 362 def __init__(self, 363 params, 364 device_assigner=None, 365 variables=None, 366 tree_variables_class=TreeTrainingVariables, 367 tree_graphs=None, 368 training=True): 369 self.params = params 370 self.device_assigner = ( 371 device_assigner or framework_variables.VariableDeviceChooser()) 372 logging.info('Constructing forest with params = ') 373 logging.info(self.params.__dict__) 374 self.variables = variables or ForestTrainingVariables( 375 self.params, device_assigner=self.device_assigner, training=training, 376 tree_variables_class=tree_variables_class) 377 tree_graph_class = tree_graphs or RandomTreeGraphs 378 self.trees = [ 379 tree_graph_class(self.variables[i], self.params, i) 380 for i in range(self.params.num_trees) 381 ] 382 383 def _bag_features(self, tree_num, input_data): 384 split_data = array_ops.split( 385 value=input_data, num_or_size_splits=self.params.num_features, axis=1) 386 return array_ops.concat( 387 [split_data[ind] for ind in self.params.bagged_features[tree_num]], 1) 388 389 def get_all_resource_handles(self): 390 return ([self.variables[i].tree for i in range(len(self.trees))] + 391 [self.variables[i].stats for i in range(len(self.trees))]) 392 393 def training_graph(self, 394 input_data, 395 input_labels, 396 num_trainers=1, 397 trainer_id=0, 398 **tree_kwargs): 399 """Constructs a TF graph for training a random forest. 400 401 Args: 402 input_data: A tensor or dict of string->Tensor for input data. 403 input_labels: A tensor or placeholder for labels associated with 404 input_data. 405 num_trainers: Number of parallel trainers to split trees among. 406 trainer_id: Which trainer this instance is. 407 **tree_kwargs: Keyword arguments passed to each tree's training_graph. 408 409 Returns: 410 The last op in the random forest training graph. 411 412 Raises: 413 NotImplementedError: If trying to use bagging with sparse features. 414 """ 415 processed_dense_features, processed_sparse_features, data_spec = ( 416 data_ops.ParseDataTensorOrDict(input_data)) 417 418 if input_labels is not None: 419 labels = data_ops.ParseLabelTensorOrDict(input_labels) 420 421 data_spec = data_spec or self.get_default_data_spec(input_data) 422 423 tree_graphs = [] 424 trees_per_trainer = self.params.num_trees / num_trainers 425 tree_start = int(trainer_id * trees_per_trainer) 426 tree_end = int((trainer_id + 1) * trees_per_trainer) 427 for i in range(tree_start, tree_end): 428 with ops.device(self.variables.device_dummies[i].device): 429 seed = self.params.base_random_seed 430 if seed != 0: 431 seed += i 432 # If using bagging, randomly select some of the input. 433 tree_data = processed_dense_features 434 tree_labels = labels 435 if self.params.bagging_fraction < 1.0: 436 # TODO(gilberth): Support bagging for sparse features. 437 if processed_sparse_features is not None: 438 raise NotImplementedError( 439 'Bagging not supported with sparse features.') 440 # TODO(thomaswc): This does sampling without replacement. Consider 441 # also allowing sampling with replacement as an option. 442 batch_size = array_ops.strided_slice( 443 array_ops.shape(processed_dense_features), [0], [1]) 444 r = random_ops.random_uniform(batch_size, seed=seed) 445 mask = math_ops.less( 446 r, array_ops.ones_like(r) * self.params.bagging_fraction) 447 gather_indices = array_ops.squeeze( 448 array_ops.where(mask), squeeze_dims=[1]) 449 # TODO(thomaswc): Calculate out-of-bag data and labels, and store 450 # them for use in calculating statistics later. 451 tree_data = array_ops.gather(processed_dense_features, gather_indices) 452 tree_labels = array_ops.gather(labels, gather_indices) 453 if self.params.bagged_features: 454 if processed_sparse_features is not None: 455 raise NotImplementedError( 456 'Feature bagging not supported with sparse features.') 457 tree_data = self._bag_features(i, tree_data) 458 459 tree_graphs.append(self.trees[i].training_graph( 460 tree_data, 461 tree_labels, 462 seed, 463 data_spec=data_spec, 464 sparse_features=processed_sparse_features, 465 **tree_kwargs)) 466 467 return control_flow_ops.group(*tree_graphs, name='train') 468 469 def inference_graph(self, input_data, **inference_args): 470 """Constructs a TF graph for evaluating a random forest. 471 472 Args: 473 input_data: A tensor or dict of string->Tensor for the input data. 474 This input_data must generate the same spec as the 475 input_data used in training_graph: the dict must have 476 the same keys, for example, and all tensors must have 477 the same size in their first dimension. 478 **inference_args: Keyword arguments to pass through to each tree. 479 480 Returns: 481 A tuple of (probabilities, tree_paths, variance). 482 483 Raises: 484 NotImplementedError: If trying to use feature bagging with sparse 485 features. 486 """ 487 processed_dense_features, processed_sparse_features, data_spec = ( 488 data_ops.ParseDataTensorOrDict(input_data)) 489 490 probabilities = [] 491 paths = [] 492 for i in range(self.params.num_trees): 493 with ops.device(self.variables.device_dummies[i].device): 494 tree_data = processed_dense_features 495 if self.params.bagged_features: 496 if processed_sparse_features is not None: 497 raise NotImplementedError( 498 'Feature bagging not supported with sparse features.') 499 tree_data = self._bag_features(i, tree_data) 500 probs, path = self.trees[i].inference_graph( 501 tree_data, 502 data_spec, 503 sparse_features=processed_sparse_features, 504 **inference_args) 505 probabilities.append(probs) 506 paths.append(path) 507 with ops.device(self.variables.device_dummies[0].device): 508 # shape of all_predict should be [batch_size, num_trees, num_outputs] 509 all_predict = array_ops.stack(probabilities, axis=1) 510 average_values = math_ops.div( 511 math_ops.reduce_sum(all_predict, 1), 512 self.params.num_trees, 513 name='probabilities') 514 tree_paths = array_ops.stack(paths, axis=1) 515 516 expected_squares = math_ops.div( 517 math_ops.reduce_sum(all_predict * all_predict, 1), 518 self.params.num_trees) 519 regression_variance = math_ops.maximum( 520 0., expected_squares - average_values * average_values) 521 return average_values, tree_paths, regression_variance 522 523 def average_size(self): 524 """Constructs a TF graph for evaluating the average size of a forest. 525 526 Returns: 527 The average number of nodes over the trees. 528 """ 529 sizes = [] 530 for i in range(self.params.num_trees): 531 with ops.device(self.variables.device_dummies[i].device): 532 sizes.append(self.trees[i].size()) 533 return math_ops.reduce_mean(math_ops.to_float(array_ops.stack(sizes))) 534 535 # pylint: disable=unused-argument 536 def training_loss(self, features, labels, name='training_loss'): 537 return math_ops.negative(self.average_size(), name=name) 538 539 # pylint: disable=unused-argument 540 def validation_loss(self, features, labels): 541 return math_ops.negative(self.average_size()) 542 543 def average_impurity(self): 544 """Constructs a TF graph for evaluating the leaf impurity of a forest. 545 546 Returns: 547 The last op in the graph. 548 """ 549 impurities = [] 550 for i in range(self.params.num_trees): 551 with ops.device(self.variables.device_dummies[i].device): 552 impurities.append(self.trees[i].average_impurity()) 553 return math_ops.reduce_mean(array_ops.stack(impurities)) 554 555 def feature_importances(self): 556 tree_counts = [self.trees[i].feature_usage_counts() 557 for i in range(self.params.num_trees)] 558 total_counts = math_ops.reduce_sum(array_ops.stack(tree_counts, 0), 0) 559 return total_counts / math_ops.reduce_sum(total_counts) 560 561 562 class RandomTreeGraphs(object): 563 """Builds TF graphs for random tree training and inference.""" 564 565 def __init__(self, variables, params, tree_num): 566 self.variables = variables 567 self.params = params 568 self.tree_num = tree_num 569 570 def training_graph(self, 571 input_data, 572 input_labels, 573 random_seed, 574 data_spec, 575 sparse_features=None, 576 input_weights=None): 577 578 """Constructs a TF graph for training a random tree. 579 580 Args: 581 input_data: A tensor or placeholder for input data. 582 input_labels: A tensor or placeholder for labels associated with 583 input_data. 584 random_seed: The random number generator seed to use for this tree. 0 585 means use the current time as the seed. 586 data_spec: A data_ops.TensorForestDataSpec object specifying the 587 original feature/columns of the data. 588 sparse_features: A tf.SparseTensor for sparse input data. 589 input_weights: A float tensor or placeholder holding per-input weights, 590 or None if all inputs are to be weighted equally. 591 592 Returns: 593 The last op in the random tree training graph. 594 """ 595 # TODO(gilberth): Use this. 596 unused_epoch = math_ops.to_int32(get_epoch_variable()) 597 598 if input_weights is None: 599 input_weights = [] 600 601 sparse_indices = [] 602 sparse_values = [] 603 sparse_shape = [] 604 if sparse_features is not None: 605 sparse_indices = sparse_features.indices 606 sparse_values = sparse_features.values 607 sparse_shape = sparse_features.dense_shape 608 609 if input_data is None: 610 input_data = [] 611 612 leaf_ids = model_ops.traverse_tree_v4( 613 self.variables.tree, 614 input_data, 615 sparse_indices, 616 sparse_values, 617 sparse_shape, 618 input_spec=data_spec.SerializeToString(), 619 params=self.params.serialized_params_proto) 620 621 update_model = model_ops.update_model_v4( 622 self.variables.tree, 623 leaf_ids, 624 input_labels, 625 input_weights, 626 params=self.params.serialized_params_proto) 627 628 finished_nodes = stats_ops.process_input_v4( 629 self.variables.tree, 630 self.variables.stats, 631 input_data, 632 sparse_indices, 633 sparse_values, 634 sparse_shape, 635 input_labels, 636 input_weights, 637 leaf_ids, 638 input_spec=data_spec.SerializeToString(), 639 random_seed=random_seed, 640 params=self.params.serialized_params_proto) 641 642 with ops.control_dependencies([update_model]): 643 return stats_ops.grow_tree_v4( 644 self.variables.tree, 645 self.variables.stats, 646 finished_nodes, 647 params=self.params.serialized_params_proto) 648 649 def inference_graph(self, input_data, data_spec, sparse_features=None): 650 """Constructs a TF graph for evaluating a random tree. 651 652 Args: 653 input_data: A tensor or placeholder for input data. 654 data_spec: A TensorForestDataSpec proto specifying the original 655 input columns. 656 sparse_features: A tf.SparseTensor for sparse input data. 657 658 Returns: 659 A tuple of (probabilities, tree_paths). 660 """ 661 sparse_indices = [] 662 sparse_values = [] 663 sparse_shape = [] 664 if sparse_features is not None: 665 sparse_indices = sparse_features.indices 666 sparse_values = sparse_features.values 667 sparse_shape = sparse_features.dense_shape 668 if input_data is None: 669 input_data = [] 670 671 return model_ops.tree_predictions_v4( 672 self.variables.tree, 673 input_data, 674 sparse_indices, 675 sparse_values, 676 sparse_shape, 677 input_spec=data_spec.SerializeToString(), 678 params=self.params.serialized_params_proto) 679 680 def size(self): 681 """Constructs a TF graph for evaluating the current number of nodes. 682 683 Returns: 684 The current number of nodes in the tree. 685 """ 686 return model_ops.tree_size(self.variables.tree) 687 688 def feature_usage_counts(self): 689 return model_ops.feature_usage_counts( 690 self.variables.tree, params=self.params.serialized_params_proto) 691