1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Implementation of handler for split nodes for float columns. 16 17 The general idea in batch split finding is that each handler will accumulate its 18 own statistics on multiple workers. After some steps, the master runs 19 make_splits() sub-graph of each handler and each handler returns its best split 20 per partition. 21 22 The way we ensure consistency of statistics is by using stamp_tokens for read 23 and write operations. During each update of the model, a new stamp token is 24 created. This stamp token makes sure that updates from the previous iterations 25 are not included in the statistics for this iteration. 26 27 Inequality splits for float features are created similar to the method described 28 in Approximate Algorithm described in https://arxiv.org/pdf/1603.02754v3.pdf. 29 Weighted quantiles of the feature columns are computed in a distributed fashion 30 using quantile_ops.quantile_accumulator. 31 After certain number of steps of parallel accumulation of quantile statistics, 32 we decide on bucket boundaries. These bucket boundaries are then used for the 33 next N steps to accumulate gradients and hessians per bucket. 34 35 In this implementation, we gather quantile statistics and gradient statistics 36 concurrently. That means that we don't wait until we have enough quantile 37 statistics for bucketization before we start gathering gradient stats. Instead 38 during each step we create quantile stats for the next iteration and use the 39 previous quantile buckets for gradient stats accumulation. 40 In make_splits, we do these steps: 41 1) Get the buckets that were used creating for the gradient stats. 42 2) Create bucket boundaries for the next N iterations and clear the accumulated 43 quantile stats. 44 n3) Get the accumulated gradient stats and clear the accumulator. This step can 45 run in parallel to step 2. 46 4) For each leaf node in the current tree (partition): 47 4.1) Get the overall gain computed with gradients and hessians of all 48 examples that end up in this partition. 49 4.2) Compute tensors of left and right cumulative sum of gradients, hessians 50 and gain. The first dimension of these tensors are the bucket 51 boundaries. 52 4.3) Find the gains for all bucket boundaries: 53 split_gains = left_gain + right_gain - overall_gain. 54 4.4) Find the bucket boundary that has the best gain (argmax(split_gains)) 55 4.5) For Sparse handler, we also consider the gain for when the examples go 56 the left child and when the examples go to the right child and pick the 57 default direction that yields the most gain. 58 """ 59 60 from __future__ import absolute_import 61 from __future__ import division 62 from __future__ import print_function 63 64 import re 65 66 from tensorflow.contrib.boosted_trees.lib.learner.batch import base_split_handler 67 from tensorflow.contrib.boosted_trees.python.ops import quantile_ops 68 from tensorflow.contrib.boosted_trees.python.ops import split_handler_ops 69 from tensorflow.contrib.boosted_trees.python.ops import stats_accumulator_ops 70 from tensorflow.python.framework import constant_op 71 from tensorflow.python.framework import dtypes 72 from tensorflow.python.framework import function 73 from tensorflow.python.framework import ops 74 from tensorflow.python.framework import sparse_tensor 75 from tensorflow.python.ops import array_ops 76 from tensorflow.python.ops import control_flow_ops 77 from tensorflow.python.ops import math_ops 78 _BIAS_FEATURE_ID = -1 79 # Pattern to remove all non alpha numeric from a string. 80 _PATTERN = re.compile(r"[\W_]+") 81 82 83 class InequalitySplitHandler(base_split_handler.BaseSplitHandler): 84 """Base class for handlers of inequality splits.""" 85 86 def __init__(self, 87 l1_regularization, 88 l2_regularization, 89 tree_complexity_regularization, 90 min_node_weight, 91 feature_column_group_id, 92 epsilon, 93 num_quantiles, 94 gradient_shape, 95 hessian_shape, 96 multiclass_strategy, 97 init_stamp_token=0, 98 name=None): 99 """Initialize the internal state for this split handler. 100 101 Args: 102 l1_regularization: L1 regularization applied for this split handler. 103 l2_regularization: L2 regularization applied for this split handler. 104 tree_complexity_regularization: Tree complexity regularization applied 105 for this split handler. 106 min_node_weight: Minimum sum of weights of examples in each partition to 107 be considered for splitting. 108 feature_column_group_id: Feature column group index. 109 epsilon: A float, the error bound for quantile computation. 110 num_quantiles: An int, the number of buckets to create from the histogram. 111 gradient_shape: A TensorShape, containing shape of gradients. 112 hessian_shape: A TensorShape, containing shape of hessians. 113 multiclass_strategy: Strategy describing how to treat multiclass problems. 114 init_stamp_token: A tensor containing an scalar for initial stamp of the 115 stamped objects. 116 name: An optional handler name. 117 """ 118 super(InequalitySplitHandler, self).__init__( 119 name=name, 120 l1_regularization=l1_regularization, 121 l2_regularization=l2_regularization, 122 tree_complexity_regularization=tree_complexity_regularization, 123 min_node_weight=min_node_weight, 124 feature_column_group_id=feature_column_group_id, 125 gradient_shape=gradient_shape, 126 hessian_shape=hessian_shape, 127 multiclass_strategy=multiclass_strategy) 128 self._stats_accumulator = stats_accumulator_ops.StatsAccumulator( 129 init_stamp_token, 130 gradient_shape, 131 hessian_shape, 132 name="StatsAccumulator/{}".format(self._name)) 133 self._quantile_accumulator = quantile_ops.QuantileAccumulator( 134 init_stamp_token, 135 epsilon=epsilon, 136 num_quantiles=num_quantiles, 137 name="QuantileAccumulator/{}".format(self._name)) 138 139 140 class DenseSplitHandler(InequalitySplitHandler): 141 """Computes stats and finds the best inequality splits on dense columns.""" 142 143 def __init__(self, 144 dense_float_column, 145 l1_regularization, 146 l2_regularization, 147 tree_complexity_regularization, 148 min_node_weight, 149 feature_column_group_id, 150 epsilon, 151 num_quantiles, 152 gradient_shape, 153 hessian_shape, 154 multiclass_strategy, 155 init_stamp_token=0, 156 name=None): 157 """Initialize the internal state for this split handler. 158 159 Args: 160 dense_float_column: A `Tensor` column associated with this handler. 161 l1_regularization: L1 regularization applied for this split handler. 162 l2_regularization: L2 regularization applied for this split handler. 163 tree_complexity_regularization: Tree complexity regularization applied 164 for this split handler. 165 min_node_weight: Minimum sum of weights of examples in each partition to 166 be considered for splitting. 167 feature_column_group_id: Feature column group index. 168 epsilon: A float, the error bound for quantile computation. 169 num_quantiles: An int, the number of buckets to create from the histogram. 170 gradient_shape: A TensorShape, containing shape of gradients. 171 hessian_shape: A TensorShape, containing shape of hessians. 172 multiclass_strategy: Strategy describing how to treat multiclass problems. 173 init_stamp_token: A tensor containing an scalar for initial stamp of the 174 stamped objects. 175 name: An optional handler name. 176 """ 177 super(DenseSplitHandler, self).__init__( 178 l1_regularization=l1_regularization, 179 l2_regularization=l2_regularization, 180 tree_complexity_regularization=tree_complexity_regularization, 181 min_node_weight=min_node_weight, 182 feature_column_group_id=feature_column_group_id, 183 epsilon=epsilon, 184 num_quantiles=num_quantiles, 185 init_stamp_token=init_stamp_token, 186 name=name, 187 gradient_shape=gradient_shape, 188 hessian_shape=hessian_shape, 189 multiclass_strategy=multiclass_strategy) 190 self._dense_float_column = dense_float_column 191 # Register dense_make_stats_update function as an Op to the graph. 192 g = ops.get_default_graph() 193 dense_make_stats_update.add_to_graph(g) 194 195 def scheduled_reads(self): 196 return [self._quantile_accumulator.schedule_get_buckets()] 197 198 def update_stats(self, stamp_token, example_partition_ids, gradients, 199 hessians, empty_gradients, empty_hessians, weights, 200 is_active, scheduled_reads): 201 """Updates the state for dense split handler. 202 203 Args: 204 stamp_token: An int32 scalar tensor containing the current stamp token. 205 example_partition_ids: A dense tensor, containing an int32 for each 206 example which is the partition id that the example ends up in. 207 gradients: A dense tensor of gradients. 208 hessians: A dense tensor of hessians. 209 empty_gradients: A dense empty tensor of the same shape (for dimensions > 210 0) as gradients. 211 empty_hessians: A dense empty tensor of the same shape (for dimensions > 212 0) as hessians. 213 weights: A dense float32 tensor with a weight for each example. 214 is_active: A boolean tensor that says if this handler is active or not. 215 One value for the current layer and one value for the next layer. 216 scheduled_reads: List of scheduled reads for this handler. 217 218 Returns: 219 The op that updates the stats for this handler. 220 """ 221 name = _PATTERN.sub("", self._name) 222 with ops.name_scope(name, "DenseSplitHandler"): 223 are_buckets_ready, buckets = scheduled_reads[0] 224 (quantile_values, quantile_weights, example_partition_ids, 225 feature_ids, gradients, hessians) = dense_make_stats_update( 226 is_active, are_buckets_ready, self._dense_float_column, buckets, 227 example_partition_ids, gradients, hessians, weights, empty_gradients, 228 empty_hessians) 229 update_quantiles = self._quantile_accumulator.schedule_add_summary( 230 stamp_token=stamp_token, 231 column=quantile_values, 232 example_weights=quantile_weights) 233 update_stats = self._stats_accumulator.schedule_add( 234 example_partition_ids, feature_ids, gradients, hessians) 235 return control_flow_ops.no_op(), [update_quantiles, update_stats] 236 237 def make_splits(self, stamp_token, next_stamp_token, class_id): 238 """Create the best split using the accumulated stats and flush the state.""" 239 # Get the bucket boundaries 240 are_splits_ready, buckets = ( 241 self._quantile_accumulator.get_buckets(stamp_token)) 242 # After we receive the boundaries from previous iteration we can flush 243 # the quantile accumulator. 244 with ops.control_dependencies([buckets]): 245 flush_quantiles = self._quantile_accumulator.flush( 246 stamp_token=stamp_token, next_stamp_token=next_stamp_token) 247 248 # Get the aggregated gradients and hessians per <partition_id, feature_id> 249 # pair. 250 # In order to distribute the computation on all the PSs we use the PS that 251 # had the stats accumulator on. 252 with ops.device(None): 253 with ops.device(self._stats_accumulator.resource().device): 254 num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( 255 self._stats_accumulator.flush(stamp_token, next_stamp_token)) 256 257 # Put quantile and stats accumulator flushing in the dependency path. 258 are_splits_ready = control_flow_ops.with_dependencies( 259 [flush_quantiles, partition_ids], are_splits_ready) 260 261 partition_ids, gains, split_infos = ( 262 split_handler_ops.build_dense_inequality_splits( 263 num_minibatches=num_minibatches, 264 bucket_boundaries=buckets, 265 partition_ids=partition_ids, 266 bucket_ids=bucket_ids, 267 gradients=gradients, 268 hessians=hessians, 269 class_id=class_id, 270 feature_column_group_id=self._feature_column_group_id, 271 l1_regularization=self._l1_regularization, 272 l2_regularization=self._l2_regularization, 273 tree_complexity_regularization=self. 274 _tree_complexity_regularization, 275 min_node_weight=self._min_node_weight, 276 multiclass_strategy=self._multiclass_strategy)) 277 return (are_splits_ready, partition_ids, gains, split_infos) 278 279 280 class SparseSplitHandler(InequalitySplitHandler): 281 """Computes stats and finds the best inequality splits on sparse columns.""" 282 283 def __init__(self, 284 sparse_float_column, 285 l1_regularization, 286 l2_regularization, 287 tree_complexity_regularization, 288 min_node_weight, 289 feature_column_group_id, 290 epsilon, 291 num_quantiles, 292 gradient_shape, 293 hessian_shape, 294 multiclass_strategy, 295 init_stamp_token=0, 296 name=None): 297 """Initialize the internal state for this split handler. 298 299 Args: 300 sparse_float_column: A `SparseTensor` column associated with this handler. 301 l1_regularization: L1 regularization applied for this split handler. 302 l2_regularization: L2 regularization applied for this split handler. 303 tree_complexity_regularization: Tree complexity regularization applied 304 for this split handler. 305 min_node_weight: Minimum sum of weights of examples in each partition to 306 be considered for splitting. 307 feature_column_group_id: Feature column group index. 308 epsilon: A float, the error bound for quantile computation. 309 num_quantiles: An int, the number of buckets to create from the histogram. 310 gradient_shape: A TensorShape, containing shape of gradients. 311 hessian_shape: A TensorShape, containing shape of hessians. 312 multiclass_strategy: Strategy describing how to treat multiclass problems. 313 init_stamp_token: A tensor containing an scalar for initial stamp of the 314 stamped objects. 315 name: An optional handler name. 316 """ 317 super(SparseSplitHandler, self).__init__( 318 l1_regularization=l1_regularization, 319 l2_regularization=l2_regularization, 320 tree_complexity_regularization=tree_complexity_regularization, 321 min_node_weight=min_node_weight, 322 feature_column_group_id=feature_column_group_id, 323 epsilon=epsilon, 324 num_quantiles=num_quantiles, 325 gradient_shape=gradient_shape, 326 hessian_shape=hessian_shape, 327 multiclass_strategy=multiclass_strategy, 328 init_stamp_token=init_stamp_token, 329 name=name) 330 # Register sparse_make_stats_update function as an Op to the graph. 331 g = ops.get_default_graph() 332 sparse_make_stats_update.add_to_graph(g) 333 self._sparse_float_column = sparse_float_column 334 335 def scheduled_reads(self): 336 return [self._quantile_accumulator.schedule_get_buckets()] 337 338 def update_stats(self, stamp_token, example_partition_ids, gradients, 339 hessians, empty_gradients, empty_hessians, weights, 340 is_active, scheduled_reads): 341 """Updates the state for dense split handler. 342 343 Args: 344 stamp_token: An int32 scalar tensor containing the current stamp token. 345 example_partition_ids: A dense tensor, containing an int32 for each 346 example which is the partition id that the example ends up in. 347 gradients: A dense tensor of gradients. 348 hessians: A dense tensor of hessians. 349 empty_gradients: A dense empty tensor of the same shape (for dimensions > 350 0) as gradients. 351 empty_hessians: A dense empty tensor of the same shape (for dimensions > 352 0) as hessians. 353 weights: A dense float32 tensor with a weight for each example. 354 is_active: A boolean tensor that says if this handler is active or not. 355 One value for the current layer and one value for the next layer. 356 scheduled_reads: List of results from the scheduled reads. 357 358 Returns: 359 The op that updates the stats for this handler. 360 """ 361 are_buckets_ready, buckets = scheduled_reads[0] 362 with ops.name_scope(self._name, "SparseSplitHandler"): 363 (quantile_indices, quantile_values, quantile_shapes, quantile_weights, 364 example_partition_ids, 365 feature_ids, gradients, hessians) = sparse_make_stats_update( 366 is_active, are_buckets_ready, self._sparse_float_column.indices, 367 self._sparse_float_column.values, 368 self._sparse_float_column.dense_shape, buckets, 369 example_partition_ids, gradients, hessians, weights, empty_gradients, 370 empty_hessians) 371 update_quantiles = self._quantile_accumulator.schedule_add_summary( 372 stamp_token=stamp_token, 373 column=sparse_tensor.SparseTensor(quantile_indices, quantile_values, 374 quantile_shapes), 375 example_weights=quantile_weights) 376 update_stats = self._stats_accumulator.schedule_add( 377 example_partition_ids, feature_ids, gradients, hessians) 378 return (control_flow_ops.no_op(), [update_quantiles, update_stats]) 379 380 def make_splits(self, stamp_token, next_stamp_token, class_id): 381 """Create the best split using the accumulated stats and flush the state.""" 382 # Get the bucket boundaries 383 are_splits_ready, buckets = ( 384 self._quantile_accumulator.get_buckets(stamp_token)) 385 386 # After we receive the boundaries from previous iteration we can flush 387 # the quantile accumulator. 388 with ops.control_dependencies([buckets]): 389 flush_quantiles = self._quantile_accumulator.flush( 390 stamp_token=stamp_token, next_stamp_token=next_stamp_token) 391 392 with ops.device(None): 393 with ops.device(self._stats_accumulator.resource().device): 394 num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( 395 self._stats_accumulator.flush(stamp_token, next_stamp_token)) 396 397 # Put quantile and stats accumulator flushing in the dependency path. 398 are_splits_ready = control_flow_ops.with_dependencies( 399 [flush_quantiles, partition_ids], are_splits_ready) 400 partition_ids, gains, split_infos = ( 401 split_handler_ops.build_sparse_inequality_splits( 402 num_minibatches=num_minibatches, 403 bucket_boundaries=buckets, 404 partition_ids=partition_ids, 405 bucket_ids=bucket_ids, 406 gradients=gradients, 407 hessians=hessians, 408 class_id=class_id, 409 feature_column_group_id=self._feature_column_group_id, 410 l1_regularization=self._l1_regularization, 411 l2_regularization=self._l2_regularization, 412 tree_complexity_regularization=self. 413 _tree_complexity_regularization, 414 min_node_weight=self._min_node_weight, 415 bias_feature_id=_BIAS_FEATURE_ID, 416 multiclass_strategy=self._multiclass_strategy)) 417 return (are_splits_ready, partition_ids, gains, split_infos) 418 419 420 @function.Defun(dtypes.bool, dtypes.bool, dtypes.float32, dtypes.float32, 421 dtypes.int32, dtypes.float32, dtypes.float32, dtypes.float32, 422 dtypes.float32, dtypes.float32) 423 def dense_make_stats_update(is_active, are_buckets_ready, float_column, 424 quantile_buckets, example_partition_ids, gradients, 425 hessians, weights, empty_gradients, empty_hessians): 426 """Updates the state for dense split handler.""" 427 empty_float = constant_op.constant([], dtype=dtypes.float32) 428 429 quantile_values, quantile_weights = control_flow_ops.cond( 430 is_active[1], # For the next layer, this handler is inactive. 431 lambda: (float_column, weights), 432 lambda: (empty_float, empty_float)) 433 434 def ready_inputs_fn(): 435 """Branch to execute when quantiles are ready.""" 436 quantized_feature = quantile_ops.quantiles([float_column], [], 437 [quantile_buckets], [], []) 438 quantized_feature = math_ops.cast(quantized_feature[0], dtypes.int64) 439 quantized_feature = array_ops.squeeze(quantized_feature, axis=0) 440 return (example_partition_ids, quantized_feature, gradients, hessians) 441 442 def not_ready_inputs_fn(): 443 return (constant_op.constant([], dtype=dtypes.int32), 444 constant_op.constant([[]], dtype=dtypes.int64, shape=[1, 2]), 445 empty_gradients, empty_hessians) 446 447 example_partition_ids, feature_ids, gradients, hessians = ( 448 control_flow_ops.cond( 449 math_ops.logical_and(are_buckets_ready, is_active[0]), 450 ready_inputs_fn, not_ready_inputs_fn)) 451 return (quantile_values, quantile_weights, example_partition_ids, feature_ids, 452 gradients, hessians) 453 454 455 @function.Defun(dtypes.bool, dtypes.bool, dtypes.int64, dtypes.float32, 456 dtypes.int64, dtypes.float32, dtypes.int32, dtypes.float32, 457 dtypes.float32, dtypes.float32, dtypes.float32, dtypes.float32) 458 def sparse_make_stats_update( 459 is_active, are_buckets_ready, sparse_column_indices, sparse_column_values, 460 sparse_column_shape, quantile_buckets, example_partition_ids, gradients, 461 hessians, weights, empty_gradients, empty_hessians): 462 """Updates the state for this split handler.""" 463 464 def quantiles_ready(): 465 """The subgraph for when the quantiles are ready.""" 466 quantized_feature = quantile_ops.quantiles([], [sparse_column_values], [], 467 [quantile_buckets], 468 [sparse_column_indices]) 469 470 quantized_feature = math_ops.cast(quantized_feature[1], dtypes.int64) 471 quantized_feature = array_ops.squeeze(quantized_feature, axis=0) 472 473 example_indices, _ = array_ops.split( 474 sparse_column_indices, num_or_size_splits=2, axis=1) 475 example_indices = array_ops.squeeze(example_indices, [1]) 476 filtered_gradients = array_ops.gather(gradients, example_indices) 477 filtered_hessians = array_ops.gather(hessians, example_indices) 478 filtered_partition_ids = array_ops.gather(example_partition_ids, 479 example_indices) 480 unique_partitions, mapped_partitions = array_ops.unique( 481 example_partition_ids) 482 483 # Compute aggregate stats for each partition. 484 per_partition_gradients = math_ops.unsorted_segment_sum( 485 gradients, mapped_partitions, array_ops.size(unique_partitions)) 486 per_partition_hessians = math_ops.unsorted_segment_sum( 487 hessians, mapped_partitions, array_ops.size(unique_partitions)) 488 489 # Prepend a bias feature per partition that accumulates the stats for all 490 # examples in that partition. 491 bias_feature_ids = array_ops.fill( 492 array_ops.shape(unique_partitions), _BIAS_FEATURE_ID) 493 bias_feature_ids = math_ops.cast(bias_feature_ids, dtypes.int64) 494 zeros = array_ops.zeros_like(bias_feature_ids) 495 bias_feature_ids = array_ops.stack([bias_feature_ids, zeros], axis=1) 496 497 partition_ids = array_ops.concat( 498 [unique_partitions, filtered_partition_ids], 0) 499 filtered_gradients = array_ops.concat( 500 [per_partition_gradients, filtered_gradients], 0) 501 filtered_hessians = array_ops.concat( 502 [per_partition_hessians, filtered_hessians], 0) 503 504 bucket_ids = array_ops.concat([bias_feature_ids, quantized_feature], 0) 505 506 return partition_ids, bucket_ids, filtered_gradients, filtered_hessians 507 508 def quantiles_not_ready(): 509 """The subgraph for when the quantiles are not ready.""" 510 return (constant_op.constant([], dtype=dtypes.int32), 511 constant_op.constant([], dtype=dtypes.int64, shape=[1, 2]), 512 empty_gradients, empty_hessians) 513 514 empty_float = constant_op.constant([], dtype=dtypes.float32) 515 handler_not_active = (constant_op.constant( 516 [], dtype=dtypes.int64, shape=[0, 2]), empty_float, constant_op.constant( 517 [0, 1], dtype=dtypes.int64), empty_float) 518 handler_active = (sparse_column_indices, sparse_column_values, 519 sparse_column_shape, weights) 520 quantile_indices, quantile_values, quantile_shape, quantile_weights = ( 521 control_flow_ops.cond(is_active[1], lambda: handler_active, 522 lambda: handler_not_active)) 523 524 example_partition_ids, feature_ids, gradients, hessians = ( 525 control_flow_ops.cond(are_buckets_ready, quantiles_ready, 526 quantiles_not_ready)) 527 528 return (quantile_indices, quantile_values, quantile_shape, quantile_weights, 529 example_partition_ids, feature_ids, gradients, hessians) 530