Home | History | Annotate | Download | only in engine
      1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Utilities related to distributed training."""
     16 # pylint:disable=protected-access
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.python.data.ops import dataset_ops
     24 from tensorflow.python.data.ops import iterator_ops
     25 from tensorflow.python.distribute import distribute_coordinator_context as dc_context
     26 from tensorflow.python.distribute import reduce_util
     27 from tensorflow.python.eager import context
     28 from tensorflow.python.framework import dtypes
     29 from tensorflow.python.framework import ops
     30 from tensorflow.python.framework import tensor_util
     31 from tensorflow.python.keras import backend as K
     32 from tensorflow.python.keras import callbacks
     33 from tensorflow.python.keras import metrics as metrics_module
     34 from tensorflow.python.keras import optimizers
     35 from tensorflow.python.keras.engine import training_utils
     36 from tensorflow.python.keras.optimizer_v2 import optimizer_v2
     37 from tensorflow.python.keras.utils.mode_keys import ModeKeys
     38 from tensorflow.python.ops import control_flow_ops
     39 from tensorflow.python.ops import math_ops
     40 from tensorflow.python.ops import variables
     41 from tensorflow.python.platform import tf_logging as logging
     42 from tensorflow.python.util import nest
     43 from tensorflow.python.util import tf_contextlib
     44 
     45 
     46 def set_weights(distribution_strategy, dist_model, weights):
     47   """Sets the weights of the replicated models.
     48 
     49   The weights of the replicated models are set to the weights of the original
     50   model. The weights of the replicated model are Mirrored variables and hence
     51   we need to use the `update` call within a DistributionStrategy scope.
     52 
     53   Args:
     54     distribution_strategy: DistributionStrategy used to distribute training
     55         and validation.
     56     dist_model: The replicated models on the different devices.
     57     weights: The weights of the original model.
     58   """
     59   assign_ops = []
     60   for layer in dist_model.layers:
     61     num_param = len(layer.weights)
     62     layer_weights = weights[:num_param]
     63     for sw, w in zip(layer.weights, layer_weights):
     64       if ops.executing_eagerly_outside_functions():
     65         sw.assign(w)
     66       else:
     67         assign_ops.append(distribution_strategy.unwrap(sw.assign(w)))
     68     weights = weights[num_param:]
     69 
     70   if not ops.executing_eagerly_outside_functions():
     71     K.get_session(assign_ops).run(assign_ops)
     72 
     73 
     74 def unwrap_values(distribution_strategy, grouped_inputs, grouped_outputs,
     75                   grouped_updates=None, grouped_session_args=None,
     76                   with_loss_tensor=False):
     77   """Unwrap and return the list of values contained in the PerDevice parameters.
     78 
     79   This function calls `flatten_perdevice_values` to parse each of the input
     80   parameters into a list of values on the different devices. If we set
     81   `with_loss_tensor` to be True, we also call `reduce` on the list of losses on
     82   the different devices to give us one loss tensor.
     83 
     84   Args:
     85     distribution_strategy: DistributionStrategy used to distribute training and
     86         validation.
     87     grouped_inputs: PerDevice inputs returned from the train or test function
     88         that we ran on each device.
     89     grouped_outputs: PerDevice outputs returned from the train or test function
     90         that we ran on each device.
     91     grouped_updates: PerDevice updates returned from the train or test function
     92         that we ran on each device.
     93     grouped_session_args: PerDevice session args returned from the train or
     94         test function that we ran on each device.
     95     with_loss_tensor: Boolean that indicates if we need to add the reduced loss
     96         tensor as one of the outputs.
     97 
     98   Returns:
     99     Values of each of the PerDevice parameters.
    100 
    101   """
    102   # Unwrap per device values returned from each model's train function.
    103   # This will be used to construct the main train function.
    104   all_inputs = flatten_perdevice_values(distribution_strategy,
    105                                         grouped_inputs)
    106   if with_loss_tensor:
    107     # reduce loss tensor before adding it to the list of fetches
    108     loss = distribution_strategy.reduce(reduce_util.ReduceOp.SUM,
    109                                         grouped_outputs[0])
    110     all_outputs = flatten_perdevice_values(distribution_strategy,
    111                                            grouped_outputs[1:])
    112     all_outputs = [loss] + all_outputs
    113   else:
    114     all_outputs = flatten_perdevice_values(distribution_strategy,
    115                                            grouped_outputs)
    116 
    117   if grouped_updates:
    118     all_updates = flatten_perdevice_values(distribution_strategy,
    119                                            grouped_updates)
    120   else:
    121     all_updates = None
    122 
    123   all_session_args = {}
    124   if grouped_session_args:
    125     grouped_feed_dict = grouped_session_args.get('feed_dict')
    126     if grouped_feed_dict:
    127       all_session_args['feed_dict'] = flatten_perdevice_values(
    128           distribution_strategy, grouped_feed_dict)
    129 
    130     grouped_fetches = grouped_session_args.get('fetches')
    131     if grouped_fetches:
    132       all_session_args['fetches'] = flatten_perdevice_values(
    133           distribution_strategy, grouped_fetches)
    134 
    135   # TODO(priyag): Return only non empty/None values
    136   return all_inputs, all_outputs, all_updates, all_session_args
    137 
    138 
    139 def flatten_perdevice_values(distribution_strategy, perdevice_values):
    140   """Unwraps and flattens a nest of PerDevice parameters.
    141 
    142   PerDevice values have one value associated with each device. Each entry in
    143   the PerDevice dict has a device `key` and the corresponding value on the
    144   device as the `value`. In this function we take a PerDevice value or a list of
    145   PerDevice values and return all the values in the PerDevice dict.
    146 
    147   Args:
    148     distribution_strategy: DistributionStrategy used to distribute training and
    149         validation.
    150     perdevice_values: List of PerDevice object or a single PerDevice object.
    151 
    152   Returns:
    153     List of values of all the PerDevice objects.
    154 
    155   """
    156   # This function takes a PerDevice object or a list of PerDevice objects and
    157   # returns all the values associated with it.
    158   return [e for flattened in nest.flatten(perdevice_values)
    159           for e in distribution_strategy.unwrap(flattened)]
    160 
    161 
    162 def validate_callbacks(input_callbacks, optimizer):
    163   """Validate whether given callbacks are supported by DistributionStrategy.
    164 
    165   Args:
    166     input_callbacks: List of callbacks passed by the user to fit.
    167     optimizer: Optimizer instance used to train the model.
    168 
    169   Raises:
    170     ValueError: If `LearningRateScheduler` or `ReduceLROnPlateau` is one of the
    171         callbacks passed.
    172     ValueError: If `histogram_freq` or `write_grads` is one of the parameters
    173         passed as part of the TensorBoard callback.
    174   """
    175   if input_callbacks:
    176     for callback in input_callbacks:
    177       if callback not in [callbacks.TensorBoard, callbacks.ReduceLROnPlateau,
    178                           callbacks.LearningRateScheduler, callbacks.CSVLogger,
    179                           callbacks.EarlyStopping, callbacks.ModelCheckpoint,
    180                           callbacks.TerminateOnNaN, callbacks.ProgbarLogger,
    181                           callbacks.History, callbacks.RemoteMonitor]:
    182         logging.warning('Your input callback is not one of the predefined '
    183                         'Callbacks that supports DistributionStrategy. You '
    184                         'might encounter an error if you access one of the '
    185                         'model\'s attributes as part of the callback since '
    186                         'these attributes are not set. You can access each of '
    187                         'the individual distributed models using the '
    188                         '`_grouped_model` attribute of your original model.')
    189       if isinstance(callback, (callbacks.LearningRateScheduler,
    190                                callbacks.ReduceLROnPlateau)):
    191 
    192         if not isinstance(optimizer, optimizer_v2.OptimizerV2):
    193           raise ValueError('You must specify a Keras Optimizer V2 when using '
    194                            '%s callback with DistributionStrategy.' % callback)
    195 
    196       # If users want to use the TensorBoard callback they cannot use certain
    197       # features of the callback that involve accessing model attributes and
    198       # running ops.
    199       if isinstance(callback, callbacks.TensorBoard):
    200         if getattr(callback, 'histogram_freq', False):
    201           logging.warning(
    202               UserWarning(
    203                   '`histogram_freq` in the TensorBoard callback is not '
    204                   'supported when using DistributionStrategy. Setting '
    205                   '`histogram_freq` to `0`.'))
    206           callback.histogram_freq = 0
    207         if getattr(callback, 'write_grads', False):
    208           logging.warning(
    209               UserWarning(
    210                   '`write_grads` in the TensorBoard callback is not supported '
    211                   'when using DistributionStrategy. Setting `write_grads` '
    212                   'to `False`.'))
    213           callback.histogram_freq = False
    214 
    215 
    216 def validate_distributed_dataset_inputs(distribution_strategy, x, y,
    217                                         sample_weights=None):
    218   """Validate all the components of a DistributedValue Dataset input.
    219 
    220   Args:
    221     distribution_strategy: The current DistributionStrategy used to call
    222         `fit`/`evaluate`.
    223     x: Input Dataset DistributedValue object. For example, when we use
    224         `MirroredStrategy` this is a PerDevice object with a tensor for each
    225         device set in the dict. x can also be a tuple or dict. The keys of the
    226         dict should match the names of the input layers of the model.
    227     y: Target Dataset DistributedValue object. For example, when we use
    228         `MirroredStrategy` this is a PerDevice object with a tensor for each
    229         device set in the dict. y can also be a tuple or dict. The keys of the
    230         dict should match the names of the output layers of the model.
    231     sample_weights: Sample weights Dataset DistributedValue object. For example,
    232         when we use `MirroredStrategy` this is a PerDevice object with a tensor
    233         for each device set in the dict.
    234 
    235   Returns:
    236     The unwrapped values list of the x and y DistributedValues inputs.
    237 
    238   Raises:
    239     ValueError: If x and y do not have support for being evaluated as tensors.
    240         or if x and y contain elements that are not tensors or if x and y
    241         contain elements that have a shape or dtype mismatch.
    242   """
    243   # If the input and target used to call the model are not dataset tensors,
    244   # we need to raise an error. When using a DistributionStrategy, the input
    245   # and targets to a model should be from a `tf.data.Dataset`.
    246 
    247   # If each element of x and y are not tensors, we cannot standardize and
    248   # validate the input and targets.
    249   x_values_list = validate_per_device_inputs(distribution_strategy, x)
    250 
    251   if y is not None:
    252     y_values_list = validate_per_device_inputs(distribution_strategy, y)
    253   else:
    254     y_values_list = None
    255 
    256   if sample_weights is not None:
    257     sample_weights_list = validate_per_device_inputs(distribution_strategy,
    258                                                      sample_weights)
    259   else:
    260     sample_weights_list = None
    261 
    262   # Return the unwrapped values to avoid calling `unwrap` a second time.
    263   return x_values_list, y_values_list, sample_weights_list
    264 
    265 
    266 def validate_per_device_inputs(distribution_strategy, x):
    267   """Validates PerDevice dataset input list.
    268 
    269   Args:
    270     distribution_strategy: The current DistributionStrategy used to call
    271       `fit`, `evaluate` and `predict`.
    272     x: A list of PerDevice objects that represent the input or
    273       target values.
    274 
    275   Returns:
    276     List containing the first element of each of the PerDevice objects in
    277     the input list.
    278 
    279   Raises:
    280     ValueError: If any of the objects in the `per_device_list` is not a tensor.
    281 
    282   """
    283   # Convert the inputs and targets into a list of PerDevice objects.
    284   per_device_list = nest.flatten(x)
    285   x_values_list = []
    286   for x in per_device_list:
    287     if not tensor_util.is_tensor(x):
    288       raise ValueError('Dataset input to the model should be tensors instead '
    289                        'they are of type {}'.format(type(x)))
    290 
    291     # At this point both x and y contain tensors in the `DistributedValues`
    292     # structure.
    293     x_values = distribution_strategy.unwrap(x)
    294 
    295     # Validate that the shape and dtype of all the elements in x are the same.
    296     validate_all_tensor_shapes(x, x_values)
    297     validate_all_tensor_types(x, x_values)
    298 
    299     x_values_list.append(x_values[0])
    300   return x_values_list
    301 
    302 
    303 def validate_all_tensor_types(x, x_values):
    304   x_dtype = x_values[0].dtype
    305   for i in range(1, len(x_values)):
    306     if x_dtype != x_values[i].dtype:
    307       raise ValueError('Input tensor dtypes do not match for distributed tensor'
    308                        ' inputs {}'.format(x))
    309 
    310 
    311 def validate_all_tensor_shapes(x, x_values):
    312   # Validate that the shape of all the elements in x have the same shape
    313   x_shape = x_values[0].get_shape().as_list()
    314   for i in range(1, len(x_values)):
    315     if x_shape != x_values[i].get_shape().as_list():
    316       raise ValueError('Input tensor shapes do not match for distributed tensor'
    317                        ' inputs {}'.format(x))
    318 
    319 
    320 def _wait_for_variable_initialization(session):
    321   """Utility to wait for variables to be initialized."""
    322   all_variables = K._get_variables(K.get_graph())  # pylint: disable=protected-access
    323   candidate_vars = []
    324   for v in all_variables:
    325     if not getattr(v, '_keras_initialized', False):
    326       candidate_vars.append(v)
    327 
    328   if not candidate_vars:
    329     return
    330 
    331   while True:
    332     is_initialized = session.run(
    333         [variables.is_variable_initialized(v) for v in candidate_vars])
    334     uninitialized_vars = []
    335     for flag, v in zip(is_initialized, candidate_vars):
    336       if not flag:
    337         uninitialized_vars.append(v)
    338       v._keras_initialized = True  # pylint: disable=protected-access
    339     if not uninitialized_vars:
    340       break
    341 
    342 
    343 def init_restore_or_wait_for_variables():
    344   """Initialize or restore variables or wait for variables to be initialized."""
    345   session = K._get_session()  # pylint: disable=protected-access
    346   worker_context = dc_context.get_current_worker_context()
    347   if not worker_context or worker_context.experimental_should_init:
    348     # TODO(yuefengz): if checkpoints exist, restore from checkpoint.
    349     K._initialize_variables(session)  # pylint: disable=protected-access
    350   else:
    351     _wait_for_variable_initialization(session)
    352 
    353 
    354 def validate_inputs(x, y, distribution_strategy, allow_partial_batch=False):
    355   """Validate inputs when using DistributionStrategy.
    356 
    357   Args:
    358     x: Model Inputs.
    359     y: Model Targets.
    360     distribution_strategy: The DistributionStrategy with which the model is
    361       compiled.
    362     allow_partial_batch: Boolean. If false, datasets must have fully
    363       defined shapes.
    364 
    365   Raises:
    366     ValueError: if input is not a Dataset or a numpy array(when we use
    367       MirroredStrategy).
    368   """
    369   if (isinstance(x, iterator_ops.Iterator) or
    370       isinstance(y, iterator_ops.Iterator)):
    371     raise ValueError('`DistributionStrategy` does not support inputs of type '
    372                      'Iterator. You must pass a `tf.data.Dataset` object or a '
    373                      'numpy array as input.')
    374 
    375   if is_tpu_strategy(distribution_strategy):
    376     for i in [x, y]:
    377       if (isinstance(i, dataset_ops.DatasetV2) and not allow_partial_batch):
    378         if not is_dataset_shape_fully_defined(i):
    379           raise ValueError(
    380               'Using TPUs currently requires fully defined shapes. Either use '
    381               'set_shape() on the input tensors or use '
    382               'dataset.batch(..., drop_remainder=True).'
    383               'Found unknown shape in input {}.'.format(i))
    384 
    385 
    386 # TODO(b/118776054): Currently we support global batch size for TPUStrategy and
    387 # core MirroredStrategy only. Remove this check when contrib MirroredStrategy is
    388 # no longer needed.
    389 def global_batch_size_supported(distribution_strategy):
    390   return distribution_strategy.extended._global_batch_size  # pylint: disable=protected-access
    391 
    392 
    393 # TODO(sourabhbajaj): Remove this once we use the same API for all strategies.
    394 def is_tpu_strategy(strategy):
    395   """We're executing TPU Strategy."""
    396   return strategy is not None and strategy.__class__.__name__ == 'TPUStrategy'
    397 
    398 
    399 def is_dataset_shape_fully_defined(dataset):
    400   """Returns whether a dataset contains a final partial batch."""
    401   shapes = nest.flatten(dataset_ops.get_legacy_output_shapes(dataset))
    402   unknown_shapes = [s for s in shapes if not s.is_fully_defined()]
    403   return not unknown_shapes
    404 
    405 
    406 def get_input_params(distribution_strategy, first_x_value, steps, batch_size,
    407                      mode=None):
    408   """Calculate the number of batches and steps/steps_per_epoch.
    409 
    410   Args:
    411     distribution_strategy: The DistributionStrategy used to compile the model.
    412     first_x_value: This is the first input numpy array that is passed in as the
    413       model input.
    414     steps:  The specified number of steps.
    415     batch_size: The specified batch_size.
    416     mode: ModeKey representing whether input will be used for training,
    417       evaluation, or prediction. This is used to relax the constraints on
    418       consuming all the training samples to keep compatibility till we
    419       support partial batches. If none, then partial batches are not allowed.
    420 
    421   Returns:
    422     steps: The steps or steps_per_epoch argument depending on if a user is
    423         calling `fit`, `evaluate` or `predict`. If the is_training flag is set
    424         we don't require the number of samples to be used completely.
    425     batch_size: The batch size to be used in model iterations.
    426 
    427   Raises:
    428     ValueError: If the number of batches or steps evaluates to 0.
    429 
    430   """
    431   num_samples = first_x_value.shape[0]
    432   # TODO(b/118776054): Use global batch size for Keras/DS support.
    433   # Currently this is only supported in TPUStrategy and CoreMirroredStrategy.
    434   use_per_replica_batch = not global_batch_size_supported(
    435       distribution_strategy)
    436 
    437   # Partial batches are allowed for training as we repeat the
    438   # dataset when converting numpy arrays into a dataset.
    439   # For other modes uneven batch sizes are not allowed except
    440   # for `predict()` on TPUStrategy.
    441   allow_partial_batch = (mode == ModeKeys.TRAIN or
    442                          (mode == ModeKeys.PREDICT
    443                           and is_tpu_strategy(distribution_strategy)))
    444 
    445   if steps is None:
    446     if batch_size is None:
    447       # If neither the batch size or number of steps are set. We choose the
    448       # global batch size as the minimum of number of samples and 32. 32 is
    449       # chosen to provide backward compatibility.
    450       global_batch_size = min(num_samples, 32)
    451     else:
    452       # If the user provided the batch size we need to handle the case
    453       # between different strategies that use the global/per-replica batch size
    454       global_batch_size = batch_size
    455       if use_per_replica_batch:
    456         global_batch_size *= distribution_strategy.num_replicas_in_sync
    457     if allow_partial_batch:
    458       steps = np.ceil(num_samples / global_batch_size).astype(int)
    459     else:
    460       if num_samples % global_batch_size:
    461         raise ValueError('The number of samples %s is not divisible by '
    462                          'batch size %s.' % (num_samples, global_batch_size))
    463       steps = num_samples // global_batch_size
    464   else:
    465     if batch_size is None:
    466       # We calculate the batch size based on the number of steps specified
    467       if num_samples % steps:
    468         raise ValueError('The number of samples %s is not divisible by '
    469                          'steps %s. Please change the number of steps to a '
    470                          'value that can consume all the samples' % (
    471                              num_samples, steps))
    472       global_batch_size = num_samples // steps
    473     else:
    474       # If the user provided the batch size we need to handle the case
    475       # between different strategies that use the global/per-replica batch size
    476       global_batch_size = batch_size
    477       if use_per_replica_batch:
    478         global_batch_size *= distribution_strategy.num_replicas_in_sync
    479 
    480       min_num_samples = global_batch_size * steps
    481       if allow_partial_batch:
    482         min_num_samples = global_batch_size * (steps-1) + 1 if steps > 1 else 0
    483 
    484       if num_samples < min_num_samples:
    485         raise ValueError('Number of samples %s is less than samples required '
    486                          'for specified batch_size %s and steps %s' % (
    487                              num_samples, global_batch_size, steps))
    488 
    489   # We need to return the per replica or global batch size based on the strategy
    490   if use_per_replica_batch:
    491     if global_batch_size % distribution_strategy.num_replicas_in_sync:
    492       raise ValueError(
    493           'The batch size (%s) could not be sharded evenly across the sync '
    494           'replicas (%s) in the distribution strategy.' % (
    495               global_batch_size, distribution_strategy.num_replicas_in_sync))
    496     batch_size = global_batch_size // distribution_strategy.num_replicas_in_sync
    497   else:
    498     batch_size = global_batch_size
    499 
    500   return steps, batch_size
    501 
    502 
    503 def get_batch_dimension(iterator):
    504   shapes = nest.flatten(dataset_ops.get_legacy_output_shapes(iterator))
    505   # Take the batch size from the first element, as it should be the same for
    506   # all.
    507   dims = shapes[0].dims
    508   return dims[0] if dims else None
    509 
    510 
    511 def list_to_tuple(maybe_list):
    512   """Datasets treat lists specially, so switch them to tuples."""
    513   if isinstance(maybe_list, list):
    514     return tuple(maybe_list)
    515   return maybe_list
    516 
    517 
    518 def get_iterator(dataset, distribution_strategy):
    519   with distribution_strategy.scope():
    520     iterator = distribution_strategy.make_dataset_iterator(dataset)
    521   initialize_iterator(iterator, distribution_strategy)
    522   return iterator
    523 
    524 
    525 def initialize_iterator(iterator, distribution_strategy):
    526   with distribution_strategy.scope():
    527     init_op = control_flow_ops.group(iterator.initialize())
    528     if not context.executing_eagerly():
    529       K.get_session((init_op,)).run(init_op)
    530 
    531 
    532 def _get_input_from_iterator(iterator, model):
    533   """Get elements from the iterator and verify the input shape and type."""
    534   next_element = iterator.get_next()
    535 
    536   if len(nest.flatten(next_element)) == len(model.inputs):
    537     x = next_element
    538     y = None
    539     sample_weights = None
    540   elif len(nest.flatten(next_element)) == (len(model.inputs) +
    541                                            len(model.outputs)):
    542     x, y = next_element
    543     sample_weights = None
    544   else:
    545     x, y, sample_weights = next_element
    546 
    547   # Validate that all the elements in x and y are of the same type and shape.
    548   validate_distributed_dataset_inputs(
    549       model._distribution_strategy, x, y, sample_weights)
    550   return x, y, sample_weights
    551 
    552 
    553 def _prepare_feed_values(model, inputs, targets, sample_weights, mode):
    554   """Prepare feed values to the model execution function.
    555 
    556   Arguments:
    557     model: Model to prepare feed values for.
    558     inputs: List or dict of model inputs.
    559     targets: Optional list of model targets.
    560     sample_weights: Optional list of sample weight arrays.
    561     mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
    562 
    563   Returns:
    564     Feed values for the model in the given mode.
    565   """
    566   strategy = model._distribution_strategy
    567   inputs, targets, sample_weights = _get_input_from_iterator(inputs, model)
    568   inputs = flatten_perdevice_values(strategy, inputs)
    569   targets = flatten_perdevice_values(strategy, targets)
    570   # Expand 1-dimensional inputs.
    571   # TODO(b/124535720): Remove once this standarize data logic is shared with
    572   # main flow.
    573   inputs, targets = nest.map_structure(training_utils.standardize_single_array,
    574                                        (inputs, targets))
    575   if mode == ModeKeys.PREDICT:
    576     sample_weights = []
    577     targets = []
    578   else:
    579     sample_weights = [
    580         None for _ in range(len(model.outputs) * strategy.num_replicas_in_sync)
    581     ]
    582   ins = inputs + targets + sample_weights
    583   if mode == ModeKeys.TRAIN and not isinstance(K.symbolic_learning_phase(),
    584                                                int):
    585     ins += [True]
    586   return ins
    587 
    588 
    589 def _custom_compile_for_predict(model):
    590   """Custom compile for TPU predict mode."""
    591   if not model.built:
    592     # Model is not compilable because it does not know its number of inputs
    593     # and outputs, nor their shapes and names. We will compile after the first
    594     # time the model gets called on training data.
    595     return
    596   model._is_compiled = True
    597   model.total_loss = None
    598   model.train_function = None
    599   model.test_function = None
    600   model.predict_function = None
    601 
    602 
    603 def _build_network_on_replica(model, mode, inputs=None, targets=None):
    604   """Build an updated model on replicas.
    605 
    606   We create a new Keras model while sharing the variables from the old graph.
    607   Building a new sub-graph is required since the original keras model creates
    608   placeholders for the input and the output that are not accessible till we
    609   call iterator.get_next() inside the step_fn for `fit`/`evaluate`/`predict`.
    610 
    611   The sharing of weights and layers between the old and the new model gaurantee
    612   that we're using Strategy variables and any updates on either model are
    613   reflected correctly in callbacks and loop iterations.
    614 
    615   We need to make sure we share the optimizers between the old and the new model
    616   as well so that optimizer state is not lost if the user is running fit
    617   multiple times.
    618 
    619   Args:
    620     model: Model to be replicated across Replicas
    621     mode: Which of fit/eval/predict is building the distributed network
    622     inputs: Input variables to be passed to the model
    623     targets: Target tensor to be passed to model.compile
    624 
    625   Returns:
    626     A new model with shared layers with the old model.
    627   """
    628   # Need to do imports here since we run into a circular dependency error.
    629   from tensorflow.python.keras import models  # pylint: disable=g-import-not-at-top
    630   from tensorflow.python.keras.engine import sequential  # pylint: disable=g-import-not-at-top
    631 
    632   # We rely on the internal methods to avoid having share_weights weights in the
    633   # public API.
    634   if isinstance(model, sequential.Sequential):
    635     updated_model = models._clone_sequential_model(model, input_tensors=inputs,
    636                                                    share_weights=True)
    637   else:
    638     updated_model = models._clone_functional_model(model, input_tensors=inputs,
    639                                                    share_weights=True)
    640 
    641   # Recast all low precision outputs back to float32 since we only casted
    642   # the inputs to bfloat16 and not targets. This is done so that we can preserve
    643   # precision when calculating the loss value.
    644   def _upcast_low_precision_outputs(output):
    645     if output.dtype == dtypes.bfloat16:
    646       return math_ops.cast(output, dtypes.float32)
    647     else:
    648       return output
    649   updated_model.outputs = [_upcast_low_precision_outputs(o)
    650                            for o in updated_model.outputs]
    651 
    652   if isinstance(targets, tuple):
    653     targets = nest.flatten(targets)
    654 
    655   if mode == ModeKeys.PREDICT and inputs is not None:  # TPU predict case
    656     _custom_compile_for_predict(updated_model)
    657   else:
    658     updated_model.compile(
    659         model.optimizer,
    660         model.loss,
    661         metrics=metrics_module.clone_metrics(model._compile_metrics),
    662         loss_weights=model.loss_weights,
    663         sample_weight_mode=model.sample_weight_mode,
    664         weighted_metrics=metrics_module.clone_metrics(
    665             model._compile_weighted_metrics),
    666         target_tensors=targets)
    667   return updated_model
    668 
    669 
    670 def _build_distributed_network(model, strategy, mode, inputs=None,
    671                                targets=None):
    672   """Create a cloned model on each replica."""
    673   with K.get_graph().as_default(), strategy.scope():
    674     distributed_model = strategy.extended.call_for_each_replica(
    675         _build_network_on_replica,
    676         args=(model, mode, inputs, targets))
    677     set_distributed_model(model, mode, distributed_model)
    678 
    679 
    680 def _clone_and_build_model(model, mode, inputs=None, targets=None):
    681   """Clone and build the given keras_model."""
    682   # We need to set the import here since we run into a circular dependency
    683   # error.
    684   from tensorflow.python.keras import models  # pylint: disable=g-import-not-at-top
    685   cloned_model = models.clone_model(model, input_tensors=inputs)
    686 
    687   # Compile and build model.
    688   if isinstance(model.optimizer, optimizers.TFOptimizer):
    689     optimizer = model.optimizer
    690   else:
    691     optimizer_config = model.optimizer.get_config()
    692     optimizer = model.optimizer.__class__.from_config(optimizer_config)
    693 
    694   # Recast all low precision outputs back to float32 since we only casted
    695   # the inputs to bfloat16 and not targets. This is done so that we can preserve
    696   # precision when calculating the loss value.
    697   def _upcast_low_precision_outputs(output):
    698     if output.dtype == dtypes.bfloat16:
    699       return math_ops.cast(output, dtypes.float32)
    700     else:
    701       return output
    702   cloned_model.outputs = [_upcast_low_precision_outputs(o)
    703                           for o in cloned_model.outputs]
    704 
    705   if isinstance(targets, tuple):
    706     targets = nest.flatten(targets)
    707   if mode == ModeKeys.PREDICT and inputs is not None:  # TPU predict case
    708     _custom_compile_for_predict(cloned_model)
    709   else:
    710     cloned_model.compile(
    711         optimizer,
    712         model.loss,
    713         metrics=metrics_module.clone_metrics(model._compile_metrics),
    714         loss_weights=model.loss_weights,
    715         sample_weight_mode=model.sample_weight_mode,
    716         weighted_metrics=metrics_module.clone_metrics(
    717             model._compile_weighted_metrics),
    718         target_tensors=targets)
    719   return cloned_model
    720 
    721 
    722 def clone_model_on_replicas(model, strategy, mode, inputs=None, targets=None):
    723   """Create a cloned model on each replica."""
    724   with K.get_graph().as_default(), strategy.scope():
    725     distributed_model = strategy.extended.call_for_each_replica(
    726         _clone_and_build_model, args=(model, mode, inputs, targets))
    727     set_distributed_model(model, mode, distributed_model)
    728   if mode == ModeKeys.TRAIN:
    729     model._make_callback_model(distributed_model)
    730 
    731 
    732 def _make_execution_function(model, mode):
    733   """Makes or reuses function to run one step of distributed model execution."""
    734   strategy = model._distribution_strategy
    735 
    736   distributed_model = get_distributed_model(model, mode)
    737   # If distributed model for a particular `mode` is already built, use the
    738   # `_distribution_function` on that distributed model.
    739   if distributed_model:
    740     return distributed_model._distributed_function
    741 
    742   # If distributed_model is not built, create one for `mode`.
    743   if model._compile_distribution:
    744     clone_model_on_replicas(model, strategy, mode)
    745   else:
    746     _build_distributed_network(model, strategy, mode)
    747 
    748   # We've just created the distributed model. So `distributed_model` should be
    749   # not None.
    750   distributed_model = get_distributed_model(model, mode)
    751   assert distributed_model
    752 
    753   # Also create an execution fuction on that distributed model.
    754   if context.executing_eagerly():
    755     distributed_function = _make_eager_execution_function(model, mode)
    756   else:
    757     distributed_function = _make_graph_execution_function(model, mode)
    758 
    759   # We cache the distributed execution function on the model since creating
    760   # distributed models and exection functions are expensive.
    761   distributed_model._distributed_function = distributed_function
    762   return distributed_function
    763 
    764 
    765 def _make_graph_execution_function(model, mode):
    766   """Makes function to run one step of distributed model in graph mode."""
    767 
    768   def _per_device_function(model):
    769     f = model._make_execution_function(mode)
    770     return (f.inputs, f.outputs, f.updates_op, f.session_kwargs)
    771 
    772   strategy = model._distribution_strategy
    773   with strategy.scope():
    774     # Create train ops on each of the devices when we call
    775     # `_per_device_fit_function`.
    776     (grouped_inputs, grouped_outputs, grouped_updates,
    777      grouped_session_args) = strategy.extended.call_for_each_replica(
    778          _per_device_function, args=(get_distributed_model(model, mode),))
    779 
    780     # Initialize the variables in the replicated model. This is necessary for
    781     # multi-worker training because on some workers, initialization is not
    782     # needed. This method does initialization or waiting for initialization
    783     # according to the context object of distribute coordinator.
    784     init_restore_or_wait_for_variables()
    785 
    786     # Unwrap all the per device values returned from `call_for_each_replica`.
    787     # Unwrapping per device values gives you a list of values that can be
    788     # used to construct a new train function that is composed of update ops on
    789     # all the devices over which the model is distributed.
    790     (all_inputs, all_outputs, all_updates, all_session_args) = unwrap_values(
    791         strategy,
    792         grouped_inputs,
    793         grouped_outputs,
    794         grouped_updates,
    795         grouped_session_args,
    796         with_loss_tensor=(mode != ModeKeys.PREDICT))
    797 
    798     return K.function(
    799         all_inputs,
    800         all_outputs,
    801         updates=all_updates,
    802         name='distributed_{}_function'.format(mode),
    803         **all_session_args)
    804 
    805 
    806 def _make_eager_execution_function(model, mode):
    807   """Makes function to run one step of distributed model eager execution."""
    808   def _per_device_function(model):
    809     f = model._make_execution_function(mode)
    810     return (f.inputs, f.outputs)
    811 
    812   # NOTE(priyag): Try creating a new FuncGraph within DS scope instead of using
    813   # the global one.
    814   strategy = model._distribution_strategy
    815   global_graph = K.get_graph()
    816 
    817   with global_graph.as_default(), strategy.scope():
    818     # First we gather the relevant portions of the model across all replicas.
    819     # `K._scratch_graph(global_graph)` signals to Keras that it should not
    820     # lift to a separate graph when creating the per-replica functions.
    821     with K._scratch_graph(global_graph):
    822       # Create train ops on each of the devices when we call
    823       # `_per_device_fit_function`.
    824       grouped = strategy.extended.call_for_each_replica(
    825           _per_device_function, args=(get_distributed_model(model, mode),))
    826       grouped_inputs, grouped_outputs = grouped
    827 
    828       # Unwrap all the per device values returned from `call_for_each_replica`.
    829       # Unwrapping per device values gives you a list of values that can be
    830       # used to construct a new train function that is composed of
    831       # inputs/outputs on all the devices over which the model is distributed.
    832       (all_inputs, all_outputs, _, _) = unwrap_values(
    833           strategy,
    834           grouped_inputs,
    835           grouped_outputs,
    836           with_loss_tensor=(mode != ModeKeys.PREDICT))
    837 
    838     # Finally, a joint Keras function is created; this one will be created in
    839     # a separate FuncGraph.
    840     return K.function(
    841         all_inputs,
    842         all_outputs,
    843         name='eager_distributed_{}_function'.format(mode))
    844 
    845 
    846 def _copy_weights_to_distributed_model(original_model, mode):
    847   """Copies weights from original model to distributed models."""
    848   strategy = original_model._distribution_strategy
    849   distributed_model = get_distributed_model(original_model, mode)
    850   if strategy:
    851     # Copy the weights from the original model to each of the replicated
    852     # models.
    853     orig_model_weights = original_model.get_weights()
    854     first_model = strategy.unwrap(distributed_model)[0]
    855     set_weights(strategy, first_model, orig_model_weights)
    856 
    857 
    858 def _copy_weights_to_original_model(model, mode):
    859   """Copies weights from first distributed model back to original model."""
    860   if model._distribution_strategy and mode == ModeKeys.TRAIN:
    861     distributed_model = get_distributed_model(model, mode)
    862     updated_weights = model._distribution_strategy.unwrap(
    863         distributed_model)[0].get_weights()
    864     model.set_weights(updated_weights)
    865 
    866 
    867 def _per_device_aggregate_batch(batch_outs, model, mode):
    868   """Aggregates the per-device batch-level outputs from a distributed step."""
    869   if model._distribution_strategy is not None and mode == ModeKeys.PREDICT:
    870     total_batch_outs = []
    871     for i in range(len(model.outputs)):
    872       num_replicas = model._distribution_strategy.num_replicas_in_sync
    873       nested_outs = batch_outs[i * num_replicas:i * num_replicas + num_replicas]
    874       total_batch_outs.append(np.concatenate(nest.flatten(nested_outs)))
    875     return total_batch_outs
    876   return batch_outs
    877 
    878 
    879 def _reset_metrics(model):
    880   if model._distribution_strategy:
    881     for mode in [ModeKeys.TRAIN, ModeKeys.TEST, ModeKeys.PREDICT]:
    882       distributed_model = get_distributed_model(model, mode)
    883       if distributed_model:
    884         first_model = model._distribution_strategy.unwrap(distributed_model)[0]
    885         first_model.reset_metrics()
    886 
    887 
    888 def get_distributed_model(model, mode):
    889   key = _generate_cache_key(mode)
    890   return model._distributed_model_cache.get(key, None)
    891 
    892 
    893 def set_distributed_model(model, mode, distributed_model):
    894   key = _generate_cache_key(mode)
    895   model._distributed_model_cache[key] = distributed_model
    896 
    897 
    898 def _generate_cache_key(mode):
    899   key = hash(mode)
    900   return key
    901 
    902 
    903 @tf_contextlib.contextmanager
    904 def distributed_scope(strategy, learning_phase):
    905   with strategy.scope(), K.learning_phase_scope(learning_phase):
    906     yield
    907 
    908 
    909 def filter_distributed_callbacks(callbacks_list):
    910   """Filter Callbacks based on the worker context when running multi-worker.
    911 
    912   Arguments:
    913     callbacks_list: A list of `Callback` instances.
    914 
    915   Returns:
    916     The list of `Callback` instances that should be run on this worker.
    917   """
    918 
    919   if not K.in_multi_worker_mode():
    920     raise ValueError(
    921         'filter_distributed_callbacks() should only be called when Keras '
    922         'is in multi worker mode.')
    923 
    924   worker_context = dc_context.get_current_worker_context()
    925   callbacks_list = callbacks_list or []
    926   if not [
    927       c for c in callbacks_list if isinstance(c, callbacks.ModelCheckpoint)
    928   ]:
    929     # TODO(rchao): Consider providing a ModelCheckpoint here if the user
    930     # fails to.
    931     logging.warning('ModelCheckpoint callback is not provided. '
    932                     'Workers will need to restart training if any fails.')
    933   # TODO(rchao): Add similar warning for restoring callback (to be designed).
    934 
    935   if callbacks_list is None or worker_context.is_chief:
    936     return callbacks_list
    937 
    938   # Some Callbacks should only run on the chief worker.
    939   return [
    940       callback for callback in callbacks_list if not callback._chief_worker_only
    941   ]  # pylint: disable=protected-access
    942