Home | History | Annotate | Download | only in distribute
      1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Various classes representing distributed inputs."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.python.data.experimental.ops import batching
     22 from tensorflow.python.data.ops import dataset_ops
     23 from tensorflow.python.data.ops import multi_device_iterator_ops
     24 from tensorflow.python.data.util import structure
     25 from tensorflow.python.distribute import device_util
     26 from tensorflow.python.distribute import distribution_strategy_context
     27 from tensorflow.python.distribute import input_ops
     28 from tensorflow.python.distribute import values
     29 from tensorflow.python.eager import context
     30 from tensorflow.python.framework import constant_op
     31 from tensorflow.python.framework import device as tf_device
     32 from tensorflow.python.framework import ops
     33 from tensorflow.python.framework import tensor_shape
     34 from tensorflow.python.framework import tensor_util
     35 from tensorflow.python.ops import array_ops
     36 from tensorflow.python.ops import control_flow_ops
     37 from tensorflow.python.ops import math_ops
     38 from tensorflow.python.util import nest
     39 
     40 
     41 class InputWorkers(object):
     42   """A 1-to-many mapping from input worker devices to compute devices."""
     43 
     44   def __init__(self, device_map, worker_device_pairs=None, logical_device=0):
     45     """Initialize an `InputWorkers` object.
     46 
     47     Args:
     48       device_map: A `DeviceMap` with the computation devices fed by the
     49         input workers.
     50       worker_device_pairs: A sequence of pairs:
     51         `(input device, a tuple of compute devices fed by that input device)`.
     52       logical_device: The logical device of `device_map` to feed.
     53     """
     54     self._device_map = device_map
     55     self._logical_device = logical_device
     56     if worker_device_pairs is None:
     57       worker_device_pairs = ((
     58           device_util.canonicalize("/device:CPU:0"),
     59           device_map.logical_to_actual_devices(logical_device)),)
     60     self._input_worker_devices = tuple(d for d, _ in worker_device_pairs)
     61     self._fed_devices = tuple(tuple(device_util.canonicalize(d) for d in f)
     62                               for _, f in worker_device_pairs)
     63     flattened = tuple(d for l in self._fed_devices for d in l)
     64     assert (flattened ==
     65             device_map.logical_to_actual_devices(logical_device)), (
     66                 "flattened: %s logical device %d: %s" %
     67                 (flattened, logical_device,
     68                  device_map.logical_to_actual_devices(logical_device)))
     69 
     70   @property
     71   def device_map(self):
     72     return self._device_map
     73 
     74   @property
     75   def logical_device(self):
     76     return self._logical_device
     77 
     78   @property
     79   def num_workers(self):
     80     return len(self._input_worker_devices)
     81 
     82   @property
     83   def worker_devices(self):
     84     return self._input_worker_devices
     85 
     86   def compute_devices_for_worker(self, worker_index):
     87     return self._fed_devices[worker_index]
     88 
     89   def __repr__(self):
     90     devices = self.worker_devices
     91     debug_repr = ",\n".join("  %d %s: %s" %
     92                             (i, devices[i], self._fed_devices[i])
     93                             for i in range(len(devices)))
     94     return "%s:{\n%s\n  device_map: %s}" % (
     95         self.__class__.__name__, debug_repr, self._device_map)
     96 
     97 
     98 class InputIterator(object):
     99   """An input iterator, intended to be passed to `DistributionStrategy.run`."""
    100 
    101   def get_next(self):
    102     """Returns the next inputs for all replicas."""
    103     raise NotImplementedError("must be implemented in descendants")
    104 
    105   def initialize(self):
    106     """Initialize the underlying input dataset, when applicable.
    107 
    108     In eager mode, this will create a new iterator and return it.
    109     In graph mode, this will initialize the same underlying iterator(s).
    110 
    111     Users are required to call this if
    112     - This iterator was returned from a call to `make_input_fn_iterator` with an
    113       input function that returns a dataset.
    114     - Or this iterator was returned from a call to `make_dataset_iterator`.
    115 
    116     Returns:
    117       A list of initialization ops to be executed.
    118     """
    119     raise NotImplementedError("must be implemented in descendants")
    120 
    121 
    122 class InputIteratorImpl(InputIterator):
    123   """Common implementation for all input iterators."""
    124 
    125   def __init__(self, input_workers, iterators):
    126     assert isinstance(input_workers, InputWorkers)
    127     if not input_workers.worker_devices:
    128       raise ValueError("Should have at least one worker for input iterator.")
    129 
    130     self._iterators = iterators
    131     self._input_workers = input_workers
    132 
    133   def get_next(self, name=None):
    134     """Returns the next input from the iterator for all replicas."""
    135     replicas = []
    136     worker_has_values = []
    137     for i, worker in enumerate(self._input_workers.worker_devices):
    138       if name is not None:
    139         d = tf_device.DeviceSpec.from_string(worker)
    140         new_name = "%s_%s_%d" % (name, d.job, d.task)
    141       else:
    142         new_name = None
    143       with ops.device(worker):
    144         worker_has_value, next_element = (
    145             self._iterators[i].get_next_as_list(new_name))
    146         worker_has_values.append(worker_has_value)
    147         # Make `replicas` a flat list of values across all replicas.
    148         replicas.append(next_element)
    149 
    150     out_of_range_replicas = []
    151 
    152     def out_of_range_fn(worker_index, device):
    153       """This function will throw an OutOfRange error."""
    154       # As this will be only called when there is no data left, so calling
    155       # get_next() will trigger an OutOfRange error.
    156       data = self._iterators[worker_index].get_next(device)
    157       out_of_range_replicas.append(data)
    158       return data
    159 
    160     # `global_has_value` indicates whether there is data in this global batch.
    161     # We do a all-reduce across all the workers in the multi-worker case.
    162     # TODO(b/126259107): Do strategy.reduce for CollectiveAllReduceStrategy.
    163     if len(worker_has_values) > 1:
    164       with ops.device(self._input_workers.compute_devices_for_worker(0)[0]):
    165         # Place the tf.reduce_any op in device 0 to minimize communication
    166         # cost.
    167         # TODO(b/128545270): Investigate why placing it on worker 0 will cause
    168         # the entire data to copy back from device to host.
    169         global_has_value = math_ops.reduce_any(worker_has_values)
    170     else:
    171       global_has_value = worker_has_values[0]
    172 
    173     results = []
    174     for i, worker in enumerate(self._input_workers.worker_devices):
    175       with ops.device(worker):
    176         devices = self._input_workers.compute_devices_for_worker(i)
    177         for j, device in enumerate(devices):
    178           with ops.device(device):
    179             # pylint: disable=undefined-loop-variable
    180             # pylint: disable=cell-var-from-loop
    181             # It is fine for the lambda to capture variables from the loop as
    182             # the lambda is executed in the loop as well.
    183             result = control_flow_ops.cond(global_has_value,
    184                                            lambda: replicas[i][j],
    185                                            lambda: out_of_range_fn(i, device))
    186             # pylint: enable=cell-var-from-loop
    187             # pylint: enable=undefined-loop-variable
    188             results.append(result)
    189     replicas = results
    190 
    191     # Some dimensions in `replicas` will become unknown after we conditionally
    192     # return the real tensors or the dummy tensors. We fix the input shapes by
    193     # using the shapes from `out_of_range_replicas` because it is calling
    194     # get_next() inside.
    195     flattened_replicas = nest.flatten(replicas)
    196     for i, replica_data in enumerate(nest.flatten(out_of_range_replicas)):
    197       flattened_replicas[i].set_shape(replica_data.get_shape())
    198     replicas = nest.pack_sequence_as(replicas, flattened_replicas)
    199 
    200     return values.regroup(self._input_workers.device_map, replicas)
    201 
    202   def initialize(self):
    203     """Initialze underlying iterators.
    204 
    205     Returns:
    206       A list of any initializer ops that should be run.
    207     """
    208     init_ops = []
    209     for it in self._iterators:
    210       init_ops.extend(it.initialize())
    211     return init_ops
    212 
    213   # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
    214   @property
    215   def output_classes(self):
    216     return self._iterators[0].output_classes
    217 
    218   # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
    219   @property
    220   def output_shapes(self):
    221     return self._iterators[0].output_shapes
    222 
    223   # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
    224   @property
    225   def output_types(self):
    226     return self._iterators[0].output_types
    227 
    228   # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
    229   def get_iterator(self, worker):
    230     for i, w in enumerate(self._input_workers.worker_devices):
    231       if worker == w:
    232         return self._iterators[i]
    233     return None
    234 
    235 
    236 class InputFunctionIterator(InputIteratorImpl):
    237   """Iterator created from input function."""
    238 
    239   def __init__(self, input_fn, input_workers, input_contexts):
    240     """Make an iterator for input provided via an input function.
    241 
    242     Currently implements PER_WORKER mode, in which the `input_fn` is called
    243     once on each worker.
    244 
    245     TODO(priyag): Add other replication modes.
    246 
    247     Args:
    248       input_fn: Input function that returns a `tf.data.Dataset` object.
    249       input_workers: an `InputWorkers` object.
    250       input_contexts: A list of `InputContext` instances to be passed to call(s)
    251         to `input_fn`. Length and order should match worker order in
    252         `worker_device_pairs`.
    253     """
    254     assert isinstance(input_workers, InputWorkers)
    255     if input_workers.num_workers != len(input_contexts):
    256       raise ValueError(
    257           "Number of input workers (%d) is not same as number of "
    258           "input_contexts (%d)" %
    259           (input_workers.num_workers, len(input_contexts)))
    260 
    261     iterators = []
    262     for i, ctx in enumerate(input_contexts):
    263       worker = input_workers.worker_devices[i]
    264       with ops.device(worker):
    265         result = input_fn(ctx)
    266         devices = input_workers.compute_devices_for_worker(i)
    267         if isinstance(result, dataset_ops.DatasetV2):
    268           iterator = _SingleWorkerDatasetIterator(result, worker, devices)
    269         elif callable(result):
    270           iterator = _SingleWorkerCallableIterator(result, worker, devices)
    271         else:
    272           raise ValueError(
    273               "input_fn must return a tf.data.Dataset or a callable.")
    274         iterators.append(iterator)
    275 
    276     super(InputFunctionIterator, self).__init__(input_workers, iterators)
    277 
    278 
    279 class DatasetIterator(InputIteratorImpl):
    280   """Iterator created from input dataset."""
    281 
    282   def __init__(self, dataset, input_workers, split_batch_by=None):
    283     """Make an iterator for the dataset on given devices.
    284 
    285     If `split_batch_by` is not None, we "split" each batch of the
    286     dataset by `split_batch_by` value. To achieve this, we first unbatch the
    287     input dataset and then rebatch it with the per replica batch size that is
    288     calculated using `global_batch_size // split_batch_by`.
    289     The currently supported datasets are as follows:
    290     `dataset.batch()` is the last operation on the dataset OR
    291     `dataset.apply(map_and_batch)` is the last operation on the dataset OR
    292     `dataset.batch().prefetch()` are the last 2 operations on the dataset OR
    293     `dataset.apply(map_and_batch).prefetch()` are the last 2 operations.
    294 
    295     TODO(priyag): Support multi worker / host cases properly by cloning
    296     and sharding the dataset on each worker. Current setup will only work in
    297     some cases, such as in-graph multi worker GPU case. If the input pipeline
    298     has random shuffling (with a different seed on each worker), each worker
    299     will see random input from the same overall dataset in each step. Otherwise,
    300     each worker will see the same input in each step.
    301 
    302     Args:
    303       dataset: `tf.data.Dataset` that will be used as the input source.
    304       input_workers: an `InputWorkers` object.
    305       split_batch_by: Optional integer. If present, we "split" each batch of the
    306         dataset by `split_batch_by` value.
    307     """
    308     assert isinstance(input_workers, InputWorkers)
    309     if split_batch_by:
    310       dataset = batching._RebatchDataset(dataset, split_batch_by)  # pylint: disable=protected-access
    311 
    312     iterators = []
    313     for i, worker in enumerate(input_workers.worker_devices):
    314       with ops.device(worker):
    315         worker_devices = input_workers.compute_devices_for_worker(i)
    316         cloned_dataset = dataset
    317         if not context.executing_eagerly():
    318           cloned_dataset = input_ops._clone_dataset(dataset)  # pylint: disable=protected-access
    319           cloned_dataset = cloned_dataset.with_options(dataset.options())
    320         iterator = _SingleWorkerDatasetIterator(cloned_dataset, worker,
    321                                                 worker_devices)
    322         iterators.append(iterator)
    323 
    324     self._element_structure = dataset._element_structure  # pylint: disable=protected-access
    325 
    326     super(DatasetIterator, self).__init__(input_workers, iterators)
    327 
    328 
    329 def _dummy_tensor_fn(value_structure):
    330   """A function to create dummy tensors from `value_structure`."""
    331 
    332   def create_dummy_tensor(feature_shape, feature_type):
    333     """Create a dummy tensor with possible batch dimensions set to 0."""
    334 
    335     # Ideally we should set the batch dimension to 0, however as in
    336     # DistributionStrategy we don't know the batch dimension, we try to
    337     # guess it as much as possible. If the feature has unknown dimensions, we
    338     # will set them to 0. If the feature shape is already static, we guess the
    339     # first dimension as batch dimension and set it to 0.
    340     dims = []
    341     for dim in feature_shape.dims:
    342       if dim.value is None:
    343         dims.append(tensor_shape.Dimension(0))
    344       else:
    345         dims.append(dim)
    346     if feature_shape.is_fully_defined() and dims:
    347       dims[0] = tensor_shape.Dimension(0)
    348 
    349     # Create the dummy tensor.
    350     dummy_tensor = array_ops.zeros(tensor_shape.TensorShape(dims), feature_type)
    351     return dummy_tensor
    352 
    353   result = []
    354   # pylint: disable=protected-access
    355   for feature_shape, feature_type in zip(value_structure._flat_shapes,
    356                                          value_structure._flat_types):
    357     result.append(create_dummy_tensor(feature_shape, feature_type))
    358 
    359   if isinstance(value_structure, structure.NestedStructure):
    360     result = nest.pack_sequence_as(value_structure._nested_structure, result)
    361   else:
    362     result = result[0]
    363   # pylint: enable=protected-access
    364 
    365   return result
    366 
    367 
    368 class _SingleWorkerDatasetIterator(object):
    369   """Iterator for a single `tf.data.Dataset`."""
    370 
    371   def __init__(self, dataset, worker, devices):
    372     """Create iterator for the `dataset` to fetch data to worker's `devices` .
    373 
    374     `MultiDeviceIterator` is used to prefetch input to the devices on the
    375     given worker.
    376 
    377     Args:
    378       dataset: A `tf.data.Dataset` instance.
    379       worker: Worker on which ops should be created.
    380       devices: Distribute data from `dataset` to these devices.
    381     """
    382     self._dataset = dataset
    383     self._worker = worker
    384     self._devices = devices
    385     self._make_iterator()
    386 
    387   def _make_iterator(self):
    388     """Make appropriate iterator on the dataset."""
    389     with ops.device(self._worker):
    390       self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
    391           self._dataset, self._devices)
    392 
    393   def get_next(self, device, name=None):
    394     """Get next element for the given device."""
    395     del name
    396     with ops.device(self._worker):
    397       return self._iterator.get_next(device)
    398 
    399   def get_next_as_list(self, name=None):
    400     """Get next element from underlying iterator.
    401 
    402     If there is no data left, a list of dummy tensors with possible batch
    403     dimensions set to 0 will be returned.
    404 
    405     Args:
    406       name: not used.
    407 
    408     Returns:
    409       A boolean tensor indicates whether there is any data in next element and
    410       the real data as the next element or a list of dummy tensors if no data
    411       left.
    412     """
    413     del name
    414     with ops.device(self._worker):
    415       data_list = self._iterator.get_next_as_optional()
    416       result = []
    417       for i, data in enumerate(data_list):
    418         # Place the condition op in the same device as the data so the data
    419         # doesn't need to be sent back to the worker.
    420         with ops.device(self._devices[i]):
    421           # As MultiDeviceIterator will fetch data in order, so we only need to
    422           # check if the first replica has value to see whether there is data
    423           # left for this single worker.
    424           if i == 0:
    425             worker_has_value = data.has_value()
    426 
    427           # pylint: disable=unnecessary-lambda
    428           # pylint: disable=cell-var-from-loop
    429           real_data = control_flow_ops.cond(
    430               data.has_value(),
    431               lambda: data.get_value(),
    432               lambda: _dummy_tensor_fn(data.value_structure))
    433           result.append(real_data)
    434           # pylint: enable=cell-var-from-loop
    435           # pylint: enable=unnecessary-lambda
    436 
    437       return worker_has_value, result
    438 
    439   def initialize(self):
    440     """Initialze underlying iterator.
    441 
    442     In eager execution, this simply recreates the underlying iterator.
    443     In graph execution, it returns the initializer ops for the underlying
    444     iterator.
    445 
    446     Returns:
    447       A list of any initializer ops that should be run.
    448     """
    449     if context.executing_eagerly():
    450       self._iterator._eager_reset()  # pylint: disable=protected-access
    451       return []
    452     else:
    453       return [self._iterator.initializer]
    454 
    455   @property
    456   def output_classes(self):
    457     return dataset_ops.get_legacy_output_classes(self._iterator)
    458 
    459   @property
    460   def output_shapes(self):
    461     return dataset_ops.get_legacy_output_shapes(self._iterator)
    462 
    463   @property
    464   def output_types(self):
    465     return dataset_ops.get_legacy_output_types(self._iterator)
    466 
    467 
    468 class _SingleWorkerCallableIterator(object):
    469   """Iterator for a single tensor-returning callable."""
    470 
    471   def __init__(self, fn, worker, devices):
    472     self._fn = fn
    473     self._worker = worker
    474     self._devices = devices
    475 
    476   def get_next(self, device, name=None):
    477     """Get next element for the given device from the callable."""
    478     del device, name
    479     with ops.device(self._worker):
    480       return self._fn()
    481 
    482   def get_next_as_list(self, name=None):
    483     """Get next element from the callable."""
    484     del name
    485     with ops.device(self._worker):
    486       data_list = [self._fn() for _ in self._devices]
    487       return constant_op.constant(True), data_list
    488 
    489   def initialize(self):
    490     # TODO(petebu) Should this throw an exception instead?
    491     return []
    492 
    493 
    494 # TODO(sourabhbajaj): Remove this in lieu of distributed datasets
    495 def _get_batched_dataset(d):
    496   """Get the batched dataset from `d`."""
    497   # pylint: disable=protected-access
    498   if isinstance(d, dataset_ops.DatasetV1Adapter):
    499     d = d._dataset
    500 
    501   if isinstance(d, (dataset_ops.BatchDataset, batching._MapAndBatchDataset)):
    502     return d
    503   elif isinstance(d, (dataset_ops.PrefetchDataset,
    504                       dataset_ops._OptionsDataset)):
    505     return _get_batched_dataset(d._input_dataset)
    506 
    507   raise ValueError(
    508       "Unable to get batched dataset from the input dataset. `batch` "
    509       "`map_and_batch` need to be the last operations on the dataset. "
    510       "The batch operations can be followed by a prefetch.")
    511 
    512 
    513 def _get_batched_dataset_attributes(d):
    514   """Get `batch_size`, `drop_remainder` of dataset."""
    515   # pylint: disable=protected-access
    516   assert isinstance(d,
    517                     (dataset_ops.BatchDataset, batching._MapAndBatchDataset))
    518   if isinstance(d, dataset_ops.BatchDataset):
    519     batch_size = d._batch_size
    520     drop_remainder = d._drop_remainder
    521   elif isinstance(d, batching._MapAndBatchDataset):
    522     batch_size = d._batch_size_t
    523     drop_remainder = d._drop_remainder_t
    524   # pylint: enable=protected-access
    525 
    526   if tensor_util.is_tensor(batch_size):
    527     batch_size = tensor_util.constant_value(batch_size)
    528 
    529   if tensor_util.is_tensor(drop_remainder):
    530     drop_remainder = tensor_util.constant_value(drop_remainder)
    531 
    532   return batch_size, drop_remainder
    533 
    534 
    535 # TODO(sourabhbajaj): Remove this in lieu of distributed datasets
    536 def _get_dataset_attributes(dataset):
    537   """Get the underlying attributes from the dataset object."""
    538   # pylint: disable=protected-access
    539 
    540   # First, get batch_size and drop_remainder from the dataset. We need
    541   # to walk back the dataset creation process and find the batched version in
    542   # order to get the attributes.
    543   batched_dataset = _get_batched_dataset(dataset)
    544   batch_size, drop_remainder = _get_batched_dataset_attributes(batched_dataset)
    545 
    546   # Second, prefetch buffer should be get from the original dataset.
    547   prefetch_buffer = None
    548   if isinstance(dataset, dataset_ops.PrefetchDataset):
    549     prefetch_buffer = dataset._buffer_size
    550   elif (isinstance(dataset, dataset_ops.DatasetV1Adapter)
    551         and isinstance(dataset._dataset, dataset_ops.PrefetchDataset)):
    552     prefetch_buffer = dataset._dataset._buffer_size
    553 
    554   return batch_size, drop_remainder, prefetch_buffer
    555 
    556 
    557 class MultiStepContext(object):
    558   """A context object that can be used to capture things when running steps.
    559 
    560   This context object is useful when running multiple steps at a time using the
    561   `experimental_run_steps_on_iterator` API. For e.g. it allows the user's step
    562   function to specify which outputs to emit at what frequency. Currently it
    563   supports capturing output from the last step, as well as capturing non tensor
    564   outputs.  In the future it will be augmented to support other use cases such
    565   as output each N steps.
    566   """
    567 
    568   def __init__(self):
    569     """Initialize an output context.
    570 
    571     Returns:
    572       A context object.
    573     """
    574     self._last_step_outputs = {}
    575     self._last_step_outputs_reduce_ops = {}
    576     self._non_tensor_outputs = {}
    577 
    578   @property
    579   def last_step_outputs(self):
    580     """A dictionary consisting of outputs to be captured on last step.
    581 
    582     Keys in the dictionary are names of tensors to be captured, as specified
    583     when `set_last_step_output` is called.
    584     Values in the dictionary are the tensors themselves. If
    585     `set_last_step_output` was called with a `reduce_op` for this output,
    586     then the value is the reduced value.
    587 
    588     Returns:
    589       A dictionary with last step outputs.
    590     """
    591     return self._last_step_outputs
    592 
    593   def _set_last_step_outputs(self, outputs):
    594     """Replace the entire dictionary of last step outputs."""
    595     if not isinstance(outputs, dict):
    596       raise ValueError("Need a dictionary to set last_step_outputs.")
    597     self._last_step_outputs = outputs
    598 
    599   def set_last_step_output(self, name, output, reduce_op=None):
    600     """Set `output` with `name` to be outputted from the last step.
    601 
    602     Args:
    603       name: String, name to identify the output. Doesn't need to match tensor
    604         name.
    605       output: The tensors that should be outputted with `name`. See below for
    606         actual types supported.
    607       reduce_op: Reduction method to use to reduce outputs from multiple
    608         replicas. Required if `set_last_step_output` is called in a replica
    609         context. Optional in cross_replica_context.
    610         When present, the outputs from all the replicas are reduced using the
    611         current distribution strategy's `reduce` method. Hence, the type of
    612         `output` must be what's supported by the corresponding `reduce` method.
    613         For e.g. if using MirroredStrategy and reduction is set, output
    614         must be a `PerReplica` value.
    615         The reduce method is also recorded in a dictionary
    616         `_last_step_outputs_reduce_ops` for later interpreting of the
    617         outputs as already reduced or not.
    618     """
    619     if distribution_strategy_context.in_cross_replica_context():
    620       self._last_step_outputs_reduce_ops[name] = reduce_op
    621       if reduce_op is None:
    622         self._last_step_outputs[name] = output
    623       else:
    624         distribution = distribution_strategy_context.get_strategy()
    625         self._last_step_outputs[name] = distribution.reduce(reduce_op, output)
    626     else:
    627       assert reduce_op is not None
    628       def merge_fn(distribution, value):
    629         self._last_step_outputs[name] = distribution.reduce(reduce_op, value)
    630         # Setting this inside the `merge_fn` because all replicas share the same
    631         # context object, so it's more robust to set it only once (even if all
    632         # the replicas are trying to set the same value).
    633         self._last_step_outputs_reduce_ops[name] = reduce_op
    634 
    635       distribution_strategy_context.get_replica_context().merge_call(
    636           merge_fn, args=(output,))
    637 
    638   @property
    639   def non_tensor_outputs(self):
    640     """A dictionary consisting of any non tensor outputs to be captured."""
    641     return self._non_tensor_outputs
    642 
    643   def set_non_tensor_output(self, name, output):
    644     """Set `output` with `name` to be captured as a non tensor output."""
    645     if distribution_strategy_context.in_cross_replica_context():
    646       self._non_tensor_outputs[name] = output
    647     else:
    648       def merge_fn(distribution, value):
    649         # NOTE(priyag): For non tensor outputs, we simply return all the values
    650         # in a list as reduction doesn't make sense on non tensors.
    651         self._non_tensor_outputs[name] = (
    652             distribution.experimental_local_results(value))
    653       distribution_strategy_context.get_replica_context().merge_call(
    654           merge_fn, args=(output,))
    655