Home | History | Annotate | Download | only in engine
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Training-related part of the Keras engine.
     16 """
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import collections
     22 import numpy as np
     23 
     24 from tensorflow.python import tf2
     25 from tensorflow.python.data.ops import dataset_ops
     26 from tensorflow.python.data.ops import iterator_ops
     27 from tensorflow.python.distribute import distribute_coordinator as dc
     28 from tensorflow.python.distribute import distribution_strategy_context
     29 from tensorflow.python.eager import context
     30 from tensorflow.python.framework import ops
     31 from tensorflow.python.framework import tensor_shape
     32 from tensorflow.python.framework import tensor_util
     33 from tensorflow.python.keras import backend as K
     34 from tensorflow.python.keras import losses
     35 from tensorflow.python.keras import metrics as metrics_module
     36 from tensorflow.python.keras import optimizers
     37 from tensorflow.python.keras.engine import distributed_training_utils
     38 from tensorflow.python.keras.engine import network
     39 from tensorflow.python.keras.engine import training_arrays
     40 from tensorflow.python.keras.engine import training_distributed
     41 from tensorflow.python.keras.engine import training_eager
     42 from tensorflow.python.keras.engine import training_generator
     43 from tensorflow.python.keras.engine import training_utils
     44 from tensorflow.python.keras.saving import saving_utils
     45 from tensorflow.python.keras.utils import data_utils
     46 from tensorflow.python.keras.utils import losses_utils
     47 from tensorflow.python.keras.utils.generic_utils import slice_arrays
     48 from tensorflow.python.keras.utils.mode_keys import ModeKeys
     49 from tensorflow.python.ops import math_ops
     50 from tensorflow.python.platform import tf_logging as logging
     51 from tensorflow.python.training.tracking import base as trackable
     52 from tensorflow.python.util import nest
     53 from tensorflow.python.util.tf_export import keras_export
     54 
     55 
     56 @keras_export('keras.models.Model', 'keras.Model')
     57 class Model(network.Network):
     58   """`Model` groups layers into an object with training and inference features.
     59 
     60   There are two ways to instantiate a `Model`:
     61 
     62   1 - With the "functional API", where you start from `Input`,
     63   you chain layer calls to specify the model's forward pass,
     64   and finally you create your model from inputs and outputs:
     65 
     66   ```python
     67   import tensorflow as tf
     68 
     69   inputs = tf.keras.Input(shape=(3,))
     70   x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
     71   outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
     72   model = tf.keras.Model(inputs=inputs, outputs=outputs)
     73   ```
     74 
     75   2 - By subclassing the `Model` class: in that case, you should define your
     76   layers in `__init__` and you should implement the model's forward pass
     77   in `call`.
     78 
     79   ```python
     80   import tensorflow as tf
     81 
     82   class MyModel(tf.keras.Model):
     83 
     84     def __init__(self):
     85       super(MyModel, self).__init__()
     86       self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
     87       self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
     88 
     89     def call(self, inputs):
     90       x = self.dense1(inputs)
     91       return self.dense2(x)
     92 
     93   model = MyModel()
     94   ```
     95 
     96   If you subclass `Model`, you can optionally have
     97   a `training` argument (boolean) in `call`, which you can use to specify
     98   a different behavior in training and inference:
     99 
    100   ```python
    101   import tensorflow as tf
    102 
    103   class MyModel(tf.keras.Model):
    104 
    105     def __init__(self):
    106       super(MyModel, self).__init__()
    107       self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
    108       self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
    109       self.dropout = tf.keras.layers.Dropout(0.5)
    110 
    111     def call(self, inputs, training=False):
    112       x = self.dense1(inputs)
    113       if training:
    114         x = self.dropout(x, training=training)
    115       return self.dense2(x)
    116 
    117   model = MyModel()
    118   ```
    119   """
    120 
    121   def __init__(self, *args, **kwargs):
    122     super(Model, self).__init__(*args, **kwargs)
    123     # initializing _distribution_strategy here since it is possible to call
    124     # predict on a model without compiling it.
    125     self._distribution_strategy = None
    126     # This flag is used to track if the user is using the deprecated path of
    127     # passing distribution strategy to compile rather than creating the model
    128     # under distribution strategy scope.
    129     self._compile_distribution = False
    130 
    131     self.run_eagerly = None
    132 
    133   def get_weights(self):
    134     """Retrieves the weights of the model.
    135 
    136     Returns:
    137         A flat list of Numpy arrays.
    138     """
    139     if self._distribution_strategy:
    140       with self._distribution_strategy.scope():
    141         return super(Model, self).get_weights()
    142     return super(Model, self).get_weights()
    143 
    144   def load_weights(self, filepath, by_name=False):
    145     """Loads all layer weights, either from a TensorFlow or an HDF5 file."""
    146     if distributed_training_utils.is_tpu_strategy(self._distribution_strategy):
    147       if (self._distribution_strategy.extended.steps_per_run > 1 and
    148           (not network._is_hdf5_filepath(filepath))):  # pylint: disable=protected-access
    149         raise ValueError('Load weights is not yet supported with TPUStrategy '
    150                          'with steps_per_run greater than 1.')
    151     return super(Model, self).load_weights(filepath, by_name)
    152 
    153   @trackable.no_automatic_dependency_tracking
    154   def compile(self,
    155               optimizer,
    156               loss=None,
    157               metrics=None,
    158               loss_weights=None,
    159               sample_weight_mode=None,
    160               weighted_metrics=None,
    161               target_tensors=None,
    162               distribute=None,
    163               **kwargs):
    164     """Configures the model for training.
    165 
    166     Arguments:
    167         optimizer: String (name of optimizer) or optimizer instance.
    168             See `tf.keras.optimizers`.
    169         loss: String (name of objective function), objective function or
    170             `tf.losses.Loss` instance. See `tf.losses`. If the model has
    171             multiple outputs, you can use a different loss on each output by
    172             passing a dictionary or a list of losses. The loss value that will
    173             be minimized by the model will then be the sum of all individual
    174             losses.
    175         metrics: List of metrics to be evaluated by the model during training
    176             and testing. Typically you will use `metrics=['accuracy']`.
    177             To specify different metrics for different outputs of a
    178             multi-output model, you could also pass a dictionary, such as
    179             `metrics={'output_a': 'accuracy', 'output_b': ['accuracy', 'mse']}`.
    180             You can also pass a list (len = len(outputs)) of lists of metrics
    181             such as `metrics=[['accuracy'], ['accuracy', 'mse']]` or
    182             `metrics=['accuracy', ['accuracy', 'mse']]`.
    183         loss_weights: Optional list or dictionary specifying scalar
    184             coefficients (Python floats) to weight the loss contributions
    185             of different model outputs.
    186             The loss value that will be minimized by the model
    187             will then be the *weighted sum* of all individual losses,
    188             weighted by the `loss_weights` coefficients.
    189             If a list, it is expected to have a 1:1 mapping
    190             to the model's outputs. If a tensor, it is expected to map
    191             output names (strings) to scalar coefficients.
    192         sample_weight_mode: If you need to do timestep-wise
    193             sample weighting (2D weights), set this to `"temporal"`.
    194             `None` defaults to sample-wise weights (1D).
    195             If the model has multiple outputs, you can use a different
    196             `sample_weight_mode` on each output by passing a
    197             dictionary or a list of modes.
    198         weighted_metrics: List of metrics to be evaluated and weighted
    199             by sample_weight or class_weight during training and testing.
    200         target_tensors: By default, Keras will create placeholders for the
    201             model's target, which will be fed with the target data during
    202             training. If instead you would like to use your own
    203             target tensors (in turn, Keras will not expect external
    204             Numpy data for these targets at training time), you
    205             can specify them via the `target_tensors` argument. It can be
    206             a single tensor (for a single-output model), a list of tensors,
    207             or a dict mapping output names to target tensors.
    208         distribute: NOT SUPPORTED IN TF 2.0, please create and compile the
    209             model under distribution strategy scope instead of passing it to
    210             compile.
    211         **kwargs: Any additional arguments.
    212 
    213     Raises:
    214         ValueError: In case of invalid arguments for
    215             `optimizer`, `loss`, `metrics` or `sample_weight_mode`.
    216     """
    217     run_eagerly = kwargs.pop('run_eagerly', None)
    218     if run_eagerly and getattr(self, '_contains_symbolic_tensors', False):
    219       raise ValueError(
    220           'We currently do not support enabling `run_eagerly` on compile if '
    221           '`model.add_loss(tensor)` or `model.add_metric(tensor)` '
    222           'has been called.')
    223 
    224     self._run_eagerly = run_eagerly
    225     optimizer = optimizers.get(optimizer)
    226 
    227     if distribute is not None:
    228       if tf2.enabled():
    229         raise ValueError(
    230             'Distribute argument in compile is not available in TF 2.0 please '
    231             'create the model under the distribution strategy scope.')
    232       logging.warning('Distribute argument in compile is deprecated please '
    233                       'create the model under the distribution strategy scope.')
    234       self._distribution_strategy = distribute
    235       self._compile_distribution = True
    236     else:
    237       if distribution_strategy_context.has_strategy():
    238         # When the user builds the model in the DS scope and cross replica
    239         # context we want distribution strategy to be set but when building the
    240         # replica copies of the models internally we should not be compiling
    241         # with distribution strategy and use the default compilation path.
    242         if distribution_strategy_context.in_cross_replica_context():
    243           self._distribution_strategy = (
    244               distribution_strategy_context.get_strategy())
    245 
    246     # Validate that arguments passed by the user to `compile` are supported by
    247     # DistributionStrategy.
    248     if self._distribution_strategy:
    249       if sample_weight_mode:
    250         raise NotImplementedError('sample_weight_mode is not supported with '
    251                                   'DistributionStrategy.')
    252       if weighted_metrics:
    253         raise NotImplementedError('weighted_metrics is not supported with '
    254                                   'DistributionStrategy.')
    255       if target_tensors:
    256         raise ValueError('target_tensors is not supported with '
    257                          'DistributionStrategy.')
    258 
    259       if run_eagerly:
    260         raise ValueError(
    261             'We currently do not support enabling `run_eagerly` with '
    262             'distribution strategy.')
    263 
    264       if getattr(self, '_contains_symbolic_tensors', False):
    265         raise ValueError(
    266             'We currently do not support compiling the model with distribution '
    267             'strategy if `model.add_loss(tensor)` or `model.add_metric(tensor)`'
    268             ' has been called.')
    269 
    270       if not self.built or not self.inputs or not self.outputs:
    271         raise ValueError(
    272             'We currently do not support distribution strategy with a '
    273             '`Sequential` model that is created without `input_shape`/'
    274             '`input_dim` set in its first layer or a subclassed model.')
    275 
    276     loss = loss or {}
    277 
    278     self.optimizer = optimizer
    279     # We've disabled automatic dependency tracking for this method, but do want
    280     # to add a checkpoint dependency on the optimizer if it's trackable.
    281     if isinstance(self.optimizer, trackable.Trackable):
    282       self._track_trackable(
    283           self.optimizer, name='optimizer', overwrite=True)
    284     self.loss = loss
    285     self._compile_metrics = metrics or []
    286     self.loss_weights = loss_weights
    287     self.sample_weight_mode = sample_weight_mode
    288     self._compile_weighted_metrics = weighted_metrics
    289     if self.run_eagerly and target_tensors is not None:
    290       raise ValueError(
    291           'target_tensors argument is not supported when '
    292           'running a model eagerly.')
    293     self.target_tensors = target_tensors
    294 
    295     # Set DistributionStrategy specific parameters.
    296     self._distributed_model_cache = {}
    297 
    298     if self._distribution_strategy is not None:
    299       # Ensures a Session is created and configured correctly for Distribution
    300       # Strategy.
    301       K.configure_and_create_distributed_session(self._distribution_strategy)
    302     # Initialize model metric attributes.
    303     self._init_metric_attributes()
    304     if not self.built or not self.inputs or not self.outputs:
    305       # Model is not compilable because it does not know its number of inputs
    306       # and outputs, nor their shapes and names. We will compile after the first
    307       # time the model gets called on training data.
    308       return
    309     self._is_compiled = True
    310 
    311     # Prepare list of loss functions, same size of model outputs.
    312     self.loss_functions = training_utils.prepare_loss_functions(
    313         loss, self.output_names)
    314 
    315     self._feed_outputs = []
    316     self._feed_output_names = []
    317     self._feed_output_shapes = []
    318     self._feed_loss_fns = []
    319     # if loss function is None, then this output will be skipped during total
    320     # loss calculation and feed targets preparation.
    321     skip_target_indices = []
    322     skip_target_weighing_indices = []
    323     for i, loss_function in enumerate(self.loss_functions):
    324       if loss_function is None:
    325         skip_target_indices.append(i)
    326         skip_target_weighing_indices.append(i)
    327 
    328     # Prepare output masks.
    329     if not self.run_eagerly:
    330       masks = [getattr(x, '_keras_mask', None) for x in self.outputs]
    331 
    332     # Prepare list loss weights, same size of model outputs.
    333     self.loss_weights_list = training_utils.prepare_loss_weights(
    334         self.output_names, loss_weights)
    335 
    336     # Initialization for Eager mode execution.
    337     if self.run_eagerly:
    338       # Prepare sample weights.
    339       self._set_sample_weight_attributes(sample_weight_mode,
    340                                          skip_target_weighing_indices)
    341       # Save all metric attributes per output of the model.
    342       self._cache_output_metric_attributes(metrics, weighted_metrics)
    343 
    344       if target_tensors is not None:
    345         raise ValueError('target_tensors are not currently supported in Eager '
    346                          'mode.')
    347       self.total_loss = None
    348 
    349       # Set metric attributes on model.
    350       self._set_metric_attributes(skip_target_indices=skip_target_indices)
    351 
    352       self.targets = []
    353       for i in range(len(self.outputs)):
    354         self._feed_output_names.append(self.output_names[i])
    355       self._collected_trainable_weights = self.trainable_weights
    356       return
    357 
    358     with K.get_graph().as_default():
    359       # Prepare targets of model.
    360       self.targets = []
    361       self._feed_targets = []
    362       if target_tensors not in (None, []):
    363         if isinstance(target_tensors, list):
    364           if len(target_tensors) != len(self.outputs):
    365             raise ValueError(
    366                 'When passing a list as `target_tensors`, '
    367                 'it should have one entry per model output. '
    368                 'The model has %s outputs, but you passed target_tensors=%s' %
    369                 (len(self.outputs), target_tensors))
    370         elif isinstance(target_tensors, dict):
    371           for name in target_tensors:
    372             if name not in self.output_names:
    373               raise ValueError(
    374                   'Unknown entry in `target_tensors` '
    375                   'dictionary: "' + name + '". '
    376                   'Only expected the following keys: ' + str(self.output_names))
    377           tmp_target_tensors = []
    378           for name in self.output_names:
    379             tmp_target_tensors.append(target_tensors.get(name, None))
    380           target_tensors = tmp_target_tensors
    381         elif tensor_util.is_tensor(target_tensors):
    382           target_tensors = [target_tensors]
    383         else:
    384           raise TypeError('Expected `target_tensors` to be a list or tuple or '
    385                           'dict or a single tensor, but got:', target_tensors)
    386 
    387       for i in range(len(self.outputs)):
    388         if i in skip_target_indices:
    389           self.targets.append(None)
    390         else:
    391           shape = K.int_shape(self.outputs[i])
    392           name = self.output_names[i]
    393           if target_tensors not in (None, []):
    394             target = target_tensors[i]
    395           else:
    396             target = None
    397           if target is None or K.is_placeholder(target):
    398             if target is None:
    399               target_dtype = losses.LABEL_DTYPES_FOR_LOSSES.get(
    400                   self.loss_functions[i],
    401                   K.dtype(self.outputs[i]))
    402 
    403               target = K.placeholder(
    404                   ndim=len(shape),
    405                   name=name + '_target',
    406                   sparse=K.is_sparse(self.outputs[i]),
    407                   dtype=target_dtype)
    408             self._feed_targets.append(target)
    409             self._feed_outputs.append(self.outputs[i])
    410             self._feed_output_names.append(name)
    411             self._feed_output_shapes.append(shape)
    412             self._feed_loss_fns.append(self.loss_functions[i])
    413           else:
    414             skip_target_weighing_indices.append(i)
    415           self.targets.append(target)
    416 
    417       # Prepare sample weights.
    418       self._set_sample_weight_attributes(sample_weight_mode,
    419                                          skip_target_weighing_indices)
    420       # Save all metric attributes per output of the model.
    421       self._cache_output_metric_attributes(metrics, weighted_metrics)
    422 
    423       # Set metric attributes on model.
    424       self._set_metric_attributes(skip_target_indices=skip_target_indices)
    425 
    426       # Invoke metric functions for all the outputs.
    427       self._handle_metrics(
    428           self.outputs,
    429           masks=masks,
    430           targets=self.targets,
    431           skip_target_indices=skip_target_indices,
    432           sample_weights=self.sample_weights)
    433 
    434       # Compute total loss.
    435       # Used to keep track of the total loss value (stateless).
    436       # eg., total_loss = loss_weight_1 * output_1_loss_fn(...) +
    437       #                   loss_weight_2 * output_2_loss_fn(...) +
    438       #                   layer losses.
    439       self.total_loss = self._prepare_total_loss(skip_target_indices, masks)
    440 
    441       # Functions for train, test and predict will
    442       # be compiled lazily when required.
    443       # This saves time when the user is not using all functions.
    444       self._function_kwargs = kwargs
    445 
    446       self.train_function = None
    447       self.test_function = None
    448       self.predict_function = None
    449 
    450       # Collected trainable weights, sorted in topological order.
    451       trainable_weights = self.trainable_weights
    452       self._collected_trainable_weights = trainable_weights
    453 
    454       # Validate all variables were correctly created in distribution scope.
    455       if self._distribution_strategy and not self._compile_distribution:
    456         for v in self.variables:
    457           strategy = self._distribution_strategy
    458           if not strategy.extended.variable_created_in_scope(v):
    459             raise ValueError(
    460                 'Variable (%s) was not created in the distribution strategy '
    461                 'scope of (%s). It is most likely due to not all layers or '
    462                 'the model or optimizer being created outside the distribution '
    463                 'strategy scope. Try to make sure your code looks similar '
    464                 'to the following.\n'
    465                 'with strategy.scope():\n'
    466                 '  model=_create_model()\n'
    467                 '  model.compile(...)'% (v, strategy))
    468 
    469   @property
    470   def metrics(self):
    471     """Returns the model's metrics added using `compile`, `add_metric` APIs."""
    472     metrics = []
    473     if self._is_compiled:
    474       metrics += self._compile_metric_functions
    475     return metrics + super(Model, self).metrics
    476 
    477   @property
    478   def metrics_names(self):
    479     """Returns the model's display labels for all outputs."""
    480     metrics_names = []
    481     if self._is_compiled:
    482       metrics_names += self._compile_metrics_names  # Includes names of losses.
    483 
    484     # Add metric names from layers.
    485     for layer in self.layers:
    486       metrics_names += [m.name for m in layer._metrics]  # pylint: disable=protected-access
    487     metrics_names += [m.name for m in self._metrics]
    488     return metrics_names
    489 
    490   @property
    491   def run_eagerly(self):
    492     """Settable attribute indicating whether the model should run eagerly.
    493 
    494     Running eagerly means that your model will be run step by step,
    495     like Python code. Your model might run slower, but it should become easier
    496     for you to debug it by stepping into individual layer calls.
    497 
    498     By default, we will attempt to compile your model to a static graph to
    499     deliver the best execution performance.
    500 
    501     Returns:
    502       Boolean, whether the model should run eagerly.
    503     """
    504     if self._run_eagerly is True and not context.executing_eagerly():
    505       raise ValueError('You can only set `run_eagerly=True` if eager execution '
    506                        'is enabled.')
    507     if not self.dynamic:
    508       if self._run_eagerly is None:
    509         return False
    510       else:
    511         return self._run_eagerly
    512     else:
    513       if not context.executing_eagerly():
    514         raise ValueError('Your model contains layers that can only be '
    515                          'successfully run in eager execution (layers '
    516                          'constructed with `dynamic=True`). '
    517                          'You must enable eager execution with '
    518                          '`tf.enable_eager_execution()`.')
    519       if self._run_eagerly is False:
    520         # TODO(fchollet): consider using py_func to enable this.
    521         raise ValueError('Your model contains layers that can only be '
    522                          'successfully run in eager execution (layers '
    523                          'constructed with `dynamic=True`). '
    524                          'You cannot set `run_eagerly=False`.')
    525       return context.executing_eagerly()
    526 
    527   @run_eagerly.setter
    528   def run_eagerly(self, value):
    529     self._run_eagerly = value
    530 
    531   def fit(self,
    532           x=None,
    533           y=None,
    534           batch_size=None,
    535           epochs=1,
    536           verbose=1,
    537           callbacks=None,
    538           validation_split=0.,
    539           validation_data=None,
    540           shuffle=True,
    541           class_weight=None,
    542           sample_weight=None,
    543           initial_epoch=0,
    544           steps_per_epoch=None,
    545           validation_steps=None,
    546           validation_freq=1,
    547           max_queue_size=10,
    548           workers=1,
    549           use_multiprocessing=False,
    550           **kwargs):
    551     """Trains the model for a fixed number of epochs (iterations on a dataset).
    552 
    553     Arguments:
    554         x: Input data. It could be:
    555           - A Numpy array (or array-like), or a list of arrays
    556             (in case the model has multiple inputs).
    557           - A TensorFlow tensor, or a list of tensors
    558             (in case the model has multiple inputs).
    559           - A dict mapping input names to the corresponding array/tensors,
    560             if the model has named inputs.
    561           - A `tf.data` dataset or a dataset iterator. Should return a tuple
    562             of either `(inputs, targets)` or
    563             `(inputs, targets, sample_weights)`.
    564           - A generator or `keras.utils.Sequence` returning `(inputs, targets)`
    565             or `(inputs, targets, sample weights)`.
    566         y: Target data. Like the input data `x`,
    567           it could be either Numpy array(s) or TensorFlow tensor(s).
    568           It should be consistent with `x` (you cannot have Numpy inputs and
    569           tensor targets, or inversely). If `x` is a dataset, dataset
    570           iterator, generator, or `keras.utils.Sequence` instance, `y` should
    571           not be specified (since targets will be obtained from `x`).
    572         batch_size: Integer or `None`.
    573             Number of samples per gradient update.
    574             If unspecified, `batch_size` will default to 32.
    575             Do not specify the `batch_size` if your data is in the
    576             form of symbolic tensors, dataset, dataset iterators,
    577             generators, or `keras.utils.Sequence` instances (since they generate
    578             batches).
    579         epochs: Integer. Number of epochs to train the model.
    580             An epoch is an iteration over the entire `x` and `y`
    581             data provided.
    582             Note that in conjunction with `initial_epoch`,
    583             `epochs` is to be understood as "final epoch".
    584             The model is not trained for a number of iterations
    585             given by `epochs`, but merely until the epoch
    586             of index `epochs` is reached.
    587         verbose: Integer. 0, 1, or 2. Verbosity mode.
    588             0 = silent, 1 = progress bar, 2 = one line per epoch.
    589         callbacks: List of `keras.callbacks.Callback` instances.
    590             List of callbacks to apply during training.
    591             See `tf.keras.callbacks`.
    592         validation_split: Float between 0 and 1.
    593             Fraction of the training data to be used as validation data.
    594             The model will set apart this fraction of the training data,
    595             will not train on it, and will evaluate
    596             the loss and any model metrics
    597             on this data at the end of each epoch.
    598             The validation data is selected from the last samples
    599             in the `x` and `y` data provided, before shuffling. This argument is
    600             not supported when `x` is a dataset, dataset iterator, generator or
    601            `keras.utils.Sequence` instance.
    602         validation_data: Data on which to evaluate
    603             the loss and any model metrics at the end of each epoch.
    604             The model will not be trained on this data.
    605             `validation_data` will override `validation_split`.
    606             `validation_data` could be:
    607               - tuple `(x_val, y_val)` of Numpy arrays or tensors
    608               - tuple `(x_val, y_val, val_sample_weights)` of Numpy arrays
    609               - dataset or a dataset iterator
    610             For the first two cases, `batch_size` must be provided.
    611             For the last case, `validation_steps` must be provided.
    612         shuffle: Boolean (whether to shuffle the training data
    613             before each epoch) or str (for 'batch').
    614             'batch' is a special option for dealing with the
    615             limitations of HDF5 data; it shuffles in batch-sized chunks.
    616             Has no effect when `steps_per_epoch` is not `None`.
    617         class_weight: Optional dictionary mapping class indices (integers)
    618             to a weight (float) value, used for weighting the loss function
    619             (during training only).
    620             This can be useful to tell the model to
    621             "pay more attention" to samples from
    622             an under-represented class.
    623         sample_weight: Optional Numpy array of weights for
    624             the training samples, used for weighting the loss function
    625             (during training only). You can either pass a flat (1D)
    626             Numpy array with the same length as the input samples
    627             (1:1 mapping between weights and samples),
    628             or in the case of temporal data,
    629             you can pass a 2D array with shape
    630             `(samples, sequence_length)`,
    631             to apply a different weight to every timestep of every sample.
    632             In this case you should make sure to specify
    633             `sample_weight_mode="temporal"` in `compile()`. This argument is not
    634             supported when `x` is a dataset, dataset iterator, generator, or
    635            `keras.utils.Sequence` instance, instead provide the sample_weights
    636             as the third element of `x`.
    637         initial_epoch: Integer.
    638             Epoch at which to start training
    639             (useful for resuming a previous training run).
    640         steps_per_epoch: Integer or `None`.
    641             Total number of steps (batches of samples)
    642             before declaring one epoch finished and starting the
    643             next epoch. When training with input tensors such as
    644             TensorFlow data tensors, the default `None` is equal to
    645             the number of samples in your dataset divided by
    646             the batch size, or 1 if that cannot be determined. If x is a
    647             `tf.data` dataset or a dataset iterator, and 'steps_per_epoch'
    648             is None, the epoch will run until the input dataset is exhausted.
    649         validation_steps: Only relevant if `validation_data` is provided and
    650             is a dataset or dataset iterator. Total number of steps (batches of
    651             samples) to draw before stopping when performing validation
    652             at the end of every epoch. If validation_data is a `tf.data` dataset
    653             or a dataset iterator, and 'validation_steps' is None, validation
    654             will run until the `validation_data` dataset is exhausted.
    655         validation_freq: Only relevant if validation data is provided. Integer
    656             or `collections.Container` instance (e.g. list, tuple, etc.). If an
    657             integer, specifies how many training epochs to run before a new
    658             validation run is performed, e.g. `validation_freq=2` runs
    659             validation every 2 epochs. If a Container, specifies the epochs on
    660             which to run validation, e.g. `validation_freq=[1, 2, 10]` runs
    661             validation at the end of the 1st, 2nd, and 10th epochs.
    662         max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
    663             input only. Maximum size for the generator queue.
    664             If unspecified, `max_queue_size` will default to 10.
    665         workers: Integer. Used for generator or `keras.utils.Sequence` input
    666             only. Maximum number of processes to spin up
    667             when using process-based threading. If unspecified, `workers`
    668             will default to 1. If 0, will execute the generator on the main
    669             thread.
    670         use_multiprocessing: Boolean. Used for generator or
    671             `keras.utils.Sequence` input only. If `True`, use process-based
    672             threading. If unspecified, `use_multiprocessing` will default to
    673             `False`. Note that because this implementation relies on
    674             multiprocessing, you should not pass non-picklable arguments to
    675             the generator as they can't be passed easily to children processes.
    676         **kwargs: Used for backwards compatibility.
    677 
    678     Returns:
    679         A `History` object. Its `History.history` attribute is
    680         a record of training loss values and metrics values
    681         at successive epochs, as well as validation loss values
    682         and validation metrics values (if applicable).
    683 
    684     Raises:
    685         RuntimeError: If the model was never compiled.
    686         ValueError: In case of mismatch between the provided input data
    687             and what the model expects.
    688     """
    689     # Legacy support
    690     if 'nb_epoch' in kwargs:
    691       logging.warning(
    692           'The `nb_epoch` argument in `fit` '
    693           'has been renamed `epochs`.')
    694       epochs = kwargs.pop('nb_epoch')
    695     if kwargs:
    696       raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
    697 
    698     # Case 1: distribution strategy.
    699     if self._distribution_strategy:
    700       if K.in_multi_worker_mode():
    701         # Multi-Worker mode runs the Keras training loop on multiple
    702         # servers via the Distribute Coordinator.
    703         def _worker_fn(_):
    704           """Run training inside the distributed coordinator."""
    705           filtered_callbacks = distributed_training_utils \
    706               .filter_distributed_callbacks(callbacks)
    707           return training_distributed.fit_distributed(
    708               self,
    709               x=x,
    710               y=y,
    711               batch_size=batch_size,
    712               epochs=epochs,
    713               verbose=verbose,
    714               callbacks=filtered_callbacks,
    715               validation_split=validation_split,
    716               validation_data=validation_data,
    717               shuffle=shuffle,
    718               class_weight=class_weight,
    719               sample_weight=sample_weight,
    720               initial_epoch=initial_epoch,
    721               steps_per_epoch=steps_per_epoch,
    722               validation_steps=validation_steps,
    723               validation_freq=validation_freq)
    724 
    725         # Independent worker only for now.
    726         return dc.run_distribute_coordinator(
    727             _worker_fn,
    728             self._distribution_strategy,
    729             mode=dc.CoordinatorMode.INDEPENDENT_WORKER)
    730       else:
    731         return training_distributed.fit_distributed(
    732             self,
    733             x=x,
    734             y=y,
    735             batch_size=batch_size,
    736             epochs=epochs,
    737             verbose=verbose,
    738             callbacks=callbacks,
    739             validation_split=validation_split,
    740             validation_data=validation_data,
    741             shuffle=shuffle,
    742             class_weight=class_weight,
    743             sample_weight=sample_weight,
    744             initial_epoch=initial_epoch,
    745             steps_per_epoch=steps_per_epoch,
    746             validation_steps=validation_steps,
    747             validation_freq=validation_freq)
    748 
    749     batch_size = self._validate_or_infer_batch_size(
    750         batch_size, steps_per_epoch, x)
    751 
    752     # Case 2: generator-like. Input is Python generator, or Sequence object,
    753     # or a non-distributed Dataset or iterator in eager execution.
    754     if data_utils.is_generator_or_sequence(x):
    755       training_utils.check_generator_arguments(
    756           y, sample_weight, validation_split=validation_split)
    757       return self.fit_generator(
    758           x,
    759           steps_per_epoch=steps_per_epoch,
    760           epochs=epochs,
    761           verbose=verbose,
    762           callbacks=callbacks,
    763           validation_data=validation_data,
    764           validation_steps=validation_steps,
    765           validation_freq=validation_freq,
    766           class_weight=class_weight,
    767           max_queue_size=max_queue_size,
    768           workers=workers,
    769           use_multiprocessing=use_multiprocessing,
    770           shuffle=shuffle,
    771           initial_epoch=initial_epoch)
    772     if training_utils.is_eager_dataset_or_iterator(x):
    773       # Make sure that y, sample_weights, validation_split are not passed.
    774       training_utils.validate_dataset_input(x, y, sample_weight,
    775                                             validation_split)
    776       if (isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2))
    777           and shuffle):
    778         training_utils.verify_dataset_shuffled(x)
    779 
    780       return self.fit_generator(
    781           x,
    782           steps_per_epoch=steps_per_epoch,
    783           epochs=epochs,
    784           verbose=verbose,
    785           callbacks=callbacks,
    786           validation_data=validation_data,
    787           validation_steps=validation_steps,
    788           validation_freq=validation_freq,
    789           class_weight=class_weight,
    790           workers=0,
    791           shuffle=shuffle,
    792           initial_epoch=initial_epoch)
    793 
    794     # Case 3: Symbolic tensors or Numpy array-like.
    795     # This includes Datasets and iterators in graph mode (since they
    796     # generate symbolic tensors).
    797     x, y, sample_weights = self._standardize_user_data(
    798         x,
    799         y,
    800         sample_weight=sample_weight,
    801         class_weight=class_weight,
    802         batch_size=batch_size,
    803         check_steps=True,
    804         steps_name='steps_per_epoch',
    805         steps=steps_per_epoch,
    806         validation_split=validation_split,
    807         shuffle=shuffle)
    808 
    809     # Prepare validation data.
    810     if validation_data:
    811       val_x, val_y, val_sample_weights = self._unpack_validation_data(
    812           validation_data)
    813       val_x, val_y, val_sample_weights = self._standardize_user_data(
    814           val_x,
    815           val_y,
    816           sample_weight=val_sample_weights,
    817           batch_size=batch_size,
    818           steps=validation_steps,
    819           steps_name='validation_steps')
    820     elif validation_split and 0. < validation_split < 1.:
    821       if training_utils.has_symbolic_tensors(x):
    822         raise ValueError('If your data is in the form of symbolic tensors, '
    823                          'you cannot use `validation_split`.')
    824       if hasattr(x[0], 'shape'):
    825         split_at = int(x[0].shape[0] * (1. - validation_split))
    826       else:
    827         split_at = int(len(x[0]) * (1. - validation_split))
    828       x, val_x = (slice_arrays(x, 0, split_at), slice_arrays(x, split_at))
    829       y, val_y = (slice_arrays(y, 0, split_at), slice_arrays(y, split_at))
    830       sample_weights, val_sample_weights = (slice_arrays(
    831           sample_weights, 0, split_at), slice_arrays(sample_weights, split_at))
    832     elif validation_steps:
    833       val_x = []
    834       val_y = []
    835       val_sample_weights = []
    836     else:
    837       val_x = None
    838       val_y = None
    839       val_sample_weights = None
    840 
    841     if self.run_eagerly:
    842       return training_generator.fit_generator(
    843           self, (x, y, sample_weights),
    844           steps_per_epoch=steps_per_epoch,
    845           batch_size=batch_size,
    846           epochs=epochs,
    847           verbose=verbose,
    848           callbacks=callbacks,
    849           validation_data=validation_data,
    850           validation_steps=validation_steps,
    851           validation_freq=validation_freq,
    852           workers=0,
    853           shuffle=shuffle,
    854           initial_epoch=initial_epoch,
    855           steps_name='steps_per_epoch')
    856     else:
    857       return training_arrays.fit_loop(
    858           self,
    859           x,
    860           y,
    861           sample_weights=sample_weights,
    862           batch_size=batch_size,
    863           epochs=epochs,
    864           verbose=verbose,
    865           callbacks=callbacks,
    866           val_inputs=val_x,
    867           val_targets=val_y,
    868           val_sample_weights=val_sample_weights,
    869           shuffle=shuffle,
    870           initial_epoch=initial_epoch,
    871           steps_per_epoch=steps_per_epoch,
    872           validation_steps=validation_steps,
    873           validation_freq=validation_freq,
    874           steps_name='steps_per_epoch')
    875 
    876   def evaluate(self,
    877                x=None,
    878                y=None,
    879                batch_size=None,
    880                verbose=1,
    881                sample_weight=None,
    882                steps=None,
    883                callbacks=None,
    884                max_queue_size=10,
    885                workers=1,
    886                use_multiprocessing=False):
    887     """Returns the loss value & metrics values for the model in test mode.
    888 
    889     Computation is done in batches.
    890 
    891     Arguments:
    892         x: Input data. It could be:
    893           - A Numpy array (or array-like), or a list of arrays
    894             (in case the model has multiple inputs).
    895           - A TensorFlow tensor, or a list of tensors
    896             (in case the model has multiple inputs).
    897           - A dict mapping input names to the corresponding array/tensors,
    898             if the model has named inputs.
    899           - A `tf.data` dataset or a dataset iterator.
    900           - A generator or `keras.utils.Sequence` instance.
    901         y: Target data. Like the input data `x`,
    902           it could be either Numpy array(s) or TensorFlow tensor(s).
    903           It should be consistent with `x` (you cannot have Numpy inputs and
    904           tensor targets, or inversely).
    905           If `x` is a dataset, dataset iterator, generator or
    906           `keras.utils.Sequence` instance, `y` should not be specified (since
    907           targets will be obtained from the iterator/dataset).
    908         batch_size: Integer or `None`.
    909             Number of samples per gradient update.
    910             If unspecified, `batch_size` will default to 32.
    911             Do not specify the `batch_size` is your data is in the
    912             form of symbolic tensors, dataset, dataset iterators,
    913             generators, or `keras.utils.Sequence` instances (since they generate
    914             batches).
    915         verbose: 0 or 1. Verbosity mode.
    916             0 = silent, 1 = progress bar.
    917         sample_weight: Optional Numpy array of weights for
    918             the test samples, used for weighting the loss function.
    919             You can either pass a flat (1D)
    920             Numpy array with the same length as the input samples
    921             (1:1 mapping between weights and samples),
    922             or in the case of temporal data,
    923             you can pass a 2D array with shape
    924             `(samples, sequence_length)`,
    925             to apply a different weight to every timestep of every sample.
    926             In this case you should make sure to specify
    927             `sample_weight_mode="temporal"` in `compile()`. This argument is not
    928             supported when `x` is a dataset or a dataset iterator, instead pass
    929             sample weights as the third element of `x`.
    930         steps: Integer or `None`.
    931             Total number of steps (batches of samples)
    932             before declaring the evaluation round finished.
    933             Ignored with the default value of `None`.
    934             If x is a `tf.data` dataset or a dataset iterator, and `steps` is
    935             None, 'evaluate' will run until the dataset is exhausted.
    936         callbacks: List of `keras.callbacks.Callback` instances.
    937             List of callbacks to apply during evaluation.
    938             See [callbacks](/api_docs/python/tf/keras/callbacks).
    939         max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
    940             input only. Maximum size for the generator queue.
    941             If unspecified, `max_queue_size` will default to 10.
    942         workers: Integer. Used for generator or `keras.utils.Sequence` input
    943             only. Maximum number of processes to spin up when using
    944             process-based threading. If unspecified, `workers` will default
    945             to 1. If 0, will execute the generator on the main thread.
    946         use_multiprocessing: Boolean. Used for generator or
    947             `keras.utils.Sequence` input only. If `True`, use process-based
    948             threading. If unspecified, `use_multiprocessing` will default to
    949             `False`. Note that because this implementation relies on
    950             multiprocessing, you should not pass non-picklable arguments to
    951             the generator as they can't be passed easily to children processes.
    952 
    953     Returns:
    954         Scalar test loss (if the model has a single output and no metrics)
    955         or list of scalars (if the model has multiple outputs
    956         and/or metrics). The attribute `model.metrics_names` will give you
    957         the display labels for the scalar outputs.
    958 
    959     Raises:
    960         ValueError: in case of invalid arguments.
    961     """
    962     # Case 1: distribution strategy.
    963     if self._distribution_strategy:
    964       if K.in_multi_worker_mode():
    965         # Multi-Worker mode runs the Keras evaluation loop on multiple
    966         # servers via the Distribute Coordinator.
    967         def _worker_fn(_):
    968           """Run evaluation inside the distributed coordinator."""
    969           filtered_callbacks = distributed_training_utils \
    970               .filter_distributed_callbacks(callbacks)
    971           return training_distributed.evaluate_distributed(
    972               self,
    973               x=x,
    974               y=y,
    975               batch_size=batch_size,
    976               verbose=verbose,
    977               sample_weight=sample_weight,
    978               steps=steps,
    979               callbacks=filtered_callbacks)
    980 
    981         # Independent worker only for now.
    982         return dc.run_distribute_coordinator(
    983             _worker_fn,
    984             self._distribution_strategy,
    985             mode=dc.CoordinatorMode.INDEPENDENT_WORKER)
    986       else:
    987         return training_distributed.evaluate_distributed(
    988             self,
    989             x=x,
    990             y=y,
    991             batch_size=batch_size,
    992             verbose=verbose,
    993             sample_weight=sample_weight,
    994             steps=steps,
    995             callbacks=callbacks)
    996 
    997     batch_size = self._validate_or_infer_batch_size(batch_size, steps, x)
    998 
    999     # Case 2: generator-like. Input is Python generator, or Sequence object,
   1000     # or a non-distributed Dataset or iterator in eager execution.
   1001     if data_utils.is_generator_or_sequence(x):
   1002       training_utils.check_generator_arguments(y, sample_weight)
   1003       return self.evaluate_generator(
   1004           x,
   1005           steps=steps,
   1006           verbose=verbose,
   1007           callbacks=callbacks,
   1008           max_queue_size=max_queue_size,
   1009           workers=workers,
   1010           use_multiprocessing=use_multiprocessing)
   1011     if training_utils.is_eager_dataset_or_iterator(x):
   1012       # Make sure that y, sample_weights are not passed.
   1013       training_utils.validate_dataset_input(x, y, sample_weight)
   1014       return training_generator.evaluate_generator(
   1015           self, x,
   1016           steps=steps,
   1017           batch_size=batch_size,
   1018           verbose=verbose,
   1019           workers=0,
   1020           callbacks=callbacks)
   1021 
   1022     # Case 3: Symbolic tensors or Numpy array-like.
   1023     # This includes Datasets and iterators in graph mode (since they
   1024     # generate symbolic tensors).
   1025     x, y, sample_weights = self._standardize_user_data(
   1026         x,
   1027         y,
   1028         sample_weight=sample_weight,
   1029         batch_size=batch_size,
   1030         check_steps=True,
   1031         steps_name='steps',
   1032         steps=steps)
   1033 
   1034     if self.run_eagerly:
   1035       return training_generator.evaluate_generator(
   1036           self, (x, y, sample_weights),
   1037           steps=steps,
   1038           batch_size=batch_size,
   1039           verbose=verbose,
   1040           workers=0,
   1041           callbacks=callbacks)
   1042     else:
   1043       return training_arrays.test_loop(
   1044           self,
   1045           inputs=x,
   1046           targets=y,
   1047           sample_weights=sample_weights,
   1048           batch_size=batch_size,
   1049           verbose=verbose,
   1050           steps=steps,
   1051           callbacks=callbacks)
   1052 
   1053   def predict(self,
   1054               x,
   1055               batch_size=None,
   1056               verbose=0,
   1057               steps=None,
   1058               callbacks=None,
   1059               max_queue_size=10,
   1060               workers=1,
   1061               use_multiprocessing=False):
   1062     """Generates output predictions for the input samples.
   1063 
   1064     Computation is done in batches.
   1065 
   1066     Arguments:
   1067          x: Input samples. It could be:
   1068           - A Numpy array (or array-like), or a list of arrays
   1069             (in case the model has multiple inputs).
   1070           - A TensorFlow tensor, or a list of tensors
   1071             (in case the model has multiple inputs).
   1072           - A `tf.data` dataset or a dataset iterator.
   1073           - A generator or `keras.utils.Sequence` instance.
   1074         batch_size: Integer or `None`.
   1075             Number of samples per gradient update.
   1076             If unspecified, `batch_size` will default to 32.
   1077             Do not specify the `batch_size` is your data is in the
   1078             form of symbolic tensors, dataset, dataset iterators,
   1079             generators, or `keras.utils.Sequence` instances (since they generate
   1080             batches).
   1081         verbose: Verbosity mode, 0 or 1.
   1082         steps: Total number of steps (batches of samples)
   1083             before declaring the prediction round finished.
   1084             Ignored with the default value of `None`. If x is a `tf.data`
   1085             dataset or a dataset iterator, and `steps` is None, `predict` will
   1086             run until the input dataset is exhausted.
   1087         callbacks: List of `keras.callbacks.Callback` instances.
   1088             List of callbacks to apply during prediction.
   1089             See [callbacks](/api_docs/python/tf/keras/callbacks).
   1090         max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
   1091             input only. Maximum size for the generator queue.
   1092             If unspecified, `max_queue_size` will default to 10.
   1093         workers: Integer. Used for generator or `keras.utils.Sequence` input
   1094             only. Maximum number of processes to spin up when using
   1095             process-based threading. If unspecified, `workers` will default
   1096             to 1. If 0, will execute the generator on the main thread.
   1097         use_multiprocessing: Boolean. Used for generator or
   1098             `keras.utils.Sequence` input only. If `True`, use process-based
   1099             threading. If unspecified, `use_multiprocessing` will default to
   1100             `False`. Note that because this implementation relies on
   1101             multiprocessing, you should not pass non-picklable arguments to
   1102             the generator as they can't be passed easily to children processes.
   1103 
   1104 
   1105     Returns:
   1106         Numpy array(s) of predictions.
   1107 
   1108     Raises:
   1109         ValueError: In case of mismatch between the provided
   1110             input data and the model's expectations,
   1111             or in case a stateful model receives a number of samples
   1112             that is not a multiple of the batch size.
   1113     """
   1114     # Case 1: distribution strategy.
   1115     if self._distribution_strategy:
   1116       return training_distributed.predict_distributed(self,
   1117                                                       x=x,
   1118                                                       batch_size=batch_size,
   1119                                                       verbose=verbose,
   1120                                                       steps=steps,
   1121                                                       callbacks=callbacks)
   1122 
   1123     batch_size = self._validate_or_infer_batch_size(batch_size, steps, x)
   1124 
   1125     # Case 2: generator-like. Input is Python generator, or Sequence object,
   1126     # or a non-distributed Dataset or iterator in eager execution.
   1127     if data_utils.is_generator_or_sequence(x):
   1128       return self.predict_generator(
   1129           x,
   1130           steps=steps,
   1131           verbose=verbose,
   1132           callbacks=callbacks,
   1133           max_queue_size=max_queue_size,
   1134           workers=workers,
   1135           use_multiprocessing=use_multiprocessing)
   1136     if training_utils.is_eager_dataset_or_iterator(x):
   1137       return training_generator.predict_generator(
   1138           self,
   1139           x,
   1140           steps=steps,
   1141           batch_size=batch_size,
   1142           verbose=verbose,
   1143           workers=0,
   1144           callbacks=callbacks)
   1145 
   1146     # Case 3: Symbolic tensors or Numpy array-like.
   1147     # This includes Datasets and iterators in graph mode (since they
   1148     # generate symbolic tensors).
   1149     x, _, _ = self._standardize_user_data(
   1150         x, check_steps=True, steps_name='steps', steps=steps)
   1151 
   1152     if self.run_eagerly:
   1153       return training_generator.predict_generator(
   1154           self,
   1155           x,
   1156           steps=steps,
   1157           batch_size=batch_size,
   1158           verbose=verbose,
   1159           workers=0,
   1160           callbacks=callbacks)
   1161     else:
   1162       return training_arrays.predict_loop(
   1163           self,
   1164           x,
   1165           batch_size=batch_size,
   1166           verbose=verbose,
   1167           steps=steps,
   1168           callbacks=callbacks)
   1169 
   1170   def reset_metrics(self):
   1171     """Resets the state of metrics."""
   1172     if hasattr(self, 'metrics'):
   1173       for m in self.metrics:
   1174         m.reset_states()
   1175 
   1176     # Reset the state of loss metric wrappers.
   1177     if getattr(self, '_output_loss_metrics', None) is not None:
   1178       for m in self._output_loss_metrics:
   1179         m.reset_states()
   1180 
   1181     # Reset metrics on all the distributed (cloned) models.
   1182     if self._distribution_strategy:
   1183       distributed_training_utils._reset_metrics(self)  # pylint: disable=protected-access
   1184 
   1185   def train_on_batch(self,
   1186                      x,
   1187                      y=None,
   1188                      sample_weight=None,
   1189                      class_weight=None,
   1190                      reset_metrics=True):
   1191     """Runs a single gradient update on a single batch of data.
   1192 
   1193     Arguments:
   1194         x: Input data. It could be:
   1195           - A Numpy array (or array-like), or a list of arrays
   1196               (in case the model has multiple inputs).
   1197           - A TensorFlow tensor, or a list of tensors
   1198               (in case the model has multiple inputs).
   1199           - A dict mapping input names to the corresponding array/tensors,
   1200               if the model has named inputs.
   1201           - A `tf.data` dataset or a dataset iterator.
   1202         y: Target data. Like the input data `x`, it could be either Numpy
   1203           array(s) or TensorFlow tensor(s). It should be consistent with `x`
   1204           (you cannot have Numpy inputs and tensor targets, or inversely). If
   1205           `x` is a dataset or a dataset iterator, `y` should not be specified
   1206           (since targets will be obtained from the iterator).
   1207         sample_weight: Optional array of the same length as x, containing
   1208           weights to apply to the model's loss for each sample. In the case of
   1209           temporal data, you can pass a 2D array with shape (samples,
   1210           sequence_length), to apply a different weight to every timestep of
   1211           every sample. In this case you should make sure to specify
   1212           sample_weight_mode="temporal" in compile(). This argument is not
   1213           supported when `x` is a dataset or a dataset iterator.
   1214         class_weight: Optional dictionary mapping class indices (integers) to a
   1215           weight (float) to apply to the model's loss for the samples from this
   1216           class during training. This can be useful to tell the model to "pay
   1217           more attention" to samples from an under-represented class.
   1218         reset_metrics: If `True`, the metrics returned will be only for this
   1219           batch. If `False`, the metrics will be statefully accumulated across
   1220           batches.
   1221 
   1222     Returns:
   1223         Scalar training loss
   1224         (if the model has a single output and no metrics)
   1225         or list of scalars (if the model has multiple outputs
   1226         and/or metrics). The attribute `model.metrics_names` will give you
   1227         the display labels for the scalar outputs.
   1228 
   1229     Raises:
   1230       ValueError: In case of invalid user-provided arguments.
   1231     """
   1232     if self._distribution_strategy:
   1233       raise NotImplementedError('`train_on_batch` is not supported for models '
   1234                                 'compiled with DistributionStrategy.')
   1235     # Validate and standardize user data.
   1236     x, y, sample_weights = self._standardize_user_data(
   1237         x, y, sample_weight=sample_weight, class_weight=class_weight,
   1238         extract_tensors_from_dataset=True)
   1239 
   1240     if self.run_eagerly:
   1241       outputs = training_eager.train_on_batch(
   1242           self,
   1243           x,
   1244           y,
   1245           sample_weights=sample_weights,
   1246           output_loss_metrics=self._output_loss_metrics)
   1247     else:
   1248       x = training_utils.ModelInputs(x).as_list()
   1249       ins = x + (y or []) + (sample_weights or [])
   1250 
   1251       if not isinstance(K.symbolic_learning_phase(), int):
   1252         ins += [True]  # Add learning phase value.
   1253 
   1254       self._make_train_function()
   1255       outputs = self.train_function(ins)  # pylint: disable=not-callable
   1256 
   1257     if reset_metrics:
   1258       self.reset_metrics()
   1259 
   1260     if len(outputs) == 1:
   1261       return outputs[0]
   1262     return outputs
   1263 
   1264   def test_on_batch(self, x, y=None, sample_weight=None, reset_metrics=True):
   1265     """Test the model on a single batch of samples.
   1266 
   1267     Arguments:
   1268         x: Input data. It could be:
   1269           - A Numpy array (or array-like), or a list of arrays
   1270             (in case the model has multiple inputs).
   1271           - A TensorFlow tensor, or a list of tensors
   1272             (in case the model has multiple inputs).
   1273           - A dict mapping input names to the corresponding array/tensors,
   1274             if the model has named inputs.
   1275           - A `tf.data` dataset or a dataset iterator.
   1276         y: Target data. Like the input data `x`,
   1277           it could be either Numpy array(s) or TensorFlow tensor(s).
   1278           It should be consistent with `x` (you cannot have Numpy inputs and
   1279           tensor targets, or inversely). If `x` is a dataset or a
   1280           dataset iterator, `y` should not be specified
   1281           (since targets will be obtained from the iterator).
   1282         sample_weight: Optional array of the same length as x, containing
   1283             weights to apply to the model's loss for each sample.
   1284             In the case of temporal data, you can pass a 2D array
   1285             with shape (samples, sequence_length),
   1286             to apply a different weight to every timestep of every sample.
   1287             In this case you should make sure to specify
   1288             sample_weight_mode="temporal" in compile(). This argument is not
   1289             supported when `x` is a dataset or a dataset iterator.
   1290         reset_metrics: If `True`, the metrics returned will be only for this
   1291           batch. If `False`, the metrics will be statefully accumulated across
   1292           batches.
   1293 
   1294     Returns:
   1295         Scalar test loss (if the model has a single output and no metrics)
   1296         or list of scalars (if the model has multiple outputs
   1297         and/or metrics). The attribute `model.metrics_names` will give you
   1298         the display labels for the scalar outputs.
   1299 
   1300     Raises:
   1301         ValueError: In case of invalid user-provided arguments.
   1302     """
   1303     if self._distribution_strategy:
   1304       raise NotImplementedError('`test_on_batch` is not supported for models '
   1305                                 'compiled with DistributionStrategy.')
   1306     # Validate and standardize user data.
   1307     x, y, sample_weights = self._standardize_user_data(
   1308         x, y, sample_weight=sample_weight, extract_tensors_from_dataset=True)
   1309 
   1310     if self.run_eagerly:
   1311       outputs = training_eager.test_on_batch(
   1312           self,
   1313           x,
   1314           y,
   1315           sample_weights=sample_weights,
   1316           output_loss_metrics=self._output_loss_metrics)
   1317     else:
   1318       x = training_utils.ModelInputs(x).as_list()
   1319       inputs = x + (y or []) + (sample_weights or [])
   1320 
   1321       self._make_test_function()
   1322       outputs = self.test_function(inputs)  # pylint: disable=not-callable
   1323 
   1324     if reset_metrics:
   1325       self.reset_metrics()
   1326 
   1327     if len(outputs) == 1:
   1328       return outputs[0]
   1329     return outputs
   1330 
   1331   def predict_on_batch(self, x):
   1332     """Returns predictions for a single batch of samples.
   1333 
   1334     Arguments:
   1335         x: Input data. It could be:
   1336           - A Numpy array (or array-like), or a list of arrays
   1337             (in case the model has multiple inputs).
   1338           - A TensorFlow tensor, or a list of tensors
   1339             (in case the model has multiple inputs).
   1340           - A `tf.data` dataset or a dataset iterator.
   1341 
   1342     Returns:
   1343         Numpy array(s) of predictions.
   1344 
   1345     Raises:
   1346         ValueError: In case of mismatch between given number of inputs and
   1347           expectations of the model.
   1348     """
   1349     if self._distribution_strategy:
   1350       raise NotImplementedError('`predict_on_batch` is not supported for '
   1351                                 'models compiled with DistributionStrategy.')
   1352     # Validate and standardize user data.
   1353     inputs, _, _ = self._standardize_user_data(
   1354         x, extract_tensors_from_dataset=True)
   1355     if self.run_eagerly:
   1356       if (isinstance(inputs, iterator_ops.EagerIterator) or
   1357           (isinstance(inputs, dataset_ops.DatasetV2))):
   1358         inputs = training_utils.cast_if_floating_dtype(inputs)
   1359       elif isinstance(inputs, collections.Sequence):
   1360         inputs = [
   1361             ops.convert_to_tensor(val, dtype=K.floatx()) for val in inputs]
   1362 
   1363         # Unwrap lists with only one input, as we do when training on batch
   1364         if len(inputs) == 1:
   1365           inputs = inputs[0]
   1366 
   1367       return self(inputs)  # pylint: disable=not-callable
   1368 
   1369     self._make_predict_function()
   1370     outputs = self.predict_function(inputs)
   1371 
   1372     if len(outputs) == 1:
   1373       return outputs[0]
   1374     return outputs
   1375 
   1376   def fit_generator(self,
   1377                     generator,
   1378                     steps_per_epoch=None,
   1379                     epochs=1,
   1380                     verbose=1,
   1381                     callbacks=None,
   1382                     validation_data=None,
   1383                     validation_steps=None,
   1384                     validation_freq=1,
   1385                     class_weight=None,
   1386                     max_queue_size=10,
   1387                     workers=1,
   1388                     use_multiprocessing=False,
   1389                     shuffle=True,
   1390                     initial_epoch=0):
   1391     """Fits the model on data yielded batch-by-batch by a Python generator.
   1392 
   1393     The generator is run in parallel to the model, for efficiency.
   1394     For instance, this allows you to do real-time data augmentation
   1395     on images on CPU in parallel to training your model on GPU.
   1396 
   1397     The use of `keras.utils.Sequence` guarantees the ordering
   1398     and guarantees the single use of every input per epoch when
   1399     using `use_multiprocessing=True`.
   1400 
   1401     Arguments:
   1402         generator: A generator or an instance of `Sequence`
   1403           (`keras.utils.Sequence`)
   1404             object in order to avoid duplicate data
   1405             when using multiprocessing.
   1406             The output of the generator must be either
   1407             - a tuple `(inputs, targets)`
   1408             - a tuple `(inputs, targets, sample_weights)`.
   1409             This tuple (a single output of the generator) makes a single batch.
   1410             Therefore, all arrays in this tuple must have the same length (equal
   1411             to the size of this batch). Different batches may have different
   1412               sizes.
   1413             For example, the last batch of the epoch is commonly smaller than
   1414               the
   1415             others, if the size of the dataset is not divisible by the batch
   1416               size.
   1417             The generator is expected to loop over its data
   1418             indefinitely. An epoch finishes when `steps_per_epoch`
   1419             batches have been seen by the model.
   1420         steps_per_epoch: Total number of steps (batches of samples)
   1421             to yield from `generator` before declaring one epoch
   1422             finished and starting the next epoch. It should typically
   1423             be equal to the number of samples of your dataset
   1424             divided by the batch size.
   1425             Optional for `Sequence`: if unspecified, will use
   1426             the `len(generator)` as a number of steps.
   1427         epochs: Integer, total number of iterations on the data.
   1428         verbose: Verbosity mode, 0, 1, or 2.
   1429         callbacks: List of callbacks to be called during training.
   1430         validation_data: This can be either
   1431             - a generator for the validation data
   1432             - a tuple (inputs, targets)
   1433             - a tuple (inputs, targets, sample_weights).
   1434         validation_steps: Only relevant if `validation_data`
   1435             is a generator. Total number of steps (batches of samples)
   1436             to yield from `generator` before stopping.
   1437             Optional for `Sequence`: if unspecified, will use
   1438             the `len(validation_data)` as a number of steps.
   1439         validation_freq: Only relevant if validation data is provided. Integer
   1440             or `collections.Container` instance (e.g. list, tuple, etc.). If an
   1441             integer, specifies how many training epochs to run before a new
   1442             validation run is performed, e.g. `validation_freq=2` runs
   1443             validation every 2 epochs. If a Container, specifies the epochs on
   1444             which to run validation, e.g. `validation_freq=[1, 2, 10]` runs
   1445             validation at the end of the 1st, 2nd, and 10th epochs.
   1446         class_weight: Dictionary mapping class indices to a weight
   1447             for the class.
   1448         max_queue_size: Integer. Maximum size for the generator queue.
   1449             If unspecified, `max_queue_size` will default to 10.
   1450         workers: Integer. Maximum number of processes to spin up
   1451             when using process-based threading.
   1452             If unspecified, `workers` will default to 1. If 0, will
   1453             execute the generator on the main thread.
   1454         use_multiprocessing: Boolean.
   1455             If `True`, use process-based threading.
   1456             If unspecified, `use_multiprocessing` will default to `False`.
   1457             Note that because this implementation relies on multiprocessing,
   1458             you should not pass non-picklable arguments to the generator
   1459             as they can't be passed easily to children processes.
   1460         shuffle: Boolean. Whether to shuffle the order of the batches at
   1461             the beginning of each epoch. Only used with instances
   1462             of `Sequence` (`keras.utils.Sequence`).
   1463             Has no effect when `steps_per_epoch` is not `None`.
   1464         initial_epoch: Epoch at which to start training
   1465             (useful for resuming a previous training run)
   1466 
   1467     Returns:
   1468         A `History` object.
   1469 
   1470     Example:
   1471 
   1472     ```python
   1473         def generate_arrays_from_file(path):
   1474             while 1:
   1475                 f = open(path)
   1476                 for line in f:
   1477                     # create numpy arrays of input data
   1478                     # and labels, from each line in the file
   1479                     x1, x2, y = process_line(line)
   1480                     yield ({'input_1': x1, 'input_2': x2}, {'output': y})
   1481                 f.close()
   1482 
   1483         model.fit_generator(generate_arrays_from_file('/my_file.txt'),
   1484                             steps_per_epoch=10000, epochs=10)
   1485     ```
   1486     Raises:
   1487         ValueError: In case the generator yields data in an invalid format.
   1488     """
   1489     if self._distribution_strategy:
   1490       raise NotImplementedError('`fit_generator` is not supported for '
   1491                                 'models compiled with DistributionStrategy.')
   1492     return training_generator.fit_generator(
   1493         self,
   1494         generator,
   1495         steps_per_epoch=steps_per_epoch,
   1496         epochs=epochs,
   1497         verbose=verbose,
   1498         callbacks=callbacks,
   1499         validation_data=validation_data,
   1500         validation_steps=validation_steps,
   1501         validation_freq=validation_freq,
   1502         class_weight=class_weight,
   1503         max_queue_size=max_queue_size,
   1504         workers=workers,
   1505         use_multiprocessing=use_multiprocessing,
   1506         shuffle=shuffle,
   1507         initial_epoch=initial_epoch,
   1508         steps_name='steps_per_epoch')
   1509 
   1510   def evaluate_generator(self,
   1511                          generator,
   1512                          steps=None,
   1513                          callbacks=None,
   1514                          max_queue_size=10,
   1515                          workers=1,
   1516                          use_multiprocessing=False,
   1517                          verbose=0):
   1518     """Evaluates the model on a data generator.
   1519 
   1520     The generator should return the same kind of data
   1521     as accepted by `test_on_batch`.
   1522 
   1523     Arguments:
   1524         generator: Generator yielding tuples (inputs, targets)
   1525             or (inputs, targets, sample_weights)
   1526             or an instance of `keras.utils.Sequence`
   1527             object in order to avoid duplicate data
   1528             when using multiprocessing.
   1529         steps: Total number of steps (batches of samples)
   1530             to yield from `generator` before stopping.
   1531             Optional for `Sequence`: if unspecified, will use
   1532             the `len(generator)` as a number of steps.
   1533         callbacks: List of `keras.callbacks.Callback` instances.
   1534             List of callbacks to apply during evaluation.
   1535             See [callbacks](/api_docs/python/tf/keras/callbacks).
   1536         max_queue_size: maximum size for the generator queue
   1537         workers: Integer. Maximum number of processes to spin up
   1538             when using process-based threading.
   1539             If unspecified, `workers` will default to 1. If 0, will
   1540             execute the generator on the main thread.
   1541         use_multiprocessing: Boolean.
   1542             If `True`, use process-based threading.
   1543             If unspecified, `use_multiprocessing` will default to `False`.
   1544             Note that because this implementation relies on multiprocessing,
   1545             you should not pass non-picklable arguments to the generator
   1546             as they can't be passed easily to children processes.
   1547         verbose: Verbosity mode, 0 or 1.
   1548 
   1549     Returns:
   1550         Scalar test loss (if the model has a single output and no metrics)
   1551         or list of scalars (if the model has multiple outputs
   1552         and/or metrics). The attribute `model.metrics_names` will give you
   1553         the display labels for the scalar outputs.
   1554 
   1555     Raises:
   1556         ValueError: in case of invalid arguments.
   1557 
   1558     Raises:
   1559         ValueError: In case the generator yields data in an invalid format.
   1560     """
   1561     if self._distribution_strategy:
   1562       raise NotImplementedError('`evaluate_generator` is not supported for '
   1563                                 'models compiled with DistributionStrategy.')
   1564     return training_generator.evaluate_generator(
   1565         self,
   1566         generator,
   1567         steps=steps,
   1568         max_queue_size=max_queue_size,
   1569         workers=workers,
   1570         use_multiprocessing=use_multiprocessing,
   1571         verbose=verbose,
   1572         callbacks=callbacks)
   1573 
   1574   def predict_generator(self,
   1575                         generator,
   1576                         steps=None,
   1577                         callbacks=None,
   1578                         max_queue_size=10,
   1579                         workers=1,
   1580                         use_multiprocessing=False,
   1581                         verbose=0):
   1582     """Generates predictions for the input samples from a data generator.
   1583 
   1584     The generator should return the same kind of data as accepted by
   1585     `predict_on_batch`.
   1586 
   1587     Arguments:
   1588         generator: Generator yielding batches of input samples
   1589             or an instance of `keras.utils.Sequence` object in order to
   1590             avoid duplicate data when using multiprocessing.
   1591         steps: Total number of steps (batches of samples)
   1592             to yield from `generator` before stopping.
   1593             Optional for `Sequence`: if unspecified, will use
   1594             the `len(generator)` as a number of steps.
   1595         callbacks: List of `keras.callbacks.Callback` instances.
   1596             List of callbacks to apply during prediction.
   1597             See [callbacks](/api_docs/python/tf/keras/callbacks).
   1598         max_queue_size: Maximum size for the generator queue.
   1599         workers: Integer. Maximum number of processes to spin up
   1600             when using process-based threading.
   1601             If unspecified, `workers` will default to 1. If 0, will
   1602             execute the generator on the main thread.
   1603         use_multiprocessing: Boolean.
   1604             If `True`, use process-based threading.
   1605             If unspecified, `use_multiprocessing` will default to `False`.
   1606             Note that because this implementation relies on multiprocessing,
   1607             you should not pass non-picklable arguments to the generator
   1608             as they can't be passed easily to children processes.
   1609         verbose: verbosity mode, 0 or 1.
   1610 
   1611     Returns:
   1612         Numpy array(s) of predictions.
   1613 
   1614     Raises:
   1615         ValueError: In case the generator yields data in an invalid format.
   1616     """
   1617     if self._distribution_strategy:
   1618       raise NotImplementedError('`predict_generator` is not supported for '
   1619                                 'models compiled with DistributionStrategy.')
   1620     return training_generator.predict_generator(
   1621         self,
   1622         generator,
   1623         steps=steps,
   1624         max_queue_size=max_queue_size,
   1625         workers=workers,
   1626         use_multiprocessing=use_multiprocessing,
   1627         verbose=verbose,
   1628         callbacks=callbacks)
   1629 
   1630   def _prepare_total_loss(self, skip_target_indices=None, masks=None):
   1631     """Computes total loss from loss functions.
   1632 
   1633     Arguments:
   1634         skip_target_indices: A list of indices of model outputs where loss
   1635           function is None.
   1636         masks: List of mask values corresponding to each model output.
   1637 
   1638     Returns:
   1639         A list of loss weights of python floats.
   1640 
   1641     Raises:
   1642         TypeError: If model run_eagerly is True.
   1643     """
   1644     if self.run_eagerly:
   1645       raise TypeError('total loss can not be computed when compiled with '
   1646                       'run_eagerly = True.')
   1647     skip_target_indices = skip_target_indices or []
   1648     total_loss = None
   1649     with K.name_scope('loss'):
   1650       zipped_inputs = zip(self.targets, self.outputs, self.loss_functions,
   1651                           self.sample_weights, masks, self.loss_weights_list)
   1652       for i, (y_true, y_pred, loss_fn, sample_weight, mask,
   1653               loss_weight) in enumerate(zipped_inputs):
   1654         if i in skip_target_indices:
   1655           continue
   1656         loss_name = self.output_names[i] + '_loss'
   1657         with K.name_scope(loss_name):
   1658           if mask is not None:
   1659             mask = math_ops.cast(mask, y_pred.dtype)
   1660             # Update weights with mask.
   1661             if sample_weight is None:
   1662               sample_weight = mask
   1663             else:
   1664               # Update dimensions of weights to match with mask if possible.
   1665               mask, _, sample_weight = (
   1666                   losses_utils.squeeze_or_expand_dimensions(
   1667                       mask, None, sample_weight))
   1668               sample_weight *= mask
   1669 
   1670           # Reset reduction on the loss so that we can get the per sample loss
   1671           # value. We use this to get both the stateless and stateful loss
   1672           # values without having to compute the underlying loss function
   1673           # twice.
   1674           weighted_losses = None
   1675           if hasattr(loss_fn, 'reduction'):
   1676             current_loss_reduction = loss_fn.reduction
   1677             loss_fn.reduction = losses_utils.ReductionV2.NONE
   1678             weighted_losses = loss_fn(
   1679                 y_true, y_pred, sample_weight=sample_weight)
   1680             loss_fn.reduction = current_loss_reduction
   1681 
   1682             # Compute the stateless loss value.
   1683             output_loss = losses_utils.reduce_weighted_loss(
   1684                 weighted_losses, reduction=current_loss_reduction)
   1685           else:
   1686             # Compute the stateless loss value for a custom loss class.
   1687             # Here we assume that the class takes care of loss reduction
   1688             # because if this class returns a vector value we cannot
   1689             # differentiate between use case where a custom optimizer
   1690             # expects a vector loss value vs unreduced per-sample loss value.
   1691             output_loss = loss_fn(y_true, y_pred, sample_weight=sample_weight)
   1692 
   1693         if len(self.outputs) > 1:
   1694           # Keep track of stateful result tensor and function for the loss.
   1695           # Compute the stateful loss value.
   1696           if weighted_losses is not None:
   1697             # TODO(b/120571621): Directly call metric when the bug is fixed.
   1698             aggregated_output_loss = self._call_fn_for_each_replica(
   1699                 self._output_loss_metrics[i], weighted_losses)
   1700           else:
   1701             # Custom loss class.
   1702             aggregated_output_loss = self._call_metric_fn(
   1703                 self._output_loss_metrics[i], y_true, y_pred, sample_weight)
   1704           self._compile_metrics_tensors[loss_name] = aggregated_output_loss
   1705 
   1706         if total_loss is None:
   1707           total_loss = loss_weight * output_loss
   1708         else:
   1709           total_loss += loss_weight * output_loss
   1710       if total_loss is None:
   1711         if not self.losses:
   1712           raise ValueError('The model cannot be compiled '
   1713                            'because it has no loss to optimize.')
   1714         else:
   1715           total_loss = 0.
   1716 
   1717       # Add regularization penalties and other layer-specific losses.
   1718       if self.losses:
   1719         total_loss += losses_utils.scale_loss_for_distribution(
   1720             math_ops.add_n(self.losses))
   1721     return total_loss
   1722 
   1723   def _get_callback_model(self):
   1724     """Returns the Callback Model for this Model."""
   1725 
   1726     if hasattr(self, '_replicated_model') and self._replicated_model:
   1727       # When using training_distributed, we set the callback model
   1728       # to an instance of the `DistributedModel` that we create in
   1729       # the `compile` call. The `DistributedModel` is initialized
   1730       # with the first replicated model. We need to set the callback
   1731       # model to a DistributedModel to allow us to override saving
   1732       # and loading weights when we checkpoint the model during training.
   1733       return self._replicated_model
   1734     if hasattr(self, 'callback_model') and self.callback_model:
   1735       return self.callback_model
   1736     return self
   1737 
   1738   def _make_callback_model(self, grouped_model):
   1739     first_replicated_model = self._distribution_strategy.unwrap(
   1740         grouped_model)[0]
   1741     # We initialize the callback model with the first replicated model.
   1742     self._replicated_model = DistributedCallbackModel(first_replicated_model)
   1743     self._replicated_model.set_original_model(self)
   1744 
   1745   def _validate_or_infer_batch_size(self, batch_size, steps, x):
   1746     """Validates that the `batch_size` provided is consistent with InputLayer.
   1747 
   1748     It's possible that the user specified a static batch size in their
   1749     InputLayer. If so, this method checks the provided `batch_size` and `x`
   1750     arguments are consistent with this static batch size. Also, if
   1751     `batch_size` is `None`, this method will attempt to infer the batch size
   1752     from the static batch size of the InputLayer. Lastly, ValueError will be
   1753     raised if `x` is a tf.data.Dataset and `batch_size` is specified as we
   1754     expect users to provide batched datasets.
   1755 
   1756     Arguments:
   1757       batch_size: The batch_size provided as an argument to
   1758         fit/evaluate/predict.
   1759       steps: The steps provided as an argument to fit/evaluate/predict.
   1760       x: The data passed as `x` to fit/evaluate/predict.
   1761 
   1762     Returns:
   1763       The validated batch_size, auto-inferred from the first layer if not
   1764       provided.
   1765     """
   1766     if batch_size is not None and isinstance(x, dataset_ops.DatasetV2):
   1767       raise ValueError('The `batch_size` argument must not be specified when'
   1768                        ' using dataset as an input.')
   1769 
   1770     layers = super(Model, self).layers  # Avoids the override in Sequential.
   1771     if layers:
   1772       first_layer = layers[0]
   1773       static_batch_size = training_utils.get_static_batch_size(first_layer)
   1774       if static_batch_size is not None:
   1775 
   1776         # Check `batch_size` argument is consistent with InputLayer.
   1777         if batch_size is not None and batch_size != static_batch_size:
   1778           raise ValueError('The `batch_size` argument value {} is incompatible '
   1779                            'with the specified batch size of your Input Layer: '
   1780                            '{}'.format(batch_size, static_batch_size))
   1781 
   1782         # Check Dataset/Iterator batch size is consistent with InputLayer.
   1783         if isinstance(x, (dataset_ops.DatasetV2, iterator_ops.Iterator,
   1784                           iterator_ops.EagerIterator)):
   1785           ds_batch_size = tensor_shape.as_dimension(
   1786               nest.flatten(dataset_ops.get_legacy_output_shapes(x))[0][0]).value
   1787           if ds_batch_size is not None and ds_batch_size != static_batch_size:
   1788             raise ValueError('The batch output shape of your `Dataset` is {}, '
   1789                              'which is incompatible with the specified batch '
   1790                              'size of your Input Layer: {}'.format(
   1791                                  ds_batch_size, static_batch_size))
   1792 
   1793         # Set inferred batch size from the InputLayer.
   1794         if steps is None:
   1795           batch_size = static_batch_size
   1796 
   1797     if batch_size is None and steps is None:
   1798       # Backwards compatibility
   1799       batch_size = 32
   1800     return batch_size
   1801 
   1802   def _list_functions_for_serialization(self):
   1803     return {
   1804         '_default_save_signature': saving_utils.trace_model_call(self)
   1805     }
   1806 
   1807   def _set_sample_weight_attributes(self, sample_weight_mode,
   1808                                     skip_target_weighing_indices):
   1809     """Sets sample weight related attributes on the model."""
   1810     sample_weights, sample_weight_modes = training_utils.prepare_sample_weights(
   1811         self.output_names, sample_weight_mode, skip_target_weighing_indices)
   1812     self.sample_weights = sample_weights
   1813     self.sample_weight_modes = sample_weight_modes
   1814     self._feed_sample_weight_modes = [
   1815         sample_weight_modes[i]
   1816         for i in range(len(self.outputs))
   1817         if i not in skip_target_weighing_indices
   1818     ]
   1819     self._feed_sample_weights = [
   1820         sample_weights[i]
   1821         for i in range(len(sample_weights))
   1822         if i not in skip_target_weighing_indices
   1823     ]
   1824 
   1825   def _cache_output_metric_attributes(self, metrics, weighted_metrics):
   1826     """Caches metric name and function attributes for every model output."""
   1827     output_shapes = []
   1828     for output in self.outputs:
   1829       if output is None or output.shape.rank is None:
   1830         output_shapes.append(None)
   1831       else:
   1832         output_shapes.append(output.shape.as_list())
   1833     self._per_output_metrics = training_utils.collect_per_output_metric_info(
   1834         metrics, self.output_names, output_shapes, self.loss_functions)
   1835     self._per_output_weighted_metrics = (
   1836         training_utils.collect_per_output_metric_info(
   1837             weighted_metrics,
   1838             self.output_names,
   1839             output_shapes,
   1840             self.loss_functions,
   1841             is_weighted=True))
   1842 
   1843   def _add_unique_metric_name(self, metric_name, output_index):
   1844     """Makes the metric name unique and adds it to the model's metric name list.
   1845 
   1846       If there are multiple outputs for which the metrics are calculated, the
   1847       metric names have to be made unique by appending an integer.
   1848 
   1849     Arguments:
   1850       metric_name: Metric name that corresponds to the metric specified by the
   1851           user. For example: 'acc'.
   1852       output_index: The index of the model output for which the metric name is
   1853         being added.
   1854 
   1855     Returns:
   1856       string, name of the model's unique metric name
   1857     """
   1858     if len(self.output_names) > 1:
   1859       metric_name = '%s_%s' % (self.output_names[output_index], metric_name)
   1860     j = 1
   1861     base_metric_name = metric_name
   1862     while metric_name in self._compile_metrics_names:
   1863       metric_name = '%s_%d' % (base_metric_name, j)
   1864       j += 1
   1865 
   1866     return metric_name
   1867 
   1868   @property
   1869   def _all_metrics_tensors(self):
   1870     """Returns a dictionary that maps metric names to metric result tensors.
   1871 
   1872     This maps metric names from `model.metric_names` to result tensors.
   1873     Just like model.metric_names, this includes loss names and tensors.
   1874     """
   1875     metrics_tensors = {}
   1876     if self._is_compiled:
   1877       metrics_tensors.update(self._compile_metrics_tensors)
   1878     metrics_tensors.update(super(Model, self)._all_metrics_tensors)
   1879     return metrics_tensors
   1880 
   1881   def _init_metric_attributes(self):
   1882     """Initialized model metric attributes."""
   1883     # List of all metric names in the model. This includes loss metrics.
   1884     self._compile_metrics_names = ['loss']
   1885     # List of stateful metric functions. Used for resetting metric state during
   1886     # training/eval. This includes loss metric functions.
   1887     self._compile_metric_functions = []
   1888     # Dict of all aggregated metric result tensors. This includes aggregated
   1889     # loss result tensors.
   1890     self._compile_metrics_tensors = {}
   1891     # List of metric wrappers on output losses.
   1892     self._output_loss_metrics = None
   1893 
   1894   def _set_per_output_metric_attributes(self, metrics_dict, output_index):
   1895     """Sets the metric attributes on the model for the given output.
   1896 
   1897     Arguments:
   1898       metrics_dict: A dict with metric names as keys and metric fns as values.
   1899       output_index: The index of the model output for which the metric
   1900         attributes are added.
   1901 
   1902     Returns:
   1903       Metrics dict updated with unique metric names as keys.
   1904     """
   1905     updated_metrics_dict = collections.OrderedDict()
   1906     for metric_name, metric_fn in metrics_dict.items():
   1907       metric_name = self._add_unique_metric_name(metric_name, output_index)
   1908 
   1909       # Update the name on the metric class to be the unique generated name.
   1910       metric_fn._name = metric_name  # pylint: disable=protected-access
   1911       updated_metrics_dict[metric_name] = metric_fn
   1912       # Keep track of metric name and function.
   1913       self._compile_metrics_names.append(metric_name)
   1914       self._compile_metric_functions.append(metric_fn)
   1915     return updated_metrics_dict
   1916 
   1917   def _set_metric_attributes(self, skip_target_indices=None):
   1918     """Sets the metric attributes on the model for all the model outputs."""
   1919     # Add loss metric names to the model metric names list.
   1920     if len(self.outputs) > 1:
   1921       output_names = [
   1922           self.output_names[i] + '_loss'
   1923           for i in range(len(self.outputs))
   1924           if i not in skip_target_indices
   1925       ]
   1926       self._compile_metrics_names.extend(output_names)
   1927 
   1928     skip_target_indices = skip_target_indices or []
   1929     updated_per_output_metrics = []
   1930     updated_per_output_weighted_metrics = []
   1931     for i in range(len(self.outputs)):
   1932       if i in skip_target_indices:
   1933         updated_per_output_metrics.append(self._per_output_metrics[i])
   1934         updated_per_output_weighted_metrics.append(
   1935             self._per_output_weighted_metrics[i])
   1936         continue
   1937       updated_per_output_metrics.append(
   1938           self._set_per_output_metric_attributes(self._per_output_metrics[i],
   1939                                                  i))
   1940       updated_per_output_weighted_metrics.append(
   1941           self._set_per_output_metric_attributes(
   1942               self._per_output_weighted_metrics[i], i))
   1943 
   1944     # Create a metric wrapper for each output loss.
   1945     if len(self.outputs) > 1:
   1946       self._output_loss_metrics = [
   1947           metrics_module.SumOverBatchSize() if hasattr(loss_fn, 'reduction')
   1948           else metrics_module.SumOverBatchSizeMetricWrapper(loss_fn)
   1949           for loss_fn in self.loss_functions
   1950       ]
   1951 
   1952     self._per_output_metrics = updated_per_output_metrics
   1953     self._per_output_weighted_metrics = updated_per_output_weighted_metrics
   1954 
   1955   def _call_metric_fn(self, metric_fn, y_true, y_pred, weights, mask=None):
   1956     # TODO(b/120571621): Remove this function when the bug is fixed.
   1957     """Helper function to call metric function with distribution strategy."""
   1958     return self._call_fn_for_each_replica(
   1959         training_utils.call_metric_function,
   1960         metric_fn,
   1961         y_true,
   1962         y_pred,
   1963         weights=weights,
   1964         mask=mask)
   1965 
   1966   def _call_fn_for_each_replica(self, fn, *args, **kwargs):
   1967     # TODO(b/120571621): We want to avoid metric reductions here since
   1968     # since TPUStrategy does not implement replica local variables.
   1969     # Remove this hack once we support TPUReplicaLocalVariables.
   1970     is_tpu = distributed_training_utils.is_tpu_strategy(
   1971         self._distribution_strategy)
   1972     if ((not is_tpu) and self._distribution_strategy and
   1973         distribution_strategy_context.in_cross_replica_context()):
   1974       with self._distribution_strategy.scope():
   1975         return self._distribution_strategy.extended.call_for_each_replica(
   1976             fn, args, kwargs)
   1977     return fn(*args, **kwargs)
   1978 
   1979   def _handle_per_output_metrics(self,
   1980                                  metrics_dict,
   1981                                  y_true,
   1982                                  y_pred,
   1983                                  mask,
   1984                                  weights=None):
   1985     """Calls metric functions for a single output.
   1986 
   1987     Arguments:
   1988       metrics_dict: A dict with metric names as keys and metric fns as values.
   1989       y_true: Target output.
   1990       y_pred: Predicted output.
   1991       mask: Computed mask value for the current output.
   1992       weights: Weights to be applied on the current output.
   1993 
   1994     Returns:
   1995       A list of metric result tensors.
   1996     """
   1997     metric_results = []
   1998     for metric_name, metric_fn in metrics_dict.items():
   1999       with K.name_scope(metric_name):
   2000         metric_result = self._call_metric_fn(metric_fn, y_true, y_pred, weights,
   2001                                              mask)
   2002         metric_results.append(metric_result)
   2003         if not self.run_eagerly:
   2004           self._compile_metrics_tensors[metric_name] = metric_result
   2005 
   2006     return metric_results
   2007 
   2008   def _handle_metrics(self,
   2009                       outputs,
   2010                       skip_target_indices=None,
   2011                       targets=None,
   2012                       sample_weights=None,
   2013                       masks=None):
   2014     """Handles calling metric functions.
   2015 
   2016     Arguments:
   2017       outputs: List of outputs (predictions).
   2018       skip_target_indices: Optional. List of target ids to skip.
   2019       targets: List of targets.
   2020       sample_weights: Optional list of sample weight arrays.
   2021       masks: List of computed output mask values.
   2022 
   2023     Returns:
   2024       A list of metric result tensors.
   2025     """
   2026     skip_target_indices = skip_target_indices or []
   2027     metric_results = []
   2028     with K.name_scope('metrics'):
   2029       # Invoke all metrics added using `compile`.
   2030       for i in range(len(outputs)):
   2031         if i in skip_target_indices:
   2032           continue
   2033         output = outputs[i] if outputs else None
   2034         target = targets[i] if targets else None
   2035         output_mask = masks[i] if masks else None
   2036         metric_results.extend(
   2037             self._handle_per_output_metrics(self._per_output_metrics[i], target,
   2038                                             output, output_mask))
   2039         metric_results.extend(
   2040             self._handle_per_output_metrics(
   2041                 self._per_output_weighted_metrics[i],
   2042                 target,
   2043                 output,
   2044                 output_mask,
   2045                 weights=sample_weights[i]))
   2046 
   2047     # Add metric results from the `add_metric` metrics in eager mode.
   2048     if context.executing_eagerly():
   2049       for m in self.metrics:
   2050         if m not in self._compile_metric_functions:
   2051           metric_results.append(m.result())
   2052     return metric_results
   2053 
   2054   def _check_trainable_weights_consistency(self):
   2055     """Check trainable weights count consistency.
   2056 
   2057     This will raise a warning if `trainable_weights` and
   2058     `_collected_trainable_weights` are inconsistent (i.e. have different
   2059     number of parameters).
   2060     Inconsistency will typically arise when one modifies `model.trainable`
   2061     without calling `model.compile` again.
   2062     """
   2063     if not hasattr(self, '_collected_trainable_weights'):
   2064       return
   2065 
   2066     if len(self.trainable_weights) != len(self._collected_trainable_weights):
   2067       logging.log_first_n(
   2068           logging.WARN, 'Discrepancy between trainable weights and collected'
   2069           ' trainable weights, did you set `model.trainable`'
   2070           ' without calling `model.compile` after ?', 1)
   2071 
   2072   def _make_train_function(self):
   2073     metrics_tensors = [
   2074         self._all_metrics_tensors[m] for m in self.metrics_names[1:]
   2075     ]
   2076     if not self._is_compiled:
   2077       raise RuntimeError('You must compile your model before using it.')
   2078     self._check_trainable_weights_consistency()
   2079     if getattr(self, 'train_function') is None:
   2080       inputs = (self._feed_inputs +
   2081                 self._feed_targets +
   2082                 self._feed_sample_weights)
   2083       if not isinstance(K.symbolic_learning_phase(), int):
   2084         inputs += [K.symbolic_learning_phase()]
   2085 
   2086       with K.get_graph().as_default():
   2087         with K.name_scope('training'):
   2088           with K.name_scope(self.optimizer.__class__.__name__):
   2089             # Training updates
   2090             updates = self.optimizer.get_updates(
   2091                 params=self._collected_trainable_weights, loss=self.total_loss)
   2092       # Unconditional updates
   2093       updates += self.get_updates_for(None)
   2094       # Conditional updates relevant to this model
   2095       updates += self.get_updates_for(self.inputs)
   2096 
   2097       with K.name_scope('training'):
   2098         # Gets loss and metrics. Updates weights at each call.
   2099         fn = K.function(
   2100             inputs, [self.total_loss] + metrics_tensors,
   2101             updates=updates,
   2102             name='train_function',
   2103             **self._function_kwargs)
   2104         setattr(self, 'train_function', fn)
   2105 
   2106   def _make_test_function(self):
   2107     metrics_tensors = [
   2108         self._all_metrics_tensors[m] for m in self.metrics_names[1:]
   2109     ]
   2110     if not self._is_compiled:
   2111       raise RuntimeError('You must compile your model before using it.')
   2112     if getattr(self, 'test_function') is None:
   2113       inputs = (self._feed_inputs +
   2114                 self._feed_targets +
   2115                 self._feed_sample_weights)
   2116 
   2117       with K.name_scope('evaluation'):
   2118         updates = self.state_updates
   2119         # Return loss and metrics, no gradient updates.
   2120         # Does update the network states.
   2121         fn = K.function(
   2122             inputs, [self.total_loss] + metrics_tensors,
   2123             updates=updates,
   2124             name='test_function',
   2125             **self._function_kwargs)
   2126         setattr(self, 'test_function', fn)
   2127 
   2128   def _make_predict_function(self):
   2129     if not hasattr(self, 'predict_function'):
   2130       self.predict_function = None
   2131     if self.predict_function is None:
   2132       inputs = self._feed_inputs
   2133       # Gets network outputs. Does not update weights.
   2134       # Does update the network states.
   2135       kwargs = getattr(self, '_function_kwargs', {})
   2136       with K.name_scope(ModeKeys.PREDICT):
   2137         self.predict_function = K.function(
   2138             inputs,
   2139             self.outputs,
   2140             updates=self.state_updates,
   2141             name='predict_function',
   2142             **kwargs)
   2143 
   2144   def _make_execution_function(self, mode):
   2145     if mode == ModeKeys.TRAIN:
   2146       self._make_train_function()
   2147       return self.train_function
   2148     if mode == ModeKeys.TEST:
   2149       self._make_test_function()
   2150       return self.test_function
   2151     if mode == ModeKeys.PREDICT:
   2152       self._make_predict_function()
   2153       return self.predict_function
   2154 
   2155   def _distribution_standardize_user_data(self,
   2156                                           x,
   2157                                           y=None,
   2158                                           sample_weight=None,
   2159                                           class_weight=None,
   2160                                           batch_size=None,
   2161                                           validation_split=0,
   2162                                           shuffle=False,
   2163                                           repeat=False,
   2164                                           allow_partial_batch=False):
   2165     """Runs validation checks on input and target data passed by the user.
   2166 
   2167     This is called when using DistributionStrategy to train, evaluate or serve
   2168     the model.
   2169 
   2170     Args:
   2171       x: Input data. A numpy array or `tf.data` dataset.
   2172       y: Target data. A numpy array or None if x is a `tf.data` dataset.
   2173       sample_weight: An optional sample-weight array passed by the user to
   2174         weight the importance of each sample in `x`.
   2175       class_weight: An optional class-weight array by the user to
   2176         weight the importance of samples in `x` based on the class they belong
   2177         to, as conveyed by `y`.
   2178       batch_size: Integer batch size. If provided, it is used to run additional
   2179         validation checks on stateful models.
   2180       validation_split: Float between 0 and 1.
   2181         Fraction of the training data to be used as validation data.
   2182       shuffle: Boolean whether to shuffle the training data before each epoch.
   2183       repeat: Boolean whether to repeat the numpy training data when converting
   2184         to training dataset.
   2185       allow_partial_batch: Boolean whether to enforce that all batches have the
   2186         same size.
   2187 
   2188     Returns:
   2189       Dataset instance.
   2190 
   2191     Raises:
   2192       ValueError: In case of invalid user-provided data.
   2193       RuntimeError: If the model was never compiled.
   2194     """
   2195     if class_weight:
   2196       raise NotImplementedError('`class_weight` is currently not supported '
   2197                                 'when using DistributionStrategy.')
   2198 
   2199     if (sample_weight is not None and sample_weight.all() and
   2200         distributed_training_utils.is_tpu_strategy(
   2201             self._distribution_strategy)):
   2202       raise NotImplementedError('`sample_weight` is currently not supported '
   2203                                 'when using TPUStrategy.')
   2204 
   2205     if (self.stateful and distributed_training_utils.is_tpu_strategy(
   2206         self._distribution_strategy) and self._distribution_strategy.
   2207         num_replicas_in_sync != 1):
   2208       raise ValueError('Single core must be used for computation on '
   2209                        'stateful models. Consider adding `device_assignment` '
   2210                        'parameter to TPUStrategy using\n'
   2211                        'topology = tf.contrib.distribute.'
   2212                        'initialize_tpu_system()\n'
   2213                        'device_assignment = tf.contrib.tpu.DeviceAssignment('
   2214                        'topology, core_assignment=tf.contrib.tpu.'
   2215                        'SINGLE_CORE_ASSIGNMENT)\n'
   2216                        'tpu_strategy = tf.contrib.distribute.TPUStrategy('
   2217                        'device_assignment=device_assignment)')
   2218 
   2219     # Validates `steps` and `shuffle` arguments right at the beginning
   2220     # since we use it to construct the dataset object.
   2221     # TODO(anjalisridhar): Remove this check once we refactor the
   2222     # _standardize_user_data code path. This check is already present elsewhere
   2223     # in the codebase.
   2224     if isinstance(x, dataset_ops.DatasetV2):
   2225       if shuffle:
   2226         training_utils.verify_dataset_shuffled(x)
   2227 
   2228     strategy = self._distribution_strategy
   2229     with strategy.scope():
   2230       # We should be sure to call get_session() inside the strategy.scope()
   2231       # so the strategy can affect the session options.
   2232       if ops.executing_eagerly_outside_functions():
   2233         session = None
   2234       else:
   2235         session = K.get_session()
   2236 
   2237       first_x_value = nest.flatten(x)[0]
   2238       if isinstance(first_x_value, np.ndarray):
   2239         x = distributed_training_utils.list_to_tuple(x)
   2240         if y is not None:
   2241           y = distributed_training_utils.list_to_tuple(y)
   2242           if sample_weight is not None:
   2243             sample_weight = distributed_training_utils.list_to_tuple(
   2244                 sample_weight)
   2245             in_tuple = (x, y, sample_weight)
   2246           else:
   2247             in_tuple = (x, y)
   2248         else:
   2249           in_tuple = x
   2250 
   2251         ds = strategy.extended.experimental_make_numpy_dataset(in_tuple,
   2252                                                                session=session)
   2253         if shuffle:
   2254           # We want a buffer size that is larger than the batch size provided by
   2255           # the user and provides sufficient randomness. Note that larger
   2256           # numbers introduce more memory usage based on the size of each
   2257           # sample.
   2258           ds = ds.shuffle(max(1024, batch_size * 8))
   2259         if repeat:
   2260           ds = ds.repeat()
   2261 
   2262         # We need to use the drop_remainder argument to get a known static
   2263         # input shape which is required for TPUs.
   2264         drop_remainder = (not allow_partial_batch and
   2265                           strategy.extended.experimental_require_static_shapes)
   2266         x = ds.batch(batch_size, drop_remainder=drop_remainder)
   2267       else:
   2268         assert isinstance(x, dataset_ops.DatasetV2)
   2269         training_utils.validate_dataset_input(x, y, sample_weight,
   2270                                               validation_split)
   2271     return x
   2272 
   2273   def _standardize_user_data(self,
   2274                              x,
   2275                              y=None,
   2276                              sample_weight=None,
   2277                              class_weight=None,
   2278                              batch_size=None,
   2279                              check_steps=False,
   2280                              steps_name='steps',
   2281                              steps=None,
   2282                              validation_split=0,
   2283                              shuffle=False,
   2284                              extract_tensors_from_dataset=False):
   2285     """Runs validation checks on input and target data passed by the user.
   2286 
   2287     Also standardizes the data to lists of arrays, in order.
   2288 
   2289     Also builds and compiles the model on the fly if it is a subclassed model
   2290     that has never been called before (and thus has no inputs/outputs).
   2291 
   2292     This is a purely internal method, subject to refactoring at any time.
   2293 
   2294     Args:
   2295       x: Input data. It could be:
   2296         - A Numpy array (or array-like), or a list of arrays
   2297           (in case the model has multiple inputs).
   2298         - A TensorFlow tensor, or a list of tensors
   2299           (in case the model has multiple inputs).
   2300         - A dict mapping input names to the corresponding array/tensors,
   2301           if the model has named inputs.
   2302         - A `tf.data` dataset or a dataset iterator.
   2303       y: Target data. Like the input data `x`,
   2304         it could be either Numpy array(s) or TensorFlow tensor(s).
   2305         It should be consistent with `x` (you cannot have Numpy inputs and
   2306         tensor targets, or inversely). If `x` is a dataset or a
   2307         dataset iterator, `y` should not be specified
   2308         (since targets will be obtained from the iterator).
   2309       sample_weight: An optional sample-weight array passed by the user to
   2310         weight the importance of each sample in `x`.
   2311       class_weight: An optional class-weight array by the user to
   2312         weight the importance of samples in `x` based on the class they belong
   2313         to, as conveyed by `y`. If both `sample_weight` and `class_weight` are
   2314         provided, the weights are multiplied.
   2315       batch_size: Integer batch size. If provided, it is used to run additional
   2316         validation checks on stateful models.
   2317       check_steps: boolean, True if we want to check for validity of `steps` and
   2318         False, otherwise. For example, when we are standardizing one batch of
   2319         data for train_on_batch/predict_on_batch/test_on_batch APIs, `steps`
   2320         value is not required and we should not check for its validity in these
   2321         cases.
   2322       steps_name: The public API's parameter name for `steps`.
   2323       steps: Integer or `None`. Total number of steps (batches of samples) to
   2324         execute.
   2325       validation_split: Float between 0 and 1.
   2326         Fraction of the training data to be used as validation data.
   2327       shuffle: Boolean whether to shuffle the training data before each epoch.
   2328       extract_tensors_from_dataset: Boolean. When `x` is a dataset instance,
   2329         this indicates whether to extract actual tensors from the dataset or
   2330         instead output the dataset instance itself.
   2331         Set to True when calling from `train_on_batch`/etc.
   2332 
   2333     Returns:
   2334       A tuple of 3: inputs (arrays or dicts, depending on whether `x` was a dict
   2335       or not), target arrays, sample-weight arrays.
   2336       If the model's input and targets are symbolic, these lists are empty
   2337       (since the model takes no user-provided data, instead the data comes
   2338       from the symbolic inputs/targets).
   2339 
   2340     Raises:
   2341       ValueError: In case of invalid user-provided data.
   2342       RuntimeError: If the model was never compiled.
   2343     """
   2344     if isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)):
   2345       # Graph mode dataset. We'll pass the dataset as-is (unless
   2346       # `extract_tensors_from_dataset` is True, in which case we extract
   2347       # the tensors from the dataset and we output them.
   2348       training_utils.validate_dataset_input(x, y, sample_weight,
   2349                                             validation_split)
   2350       if shuffle:
   2351         training_utils.verify_dataset_shuffled(x)
   2352 
   2353       is_dataset = True
   2354       if extract_tensors_from_dataset:
   2355         # We do this for `train_on_batch`/etc.
   2356         x, y, sample_weight = training_utils.extract_tensors_from_dataset(x)
   2357     elif isinstance(x, iterator_ops.Iterator):
   2358       # Graph mode iterator. We extract the symbolic tensors.
   2359       training_utils.validate_dataset_input(x, y, sample_weight,
   2360                                             validation_split)
   2361       iterator = x
   2362       x, y, sample_weight = training_utils.unpack_iterator_input(iterator)
   2363       is_dataset = True
   2364     else:
   2365       is_dataset = False
   2366 
   2367     # Validates `steps` argument based on x's type.
   2368     if check_steps:
   2369       training_utils.check_steps_argument(x, steps, steps_name)
   2370 
   2371     # First, we build/compile the model on the fly if necessary.
   2372     all_inputs = []
   2373     is_build_called = False
   2374     is_compile_called = False
   2375     # Whether this is a subclassed model that expects dictionary inputs
   2376     # rather than list inputs (e.g. FeatureColumn-based models).
   2377     dict_inputs = False
   2378     if not self.inputs:
   2379       # We need to use `x_input` to set the model inputs.
   2380 
   2381       # If input data is a dataset iterator in graph mode or if it is an eager
   2382       # iterator and only one batch of samples is required, we fetch the data
   2383       # tensors from the iterator and then standardize them.
   2384       if isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)):
   2385         x_input, y_input, _ = training_utils.extract_tensors_from_dataset(x)
   2386       else:
   2387         x_input = x
   2388         y_input = y
   2389       # We type-check that `x_input` and `y_input` are either single arrays
   2390       # or lists of arrays.
   2391       if isinstance(x_input, (list, tuple)):
   2392         if not all(isinstance(v, np.ndarray) or
   2393                    tensor_util.is_tensor(v) for v in x_input):
   2394           raise ValueError('Please provide as model inputs either a single '
   2395                            'array or a list of arrays. You passed: x=' + str(x))
   2396         all_inputs += list(x_input)
   2397       elif isinstance(x_input, dict):
   2398         dict_inputs = True
   2399         keys = sorted(x_input.keys())
   2400         all_inputs = [x_input[k] for k in keys]
   2401       else:
   2402         if (not isinstance(x_input, np.ndarray) and
   2403             not tensor_util.is_tensor(x_input)):
   2404           raise ValueError('Please provide as model inputs either a single '
   2405                            'array or a list of arrays. You passed: x=' + str(x))
   2406         all_inputs.append(x_input)
   2407 
   2408       # Build the model using the retrieved inputs (value or symbolic).
   2409       # If values or generated from a dataset, then in symbolic-mode
   2410       # placeholders will be created to match the value shapes.
   2411       is_build_called = True
   2412       if is_dataset:
   2413         cast_inputs = nest.map_structure(lambda v: v.shape, x_input)
   2414       elif training_utils.has_tensors(x_input):
   2415         cast_inputs = training_utils.cast_if_floating_dtype(x_input)
   2416       else:
   2417         cast_inputs = x_input
   2418       self._set_inputs(cast_inputs)
   2419     else:
   2420       y_input = y
   2421       dict_inputs = isinstance(self.inputs, dict)
   2422 
   2423     if y_input is not None:
   2424       if not self.optimizer:
   2425         raise RuntimeError('You must compile a model before '
   2426                            'training/testing. '
   2427                            'Use `model.compile(optimizer, loss)`.')
   2428       if not self._is_compiled:
   2429         # On-the-fly compilation of the model.
   2430         # We need to use `y` to set the model targets.
   2431         if training_utils.has_tensors(y_input):
   2432           y_input = training_utils.cast_if_floating_dtype(y_input)
   2433         if isinstance(y_input, (list, tuple)):
   2434           if not all(isinstance(v, np.ndarray) or
   2435                      tensor_util.is_tensor(v) for v in y_input):
   2436             raise ValueError('Please provide as model targets either a single '
   2437                              'array or a list of arrays. '
   2438                              'You passed: y=' + str(y))
   2439           all_inputs += list(y_input)
   2440         elif isinstance(y_input, dict):
   2441           raise ValueError('You cannot pass a dictionary as model targets.')
   2442         else:
   2443           if (not isinstance(y_input, np.ndarray) and
   2444               not tensor_util.is_tensor(y_input)):
   2445             raise ValueError('Please provide as model targets either a single '
   2446                              'array or a list of arrays. '
   2447                              'You passed: y=' + str(y))
   2448           all_inputs.append(y_input)
   2449 
   2450         # Typecheck that all inputs are *either* value *or* symbolic.
   2451         # TODO(fchollet): this check could be removed in Eager mode?
   2452         if any(tensor_util.is_tensor(v) for v in all_inputs):
   2453           if not all(tensor_util.is_tensor(v) for v in all_inputs):
   2454             raise ValueError('Do not pass inputs that mix Numpy arrays and '
   2455                              'TensorFlow tensors. '
   2456                              'You passed: x=' + str(x) + '; y=' + str(y))
   2457 
   2458         if is_dataset or context.executing_eagerly():
   2459           target_tensors = None
   2460         else:
   2461           # Handle target tensors if any passed.
   2462           if not isinstance(y_input, (list, tuple)):
   2463             y_input = [y_input]
   2464           target_tensors = [v for v in y_input if _is_symbolic_tensor(v)]
   2465         is_compile_called = True
   2466         self.compile(
   2467             optimizer=self.optimizer,
   2468             loss=self.loss,
   2469             metrics=self._compile_metrics,
   2470             weighted_metrics=self._compile_weighted_metrics,
   2471             loss_weights=self.loss_weights,
   2472             target_tensors=target_tensors,
   2473             run_eagerly=self.run_eagerly)
   2474 
   2475     # In graph mode, if we had just set inputs and targets as symbolic tensors
   2476     # by invoking build and compile on the model respectively, we do not have to
   2477     # feed anything to the model. Model already has input and target data as
   2478     # part of the graph.
   2479     # Note: in this case, `any` and `all` are equivalent since we disallow
   2480     # mixed symbolic/value inputs.
   2481     if (not self.run_eagerly and is_build_called and is_compile_called and
   2482         not is_dataset  and any(_is_symbolic_tensor(v) for v in all_inputs)):
   2483       return [], [], []
   2484 
   2485     # What follows is input validation and standardization to list format,
   2486     # in the case where all inputs are value arrays.
   2487 
   2488     if self.run_eagerly:
   2489       # In eager mode, do not do shape validation
   2490       # since the network has no input nodes (placeholders) to be fed.
   2491       feed_input_names = self.input_names
   2492       feed_input_shapes = None
   2493     elif not self._is_graph_network:
   2494       # Case: symbolic-mode subclassed network. Do not do shape validation.
   2495       feed_input_names = self._feed_input_names
   2496       feed_input_shapes = None
   2497     else:
   2498       # Case: symbolic-mode graph network.
   2499       # In this case, we run extensive shape validation checks.
   2500       feed_input_names = self._feed_input_names
   2501       feed_input_shapes = self._feed_input_shapes
   2502 
   2503     # Standardize the inputs.
   2504     if not isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)):
   2505       # TODO(fchollet): run static checks with dataset output shape(s).
   2506       x = training_utils.standardize_input_data(
   2507           x,
   2508           feed_input_names,
   2509           feed_input_shapes,
   2510           check_batch_axis=False,  # Don't enforce the batch size.
   2511           exception_prefix='input')
   2512 
   2513     if y is not None:
   2514       if not self._is_graph_network:
   2515         feed_output_names = self._feed_output_names
   2516         feed_output_shapes = None
   2517         # Sample weighting not supported in this case.
   2518         # TODO(fchollet): consider supporting it.
   2519         feed_sample_weight_modes = [None for _ in self.outputs]
   2520       else:
   2521         feed_output_names = self._feed_output_names
   2522         feed_sample_weight_modes = self._feed_sample_weight_modes
   2523         feed_output_shapes = []
   2524         for output_shape, loss_fn in zip(self._feed_output_shapes,
   2525                                          self._feed_loss_fns):
   2526           if ((isinstance(loss_fn, losses.LossFunctionWrapper) and
   2527                loss_fn.fn == losses.sparse_categorical_crossentropy)) or (
   2528                    isinstance(loss_fn, losses.SparseCategoricalCrossentropy)):
   2529             if K.image_data_format() == 'channels_first':
   2530               feed_output_shapes.append(
   2531                   (output_shape[0], 1) + output_shape[2:])
   2532             else:
   2533               feed_output_shapes.append(output_shape[:-1] + (1,))
   2534           elif (not isinstance(loss_fn, losses.Loss) or
   2535                 (isinstance(loss_fn, losses.LossFunctionWrapper) and
   2536                  (getattr(losses, loss_fn.fn.__name__, None) is None))):
   2537             # If the given loss is not an instance of the `Loss` class (custom
   2538             # class) or if the loss function that is wrapped is not in the
   2539             # `losses` module, then it is a user-defined loss and we make no
   2540             # assumptions about it.
   2541             feed_output_shapes.append(None)
   2542           else:
   2543             feed_output_shapes.append(output_shape)
   2544 
   2545       # Standardize the outputs.
   2546       y = training_utils.standardize_input_data(
   2547           y,
   2548           feed_output_names,
   2549           # Don't enforce target shapes to match output shapes.
   2550           # Precise checks will be run in `check_loss_and_target_compatibility`.
   2551           shapes=None,
   2552           check_batch_axis=False,  # Don't enforce the batch size.
   2553           exception_prefix='target')
   2554 
   2555       # Generate sample-wise weight values given the `sample_weight` and
   2556       # `class_weight` arguments.
   2557       sample_weights = training_utils.standardize_sample_weights(
   2558           sample_weight, feed_output_names)
   2559       class_weights = training_utils.standardize_class_weights(
   2560           class_weight, feed_output_names)
   2561       sample_weights = [
   2562           training_utils.standardize_weights(ref, sw, cw, mode)
   2563           for (ref, sw, cw, mode) in zip(y, sample_weights, class_weights,
   2564                                          feed_sample_weight_modes)
   2565       ]
   2566       # Check that all arrays have the same length.
   2567       if not self._distribution_strategy:
   2568         training_utils.check_array_lengths(x, y, sample_weights)
   2569         if self._is_graph_network and not self.run_eagerly:
   2570           # Additional checks to avoid users mistakenly using improper loss fns.
   2571           training_utils.check_loss_and_target_compatibility(
   2572               y, self._feed_loss_fns, feed_output_shapes)
   2573     else:
   2574       y = []
   2575       sample_weights = []
   2576 
   2577     if self.stateful and batch_size:
   2578       # Check that for stateful networks, number of samples is a multiple
   2579       # of the static batch size.
   2580       if x[0].shape[0] % batch_size != 0:
   2581         raise ValueError('In a stateful network, '
   2582                          'you should only pass inputs with '
   2583                          'a number of samples that can be '
   2584                          'divided by the batch size. Found: ' +
   2585                          str(x[0].shape[0]) + ' samples')
   2586 
   2587     # If dictionary inputs were provided, we return a dictionary as well.
   2588     if dict_inputs and not isinstance(x, (dataset_ops.DatasetV1,
   2589                                           dataset_ops.DatasetV2)):
   2590       x = dict(zip(feed_input_names, x))
   2591     return x, y, sample_weights
   2592 
   2593   def _unpack_validation_data(self, validation_data):
   2594     if (isinstance(validation_data, (iterator_ops.Iterator,
   2595                                      iterator_ops.EagerIterator,
   2596                                      dataset_ops.DatasetV2))):
   2597       val_x = validation_data
   2598       val_y = None
   2599       val_sample_weight = None
   2600     elif len(validation_data) == 2:
   2601       val_x, val_y = validation_data  # pylint: disable=unpacking-non-sequence
   2602       val_sample_weight = None
   2603     elif len(validation_data) == 3:
   2604       val_x, val_y, val_sample_weight = validation_data  # pylint: disable=unpacking-non-sequence
   2605     else:
   2606       raise ValueError(
   2607           'When passing a `validation_data` argument, '
   2608           'it must contain either 2 items (x_val, y_val), '
   2609           'or 3 items (x_val, y_val, val_sample_weights), '
   2610           'or alternatively it could be a dataset or a '
   2611           'dataset or a dataset iterator. '
   2612           'However we received `validation_data=%s`' % validation_data)
   2613     return val_x, val_y, val_sample_weight
   2614 
   2615   # TODO(omalleyt): Consider changing to a more descriptive function name.
   2616   def _set_inputs(self, inputs, outputs=None, training=None):
   2617     """Set model's input and output specs based on the input data received.
   2618 
   2619     This is to be used for Model subclasses, which do not know at instantiation
   2620     time what their inputs look like.
   2621 
   2622     Args:
   2623       inputs: Single array, or list of arrays. The arrays could be placeholders,
   2624         Numpy arrays, data tensors, or TensorShapes.
   2625         - if placeholders: the model is built on top of these placeholders,
   2626           and we expect Numpy data to be fed for them when calling `fit`/etc.
   2627         - if Numpy data or TensorShapes: we create placeholders matching the
   2628           TensorShapes or shapes of the Numpy arrays. We expect Numpy data to be
   2629           fed for these placeholders when calling `fit`/etc.
   2630         - if data tensors: the model is built on top of these tensors.
   2631           We do not expect any Numpy data to be provided when calling `fit`/etc.
   2632       outputs: None, a data tensor, or a list of tensors. If None, the
   2633         outputs will be determined by invoking `self.call()`, otherwise the
   2634         provided value will be used.
   2635       training: Boolean or None. Only relevant in symbolic mode. Specifies
   2636         whether to build the model's graph in inference mode (False), training
   2637         mode (True), or using the Keras learning phase (None).
   2638     Raises:
   2639       ValueError: If dict inputs are passed to a Sequential Model where the
   2640         first layer isn't FeatureLayer.
   2641     """
   2642     inputs = self._set_input_attrs(inputs)
   2643 
   2644     if outputs is None:
   2645       kwargs = {'training': training} if self._expects_training_arg else {}
   2646       try:
   2647         outputs = self(inputs, **kwargs)
   2648       except NotImplementedError:
   2649         # This Model or a submodel is dynamic and hasn't overridden
   2650         # `compute_output_shape`.
   2651         outputs = None
   2652 
   2653     self._set_output_attrs(outputs)
   2654 
   2655   @trackable.no_automatic_dependency_tracking
   2656   def _set_input_attrs(self, inputs):
   2657     """Sets attributes related to the inputs of the Model."""
   2658     if self.inputs:
   2659       raise ValueError('Model inputs are already set.')
   2660 
   2661     if self.__class__.__name__ == 'Sequential' and not self.built:
   2662       if tensor_util.is_tensor(inputs):
   2663         input_shape = (None,) + tuple(inputs.shape.as_list()[1:])
   2664       elif isinstance(inputs, tensor_shape.TensorShape):
   2665         input_shape = (None,) + tuple(inputs.as_list()[1:])
   2666       elif isinstance(inputs, dict):
   2667         # We assert that the first layer is a FeatureLayer.
   2668         if not training_utils.is_feature_layer(self.layers[0]):
   2669           raise ValueError('Passing a dictionary input to a Sequential Model '
   2670                            'which doesn\'t have FeatureLayer as the first layer'
   2671                            ' is an error.')
   2672         input_shape = (None,)
   2673       else:
   2674         input_shape = (None,) + tuple(inputs.shape[1:])
   2675       self._build_input_shape = input_shape
   2676 
   2677     # On-the-fly setting of symbolic model inputs (either by using the tensor
   2678     # provided, or by creating a placeholder if Numpy data was provided).
   2679     model_inputs = training_utils.ModelInputs(inputs)
   2680     inputs = model_inputs.get_symbolic_inputs()
   2681     self.inputs = model_inputs.get_symbolic_inputs(return_single_as_list=True)
   2682     self.input_names = model_inputs.get_input_names()
   2683 
   2684     self._feed_inputs = []
   2685     self._feed_input_names = []
   2686     self._feed_input_shapes = []
   2687 
   2688     for k, v in model_inputs.as_dict():
   2689       if K.is_placeholder(v):
   2690         self._feed_input_names.append(k)
   2691         self._feed_inputs.append(v)
   2692         self._feed_input_shapes.append(K.int_shape(v))
   2693 
   2694     return inputs
   2695 
   2696   @trackable.no_automatic_dependency_tracking
   2697   def _set_output_attrs(self, outputs):
   2698     """Sets attributes related to the outputs of the Model."""
   2699     outputs = nest.flatten(outputs)
   2700     self.outputs = outputs
   2701     self.output_names = training_utils.generic_output_names(outputs)
   2702     self.built = True
   2703 
   2704 
   2705 class DistributedCallbackModel(Model):
   2706   """Model that is used for callbacks with DistributionStrategy."""
   2707 
   2708   def __init__(self, model):
   2709     super(DistributedCallbackModel, self).__init__()
   2710     self.optimizer = model.optimizer
   2711 
   2712   def set_original_model(self, orig_model):
   2713     self._original_model = orig_model
   2714 
   2715   def save_weights(self, filepath, overwrite=True, save_format=None):
   2716     self._replicated_model.save_weights(filepath, overwrite=overwrite,
   2717                                         save_format=save_format)
   2718 
   2719   def save(self, filepath, overwrite=True, include_optimizer=True):
   2720     # save weights from the distributed model to the original model
   2721     distributed_model_weights = self.get_weights()
   2722     self._original_model.set_weights(distributed_model_weights)
   2723     # TODO(anjalisridhar): Do we need to save the original model here?
   2724     # Saving the first replicated model works as well.
   2725     self._original_model.save(filepath, overwrite=True, include_optimizer=False)
   2726 
   2727   def load_weights(self, filepath, by_name=False):
   2728     self._original_model.load_weights(filepath, by_name=False)
   2729     # Copy the weights from the original model to each of the replicated models.
   2730     orig_model_weights = self._original_model.get_weights()
   2731     distributed_training_utils.set_weights(
   2732         self._original_model._distribution_strategy, self,  # pylint: disable=protected-access
   2733         orig_model_weights)
   2734 
   2735   def __getattr__(self, item):
   2736     # Whitelisted atttributes of the model that can be accessed by the user
   2737     # during a callback.
   2738     if item not in ['_setattr_tracking']:
   2739       logging.warning('You are accessing attribute ' + item + ' of the '
   2740                       'DistributedCallbackModel that may not have been set '
   2741                       'correctly.')
   2742 
   2743 
   2744 def _is_symbolic_tensor(x):
   2745   return tensor_util.is_tensor(x) and not isinstance(x, ops.EagerTensor)
   2746