Home | History | Annotate | Download | only in batch
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Implementation of handler for split nodes for float columns.
     16 
     17 The general idea in batch split finding is that each handler will accumulate its
     18 own statistics on multiple workers. After some steps, the master runs
     19 make_splits() sub-graph of each handler and each handler returns its best split
     20 per partition.
     21 
     22 The way we ensure consistency of statistics is by using stamp_tokens for read
     23 and write operations. During each update of the model, a new stamp token is
     24 created. This stamp token makes sure that updates from the previous iterations
     25 are not included in the statistics for this iteration.
     26 
     27 Inequality splits for float features are created similar to the method described
     28 in Approximate Algorithm described in https://arxiv.org/pdf/1603.02754v3.pdf.
     29 Weighted quantiles of the feature columns are computed in a distributed fashion
     30 using quantile_ops.quantile_accumulator.
     31 After certain number of steps of parallel accumulation of quantile statistics,
     32 we decide on bucket boundaries. These bucket boundaries are then used for the
     33 next N steps to accumulate gradients and hessians per bucket.
     34 
     35 In this implementation, we gather quantile statistics and gradient statistics
     36 concurrently. That means that we don't wait until we have enough quantile
     37 statistics for bucketization before we start gathering gradient stats. Instead
     38 during each step we create quantile stats for the next iteration and use the
     39 previous quantile buckets for gradient stats accumulation.
     40 In make_splits, we do these steps:
     41 1) Get the buckets that were used creating for the gradient stats.
     42 2) Create bucket boundaries for the next N iterations and clear the accumulated
     43    quantile stats.
     44 n3) Get the accumulated gradient stats and clear the accumulator. This step can
     45    run in parallel to step 2.
     46 4) For each leaf node in the current tree (partition):
     47    4.1) Get the overall gain computed with gradients and hessians of all
     48         examples that end up in this partition.
     49    4.2) Compute tensors of left and right cumulative sum of gradients, hessians
     50         and gain. The first dimension of these tensors are the bucket
     51         boundaries.
     52    4.3) Find the gains for all bucket boundaries:
     53         split_gains = left_gain + right_gain - overall_gain.
     54    4.4) Find the bucket boundary that has the best gain (argmax(split_gains))
     55    4.5) For Sparse handler, we also consider the gain for when the examples go
     56         the left child and when the examples go to the right child and pick the
     57         default direction that yields the most gain.
     58 """
     59 
     60 from __future__ import absolute_import
     61 from __future__ import division
     62 from __future__ import print_function
     63 
     64 import re
     65 
     66 from tensorflow.contrib.boosted_trees.lib.learner.batch import base_split_handler
     67 from tensorflow.contrib.boosted_trees.python.ops import quantile_ops
     68 from tensorflow.contrib.boosted_trees.python.ops import split_handler_ops
     69 from tensorflow.contrib.boosted_trees.python.ops import stats_accumulator_ops
     70 from tensorflow.python.framework import constant_op
     71 from tensorflow.python.framework import dtypes
     72 from tensorflow.python.framework import function
     73 from tensorflow.python.framework import ops
     74 from tensorflow.python.framework import sparse_tensor
     75 from tensorflow.python.ops import array_ops
     76 from tensorflow.python.ops import control_flow_ops
     77 from tensorflow.python.ops import math_ops
     78 _BIAS_FEATURE_ID = -1
     79 # Pattern to remove all non alpha numeric from a string.
     80 _PATTERN = re.compile(r"[\W_]+")
     81 
     82 
     83 class InequalitySplitHandler(base_split_handler.BaseSplitHandler):
     84   """Base class for handlers of inequality splits."""
     85 
     86   def __init__(self,
     87                l1_regularization,
     88                l2_regularization,
     89                tree_complexity_regularization,
     90                min_node_weight,
     91                feature_column_group_id,
     92                epsilon,
     93                num_quantiles,
     94                gradient_shape,
     95                hessian_shape,
     96                multiclass_strategy,
     97                init_stamp_token=0,
     98                name=None):
     99     """Initialize the internal state for this split handler.
    100 
    101     Args:
    102       l1_regularization: L1 regularization applied for this split handler.
    103       l2_regularization: L2 regularization applied for this split handler.
    104       tree_complexity_regularization: Tree complexity regularization applied
    105           for this split handler.
    106       min_node_weight: Minimum sum of weights of examples in each partition to
    107           be considered for splitting.
    108       feature_column_group_id: Feature column group index.
    109       epsilon: A float, the error bound for quantile computation.
    110       num_quantiles: An int, the number of buckets to create from the histogram.
    111       gradient_shape: A TensorShape, containing shape of gradients.
    112       hessian_shape: A TensorShape, containing shape of hessians.
    113       multiclass_strategy: Strategy describing how to treat multiclass problems.
    114       init_stamp_token: A tensor containing an scalar for initial stamp of the
    115          stamped objects.
    116       name: An optional handler name.
    117     """
    118     super(InequalitySplitHandler, self).__init__(
    119         name=name,
    120         l1_regularization=l1_regularization,
    121         l2_regularization=l2_regularization,
    122         tree_complexity_regularization=tree_complexity_regularization,
    123         min_node_weight=min_node_weight,
    124         feature_column_group_id=feature_column_group_id,
    125         gradient_shape=gradient_shape,
    126         hessian_shape=hessian_shape,
    127         multiclass_strategy=multiclass_strategy)
    128     self._stats_accumulator = stats_accumulator_ops.StatsAccumulator(
    129         init_stamp_token,
    130         gradient_shape,
    131         hessian_shape,
    132         name="StatsAccumulator/{}".format(self._name))
    133     self._quantile_accumulator = quantile_ops.QuantileAccumulator(
    134         init_stamp_token,
    135         epsilon=epsilon,
    136         num_quantiles=num_quantiles,
    137         name="QuantileAccumulator/{}".format(self._name))
    138 
    139 
    140 class DenseSplitHandler(InequalitySplitHandler):
    141   """Computes stats and finds the best inequality splits on dense columns."""
    142 
    143   def __init__(self,
    144                dense_float_column,
    145                l1_regularization,
    146                l2_regularization,
    147                tree_complexity_regularization,
    148                min_node_weight,
    149                feature_column_group_id,
    150                epsilon,
    151                num_quantiles,
    152                gradient_shape,
    153                hessian_shape,
    154                multiclass_strategy,
    155                init_stamp_token=0,
    156                name=None):
    157     """Initialize the internal state for this split handler.
    158 
    159     Args:
    160       dense_float_column: A `Tensor` column associated with this handler.
    161       l1_regularization: L1 regularization applied for this split handler.
    162       l2_regularization: L2 regularization applied for this split handler.
    163       tree_complexity_regularization: Tree complexity regularization applied
    164           for this split handler.
    165       min_node_weight: Minimum sum of weights of examples in each partition to
    166           be considered for splitting.
    167       feature_column_group_id: Feature column group index.
    168       epsilon: A float, the error bound for quantile computation.
    169       num_quantiles: An int, the number of buckets to create from the histogram.
    170       gradient_shape: A TensorShape, containing shape of gradients.
    171       hessian_shape: A TensorShape, containing shape of hessians.
    172       multiclass_strategy: Strategy describing how to treat multiclass problems.
    173       init_stamp_token: A tensor containing an scalar for initial stamp of the
    174          stamped objects.
    175       name: An optional handler name.
    176     """
    177     super(DenseSplitHandler, self).__init__(
    178         l1_regularization=l1_regularization,
    179         l2_regularization=l2_regularization,
    180         tree_complexity_regularization=tree_complexity_regularization,
    181         min_node_weight=min_node_weight,
    182         feature_column_group_id=feature_column_group_id,
    183         epsilon=epsilon,
    184         num_quantiles=num_quantiles,
    185         init_stamp_token=init_stamp_token,
    186         name=name,
    187         gradient_shape=gradient_shape,
    188         hessian_shape=hessian_shape,
    189         multiclass_strategy=multiclass_strategy)
    190     self._dense_float_column = dense_float_column
    191     # Register dense_make_stats_update function as an Op to the graph.
    192     g = ops.get_default_graph()
    193     dense_make_stats_update.add_to_graph(g)
    194 
    195   def scheduled_reads(self):
    196     return [self._quantile_accumulator.schedule_get_buckets()]
    197 
    198   def update_stats(self, stamp_token, example_partition_ids, gradients,
    199                    hessians, empty_gradients, empty_hessians, weights,
    200                    is_active, scheduled_reads):
    201     """Updates the state for dense split handler.
    202 
    203     Args:
    204       stamp_token: An int32 scalar tensor containing the current stamp token.
    205       example_partition_ids: A dense tensor, containing an int32 for each
    206         example which is the partition id that the example ends up in.
    207       gradients: A dense tensor of gradients.
    208       hessians: A dense tensor of hessians.
    209       empty_gradients: A dense empty tensor of the same shape (for dimensions >
    210         0) as gradients.
    211       empty_hessians: A dense empty tensor of the same shape (for dimensions >
    212         0) as hessians.
    213       weights: A dense float32 tensor with a weight for each example.
    214       is_active: A boolean tensor that says if this handler is active or not.
    215           One value for the current layer and one value for the next layer.
    216       scheduled_reads: List of scheduled reads for this handler.
    217 
    218     Returns:
    219       The op that updates the stats for this handler.
    220     """
    221     name = _PATTERN.sub("", self._name)
    222     with ops.name_scope(name, "DenseSplitHandler"):
    223       are_buckets_ready, buckets = scheduled_reads[0]
    224       (quantile_values, quantile_weights, example_partition_ids,
    225        feature_ids, gradients, hessians) = dense_make_stats_update(
    226            is_active, are_buckets_ready, self._dense_float_column, buckets,
    227            example_partition_ids, gradients, hessians, weights, empty_gradients,
    228            empty_hessians)
    229       update_quantiles = self._quantile_accumulator.schedule_add_summary(
    230           stamp_token=stamp_token,
    231           column=quantile_values,
    232           example_weights=quantile_weights)
    233       update_stats = self._stats_accumulator.schedule_add(
    234           example_partition_ids, feature_ids, gradients, hessians)
    235       return control_flow_ops.no_op(), [update_quantiles, update_stats]
    236 
    237   def make_splits(self, stamp_token, next_stamp_token, class_id):
    238     """Create the best split using the accumulated stats and flush the state."""
    239     # Get the bucket boundaries
    240     are_splits_ready, buckets = (
    241         self._quantile_accumulator.get_buckets(stamp_token))
    242     # After we receive the boundaries from previous iteration we can flush
    243     # the quantile accumulator.
    244     with ops.control_dependencies([buckets]):
    245       flush_quantiles = self._quantile_accumulator.flush(
    246           stamp_token=stamp_token, next_stamp_token=next_stamp_token)
    247 
    248     # Get the aggregated gradients and hessians per <partition_id, feature_id>
    249     # pair.
    250     # In order to distribute the computation on all the PSs we use the PS that
    251     # had the stats accumulator on.
    252     with ops.device(None):
    253       with ops.device(self._stats_accumulator.resource().device):
    254         num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
    255             self._stats_accumulator.flush(stamp_token, next_stamp_token))
    256 
    257         # Put quantile and stats accumulator flushing in the dependency path.
    258         are_splits_ready = control_flow_ops.with_dependencies(
    259             [flush_quantiles, partition_ids], are_splits_ready)
    260 
    261         partition_ids, gains, split_infos = (
    262             split_handler_ops.build_dense_inequality_splits(
    263                 num_minibatches=num_minibatches,
    264                 bucket_boundaries=buckets,
    265                 partition_ids=partition_ids,
    266                 bucket_ids=bucket_ids,
    267                 gradients=gradients,
    268                 hessians=hessians,
    269                 class_id=class_id,
    270                 feature_column_group_id=self._feature_column_group_id,
    271                 l1_regularization=self._l1_regularization,
    272                 l2_regularization=self._l2_regularization,
    273                 tree_complexity_regularization=self.
    274                 _tree_complexity_regularization,
    275                 min_node_weight=self._min_node_weight,
    276                 multiclass_strategy=self._multiclass_strategy))
    277     return (are_splits_ready, partition_ids, gains, split_infos)
    278 
    279 
    280 class SparseSplitHandler(InequalitySplitHandler):
    281   """Computes stats and finds the best inequality splits on sparse columns."""
    282 
    283   def __init__(self,
    284                sparse_float_column,
    285                l1_regularization,
    286                l2_regularization,
    287                tree_complexity_regularization,
    288                min_node_weight,
    289                feature_column_group_id,
    290                epsilon,
    291                num_quantiles,
    292                gradient_shape,
    293                hessian_shape,
    294                multiclass_strategy,
    295                init_stamp_token=0,
    296                name=None):
    297     """Initialize the internal state for this split handler.
    298 
    299     Args:
    300       sparse_float_column: A `SparseTensor` column associated with this handler.
    301       l1_regularization: L1 regularization applied for this split handler.
    302       l2_regularization: L2 regularization applied for this split handler.
    303       tree_complexity_regularization: Tree complexity regularization applied
    304           for this split handler.
    305       min_node_weight: Minimum sum of weights of examples in each partition to
    306           be considered for splitting.
    307       feature_column_group_id: Feature column group index.
    308       epsilon: A float, the error bound for quantile computation.
    309       num_quantiles: An int, the number of buckets to create from the histogram.
    310       gradient_shape: A TensorShape, containing shape of gradients.
    311       hessian_shape: A TensorShape, containing shape of hessians.
    312       multiclass_strategy: Strategy describing how to treat multiclass problems.
    313       init_stamp_token: A tensor containing an scalar for initial stamp of the
    314          stamped objects.
    315       name: An optional handler name.
    316     """
    317     super(SparseSplitHandler, self).__init__(
    318         l1_regularization=l1_regularization,
    319         l2_regularization=l2_regularization,
    320         tree_complexity_regularization=tree_complexity_regularization,
    321         min_node_weight=min_node_weight,
    322         feature_column_group_id=feature_column_group_id,
    323         epsilon=epsilon,
    324         num_quantiles=num_quantiles,
    325         gradient_shape=gradient_shape,
    326         hessian_shape=hessian_shape,
    327         multiclass_strategy=multiclass_strategy,
    328         init_stamp_token=init_stamp_token,
    329         name=name)
    330     # Register sparse_make_stats_update function as an Op to the graph.
    331     g = ops.get_default_graph()
    332     sparse_make_stats_update.add_to_graph(g)
    333     self._sparse_float_column = sparse_float_column
    334 
    335   def scheduled_reads(self):
    336     return [self._quantile_accumulator.schedule_get_buckets()]
    337 
    338   def update_stats(self, stamp_token, example_partition_ids, gradients,
    339                    hessians, empty_gradients, empty_hessians, weights,
    340                    is_active, scheduled_reads):
    341     """Updates the state for dense split handler.
    342 
    343     Args:
    344       stamp_token: An int32 scalar tensor containing the current stamp token.
    345       example_partition_ids: A dense tensor, containing an int32 for each
    346         example which is the partition id that the example ends up in.
    347       gradients: A dense tensor of gradients.
    348       hessians: A dense tensor of hessians.
    349       empty_gradients: A dense empty tensor of the same shape (for dimensions >
    350         0) as gradients.
    351       empty_hessians: A dense empty tensor of the same shape (for dimensions >
    352         0) as hessians.
    353       weights: A dense float32 tensor with a weight for each example.
    354       is_active: A boolean tensor that says if this handler is active or not.
    355           One value for the current layer and one value for the next layer.
    356       scheduled_reads: List of results from the scheduled reads.
    357 
    358     Returns:
    359       The op that updates the stats for this handler.
    360     """
    361     are_buckets_ready, buckets = scheduled_reads[0]
    362     with ops.name_scope(self._name, "SparseSplitHandler"):
    363       (quantile_indices, quantile_values, quantile_shapes, quantile_weights,
    364        example_partition_ids,
    365        feature_ids, gradients, hessians) = sparse_make_stats_update(
    366            is_active, are_buckets_ready, self._sparse_float_column.indices,
    367            self._sparse_float_column.values,
    368            self._sparse_float_column.dense_shape, buckets,
    369            example_partition_ids, gradients, hessians, weights, empty_gradients,
    370            empty_hessians)
    371       update_quantiles = self._quantile_accumulator.schedule_add_summary(
    372           stamp_token=stamp_token,
    373           column=sparse_tensor.SparseTensor(quantile_indices, quantile_values,
    374                                             quantile_shapes),
    375           example_weights=quantile_weights)
    376       update_stats = self._stats_accumulator.schedule_add(
    377           example_partition_ids, feature_ids, gradients, hessians)
    378       return (control_flow_ops.no_op(), [update_quantiles, update_stats])
    379 
    380   def make_splits(self, stamp_token, next_stamp_token, class_id):
    381     """Create the best split using the accumulated stats and flush the state."""
    382     # Get the bucket boundaries
    383     are_splits_ready, buckets = (
    384         self._quantile_accumulator.get_buckets(stamp_token))
    385 
    386     # After we receive the boundaries from previous iteration we can flush
    387     # the quantile accumulator.
    388     with ops.control_dependencies([buckets]):
    389       flush_quantiles = self._quantile_accumulator.flush(
    390           stamp_token=stamp_token, next_stamp_token=next_stamp_token)
    391 
    392     with ops.device(None):
    393       with ops.device(self._stats_accumulator.resource().device):
    394         num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
    395             self._stats_accumulator.flush(stamp_token, next_stamp_token))
    396 
    397         # Put quantile and stats accumulator flushing in the dependency path.
    398         are_splits_ready = control_flow_ops.with_dependencies(
    399             [flush_quantiles, partition_ids], are_splits_ready)
    400         partition_ids, gains, split_infos = (
    401             split_handler_ops.build_sparse_inequality_splits(
    402                 num_minibatches=num_minibatches,
    403                 bucket_boundaries=buckets,
    404                 partition_ids=partition_ids,
    405                 bucket_ids=bucket_ids,
    406                 gradients=gradients,
    407                 hessians=hessians,
    408                 class_id=class_id,
    409                 feature_column_group_id=self._feature_column_group_id,
    410                 l1_regularization=self._l1_regularization,
    411                 l2_regularization=self._l2_regularization,
    412                 tree_complexity_regularization=self.
    413                 _tree_complexity_regularization,
    414                 min_node_weight=self._min_node_weight,
    415                 bias_feature_id=_BIAS_FEATURE_ID,
    416                 multiclass_strategy=self._multiclass_strategy))
    417     return (are_splits_ready, partition_ids, gains, split_infos)
    418 
    419 
    420 @function.Defun(dtypes.bool, dtypes.bool, dtypes.float32, dtypes.float32,
    421                 dtypes.int32, dtypes.float32, dtypes.float32, dtypes.float32,
    422                 dtypes.float32, dtypes.float32)
    423 def dense_make_stats_update(is_active, are_buckets_ready, float_column,
    424                             quantile_buckets, example_partition_ids, gradients,
    425                             hessians, weights, empty_gradients, empty_hessians):
    426   """Updates the state for dense split handler."""
    427   empty_float = constant_op.constant([], dtype=dtypes.float32)
    428 
    429   quantile_values, quantile_weights = control_flow_ops.cond(
    430       is_active[1],  # For the next layer, this handler is inactive.
    431       lambda: (float_column, weights),
    432       lambda: (empty_float, empty_float))
    433 
    434   def ready_inputs_fn():
    435     """Branch to execute when quantiles are ready."""
    436     quantized_feature = quantile_ops.quantiles([float_column], [],
    437                                                [quantile_buckets], [], [])
    438     quantized_feature = math_ops.cast(quantized_feature[0], dtypes.int64)
    439     quantized_feature = array_ops.squeeze(quantized_feature, axis=0)
    440     return (example_partition_ids, quantized_feature, gradients, hessians)
    441 
    442   def not_ready_inputs_fn():
    443     return (constant_op.constant([], dtype=dtypes.int32),
    444             constant_op.constant([[]], dtype=dtypes.int64, shape=[1, 2]),
    445             empty_gradients, empty_hessians)
    446 
    447   example_partition_ids, feature_ids, gradients, hessians = (
    448       control_flow_ops.cond(
    449           math_ops.logical_and(are_buckets_ready, is_active[0]),
    450           ready_inputs_fn, not_ready_inputs_fn))
    451   return (quantile_values, quantile_weights, example_partition_ids, feature_ids,
    452           gradients, hessians)
    453 
    454 
    455 @function.Defun(dtypes.bool, dtypes.bool, dtypes.int64, dtypes.float32,
    456                 dtypes.int64, dtypes.float32, dtypes.int32, dtypes.float32,
    457                 dtypes.float32, dtypes.float32, dtypes.float32, dtypes.float32)
    458 def sparse_make_stats_update(
    459     is_active, are_buckets_ready, sparse_column_indices, sparse_column_values,
    460     sparse_column_shape, quantile_buckets, example_partition_ids, gradients,
    461     hessians, weights, empty_gradients, empty_hessians):
    462   """Updates the state for this split handler."""
    463 
    464   def quantiles_ready():
    465     """The subgraph for when the quantiles are ready."""
    466     quantized_feature = quantile_ops.quantiles([], [sparse_column_values], [],
    467                                                [quantile_buckets],
    468                                                [sparse_column_indices])
    469 
    470     quantized_feature = math_ops.cast(quantized_feature[1], dtypes.int64)
    471     quantized_feature = array_ops.squeeze(quantized_feature, axis=0)
    472 
    473     example_indices, _ = array_ops.split(
    474         sparse_column_indices, num_or_size_splits=2, axis=1)
    475     example_indices = array_ops.squeeze(example_indices, [1])
    476     filtered_gradients = array_ops.gather(gradients, example_indices)
    477     filtered_hessians = array_ops.gather(hessians, example_indices)
    478     filtered_partition_ids = array_ops.gather(example_partition_ids,
    479                                               example_indices)
    480     unique_partitions, mapped_partitions = array_ops.unique(
    481         example_partition_ids)
    482 
    483     # Compute aggregate stats for each partition.
    484     per_partition_gradients = math_ops.unsorted_segment_sum(
    485         gradients, mapped_partitions, array_ops.size(unique_partitions))
    486     per_partition_hessians = math_ops.unsorted_segment_sum(
    487         hessians, mapped_partitions, array_ops.size(unique_partitions))
    488 
    489     # Prepend a bias feature per partition that accumulates the stats for all
    490     # examples in that partition.
    491     bias_feature_ids = array_ops.fill(
    492         array_ops.shape(unique_partitions), _BIAS_FEATURE_ID)
    493     bias_feature_ids = math_ops.cast(bias_feature_ids, dtypes.int64)
    494     zeros = array_ops.zeros_like(bias_feature_ids)
    495     bias_feature_ids = array_ops.stack([bias_feature_ids, zeros], axis=1)
    496 
    497     partition_ids = array_ops.concat(
    498         [unique_partitions, filtered_partition_ids], 0)
    499     filtered_gradients = array_ops.concat(
    500         [per_partition_gradients, filtered_gradients], 0)
    501     filtered_hessians = array_ops.concat(
    502         [per_partition_hessians, filtered_hessians], 0)
    503 
    504     bucket_ids = array_ops.concat([bias_feature_ids, quantized_feature], 0)
    505 
    506     return partition_ids, bucket_ids, filtered_gradients, filtered_hessians
    507 
    508   def quantiles_not_ready():
    509     """The subgraph for when the quantiles are not ready."""
    510     return (constant_op.constant([], dtype=dtypes.int32),
    511             constant_op.constant([], dtype=dtypes.int64, shape=[1, 2]),
    512             empty_gradients, empty_hessians)
    513 
    514   empty_float = constant_op.constant([], dtype=dtypes.float32)
    515   handler_not_active = (constant_op.constant(
    516       [], dtype=dtypes.int64, shape=[0, 2]), empty_float, constant_op.constant(
    517           [0, 1], dtype=dtypes.int64), empty_float)
    518   handler_active = (sparse_column_indices, sparse_column_values,
    519                     sparse_column_shape, weights)
    520   quantile_indices, quantile_values, quantile_shape, quantile_weights = (
    521       control_flow_ops.cond(is_active[1], lambda: handler_active,
    522                             lambda: handler_not_active))
    523 
    524   example_partition_ids, feature_ids, gradients, hessians = (
    525       control_flow_ops.cond(are_buckets_ready, quantiles_ready,
    526                             quantiles_not_ready))
    527 
    528   return (quantile_indices, quantile_values, quantile_shape, quantile_weights,
    529           example_partition_ids, feature_ids, gradients, hessians)
    530