Home | History | Annotate | Download | only in tpu
      1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """*Experimental* support for running Keras models on the TPU.
     16 
     17 To use, wrap your model with the `keras_support.tpu_model` function.
     18 
     19 Example usage:
     20 
     21 ```
     22 image = tf.keras.layers.Input(shape=(28, 28, 3), name='image')
     23 c1 = tf.keras.layers.Conv2D(filters=16, kernel_size=(3, 3))( image)
     24 flattened = tf.keras.layers.Flatten()(c1)
     25 logits = tf.keras.layers.Dense(10, activation='softmax')(flattened)
     26 model = tf.keras.Model(inputs=[image], outputs=[logits])
     27 
     28 resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu=tpu_name)
     29 strategy = keras_support.TPUDistributionStrategy(resolver)
     30 model = keras_support.tpu_model(model, strategy=strategy)
     31 
     32 # Only TF optimizers are currently supported.
     33 model.compile(optimizer=tf.train.AdamOptimizer(), ...)
     34 
     35 # `images` and `labels` should be Numpy arrays.  Support for tensor input
     36 # (e.g. datasets) is planned.
     37 model.fit(images, labels)
     38 ```
     39 """
     40 
     41 # pylint: disable=protected-access
     42 
     43 from __future__ import absolute_import
     44 from __future__ import division
     45 from __future__ import print_function
     46 
     47 import abc
     48 import collections
     49 import contextlib
     50 import re
     51 import sys
     52 import time
     53 
     54 import numpy as np
     55 import six
     56 
     57 from tensorflow.contrib.cluster_resolver.python.training import tpu_cluster_resolver as tpu_cluster_resolver_lib
     58 from tensorflow.contrib.tpu.python.ops import tpu_ops
     59 from tensorflow.contrib.tpu.python.tpu import keras_tpu_variables
     60 from tensorflow.contrib.tpu.python.tpu import tpu
     61 from tensorflow.contrib.tpu.python.tpu import tpu_function
     62 from tensorflow.contrib.tpu.python.tpu import tpu_optimizer
     63 from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
     64 from tensorflow.core.protobuf import config_pb2
     65 from tensorflow.core.protobuf.tpu import compilation_result_pb2 as tpu_compilation_result
     66 from tensorflow.python import tf2
     67 from tensorflow.python.client import session as tf_session
     68 from tensorflow.python.data.ops import dataset_ops
     69 from tensorflow.python.data.ops import iterator_ops
     70 from tensorflow.python.eager import context
     71 from tensorflow.python.estimator import model_fn as model_fn_lib
     72 from tensorflow.python.framework import constant_op
     73 from tensorflow.python.framework import dtypes
     74 from tensorflow.python.framework import errors
     75 from tensorflow.python.framework import ops
     76 from tensorflow.python.framework import tensor_shape
     77 from tensorflow.python.framework import tensor_spec
     78 from tensorflow.python.keras import backend as K
     79 from tensorflow.python.keras import callbacks as cbks
     80 from tensorflow.python.keras import metrics as metrics_module
     81 from tensorflow.python.keras import models
     82 from tensorflow.python.keras import optimizers as keras_optimizers
     83 from tensorflow.python.keras.engine import base_layer
     84 from tensorflow.python.keras.engine import base_layer_utils
     85 from tensorflow.python.keras.engine import training_arrays
     86 from tensorflow.python.keras.engine import training_utils
     87 from tensorflow.python.keras.layers import embeddings
     88 from tensorflow.python.keras.utils.generic_utils import make_batches
     89 from tensorflow.python.keras.utils.generic_utils import slice_arrays
     90 from tensorflow.python.ops import array_ops
     91 from tensorflow.python.ops import gen_linalg_ops
     92 from tensorflow.python.ops import math_ops
     93 from tensorflow.python.ops import random_ops
     94 from tensorflow.python.ops import variable_scope
     95 from tensorflow.python.ops import variables
     96 from tensorflow.python.platform import tf_logging as logging
     97 from tensorflow.python.util.deprecation import deprecated
     98 
     99 
    100 # TODO(b/114775106): temporary shim to optionally initialize the TPU
    101 # This increases the odds our session is initialized, but shouldn't be needed.
    102 _TEST_REWRITE_OP = None
    103 
    104 
    105 def _maybe_initialize_tpu(session):
    106   """Initialize the TPU if it has not already been initialized."""
    107   global _TEST_REWRITE_OP
    108   try:
    109     # Try to use cached version to avoid another ground of graph optimization.
    110     test_rewrite_op = _TEST_REWRITE_OP
    111     if (test_rewrite_op is None or
    112         test_rewrite_op[0].graph != ops.get_default_graph()):
    113 
    114       def test_op():
    115         return constant_op.constant(1) + constant_op.constant(1)
    116 
    117       test_rewrite_op = tpu.rewrite(test_op)
    118       _TEST_REWRITE_OP = test_rewrite_op
    119 
    120     session.run(test_rewrite_op)
    121   except errors.FailedPreconditionError as _:
    122     session.run(tpu.initialize_system())
    123 
    124 
    125 @contextlib.contextmanager
    126 def _tpu_session_context():
    127   """Initialize the TPU and cleans cache entries for bad sessions."""
    128   try:
    129     _maybe_initialize_tpu(K.get_session())
    130     yield
    131   except (errors.FailedPreconditionError, errors.AbortedError) as e:
    132     K.clear_session()
    133     raise Exception("""
    134 An error occurred connecting or initializing your TPU.
    135 
    136 The session has been reset. re-run keras_to_tpu_model to create a new session.
    137 """ + str(e))
    138 
    139 
    140 def setup_tpu_session(cluster_resolver):
    141   """Construct or return a `tf.Session` connected to the given cluster."""
    142   master = cluster_resolver.master()
    143 
    144   # Use the existing session if we're already connected to this TPU
    145   # N.B K.get_session() is a non-trivial operation, and may fail if the remote
    146   # session has been reset.
    147   try:
    148     default_session = K.get_session()
    149     if (default_session._target == master and
    150         getattr(default_session, '_tpu_initialized', None)):
    151       return
    152   except errors.AbortedError as _:
    153     # We lost the remote session and need to re-initialize.
    154     logging.warning('Lost remote session: creating a new session.')
    155 
    156   cluster_spec = cluster_resolver.cluster_spec()
    157   config = config_pb2.ConfigProto(isolate_session_state=True)
    158   if cluster_spec:
    159     config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
    160 
    161   tpu_session = tf_session.Session(target=master, config=config)
    162   tpu_session.run(tpu.initialize_system())
    163   tpu_session._tpu_initialized = True
    164 
    165   # N.B. We have to call `K.set_session()` AND set our session as the
    166   # TF default. `K.get_session()` surprisingly does not return the value
    167   # supplied by K.set_session otherwise.
    168   K.set_session(tpu_session)
    169 
    170 
    171 try:
    172   from scipy.sparse import issparse  # pylint: disable=g-import-not-at-top
    173 except ImportError:
    174   issparse = None
    175 
    176 
    177 def get_tpu_system_metadata(tpu_cluster_resolver):
    178   """Retrieves TPU system metadata given a TPUClusterResolver."""
    179   master = tpu_cluster_resolver.master()
    180 
    181   # pylint: disable=protected-access
    182   cluster_spec = tpu_cluster_resolver.cluster_spec()
    183   cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None
    184   tpu_system_metadata = (
    185       tpu_system_metadata_lib._query_tpu_system_metadata(
    186           master, cluster_def=cluster_def, query_topology=False))
    187 
    188   return tpu_system_metadata
    189 
    190 
    191 class TPUDistributionStrategy(object):
    192   """The strategy to run Keras model on TPU."""
    193 
    194   def __init__(self, tpu_cluster_resolver=None, using_single_core=False):
    195     """Construct a TPUDistributionStrategy.
    196 
    197     Args:
    198       tpu_cluster_resolver: Any instance of `TPUClusterResolver`. If None, will
    199         create one with '' as master address.
    200       using_single_core: Bool. This is the debugging option, which might be
    201         removed in future once the model replication functionality is mature
    202         enough. If `False` (default behavior), the system automatically finds
    203         the best configuration, in terms of number of TPU cores, for the model
    204         replication, typically using all available TPU cores. If overwrites as
    205         `True`, force the model replication using single core, i.e., no
    206         replication.
    207     Raises:
    208       Exception: No TPU Found on the given worker.
    209     """
    210     if tf2.enabled():
    211       raise RuntimeError(
    212           'Keras support is now deprecated in support of TPU Strategy. '
    213           'Please follow the distribution strategy guide on tensorflow.org '
    214           'to migrate to the 2.0 supported version.')
    215     else:
    216       logging.warning(
    217           'Keras support is now deprecated in support of TPU Strategy. '
    218           'Please follow the distribution strategy guide on tensorflow.org '
    219           'to migrate to the 2.0 supported version.')
    220     if tpu_cluster_resolver is None:
    221       tpu_cluster_resolver = tpu_cluster_resolver_lib.TPUClusterResolver('')
    222 
    223     metadata = get_tpu_system_metadata(tpu_cluster_resolver)
    224     self._tpu_metadata = metadata
    225     self._tpu_cluster_resolver = tpu_cluster_resolver
    226     self._num_cores = 1 if using_single_core else metadata.num_cores
    227 
    228     # Walk device list to identify TPU worker for enqueue/dequeue operations.
    229     worker_re = re.compile('/job:([^/]+)')
    230     for device in metadata.devices:
    231       if 'TPU:0' in device.name:
    232         self._worker_name = worker_re.search(device.name).group(1)
    233         return
    234     raise Exception('No TPU found on given worker.')
    235 
    236   def _make_assignment_for_model(self, cpu_model):
    237     """Makes a `TPUAssignment` for the passed in `cpu_model`."""
    238     num_cores = self._num_cores
    239     if num_cores > 1 and cpu_model.stateful:
    240       logging.warning(
    241           'Model replication does not currently support stateful models.  '
    242           'Degrading to a single core.')
    243       num_cores = 1
    244 
    245     return TPUAssignment(worker_name=self._worker_name, num_cores=num_cores)
    246 
    247 
    248 class TPUAssignment(object):
    249   """This is object holding TPU resources assignment for the concrete model.
    250 
    251   `TPUDistributionStrategy` is responsible to create the instance of
    252   `TPUAssignment`, so, it can dynamically adjust the `num_cores` to use based on
    253   model and input batch sizes.
    254   """
    255 
    256   def __init__(self, worker_name, num_cores):
    257     self._worker_name = worker_name
    258     self._num_cores = num_cores
    259 
    260   @property
    261   def worker_name(self):
    262     return self._worker_name
    263 
    264   @property
    265   def num_towers(self):
    266     # TODO(xiejw): Support automatically assign num_cores based on inputs.
    267     return self._num_cores
    268 
    269 
    270 class TPUEmbedding(embeddings.Embedding):
    271   """TPU compatible embedding layer.
    272 
    273   The default Keras layer is not TPU compatible.  This layer is a drop-in
    274   replacement: it has the same behavior and will work on CPU and GPU devices.
    275   """
    276 
    277   def build(self, input_shape):
    278     if input_shape[0] is None:
    279       raise ValueError(
    280           'TPUEmbeddings must have a fixed input_length or input shape.')
    281     return super(TPUEmbedding, self).build(input_shape)
    282 
    283   def call(self, inputs):
    284     if K.dtype(inputs) != 'int32':
    285       inputs = math_ops.cast(inputs, 'int32')
    286 
    287     inputs = array_ops.one_hot(inputs, self.input_dim)
    288     return math_ops.tensordot(inputs, self.embeddings, 1)
    289 
    290 
    291 def _cross_replica_concat(tensor, core_id, num_cores, name):
    292   """Concatenate `tensor` across cores.
    293 
    294   Args:
    295     tensor: The tensor to be concatenated. Must be [int32 and float32].
    296     core_id: Tensor indicating the current TPU core.
    297     num_cores: Python int. The total number of TPU cores in the system.
    298     name: The string name to print for debugging.
    299 
    300   Returns:
    301     The same concatenated Tensor on each core.
    302   """
    303 
    304   input_dtype = tensor.dtype
    305   if input_dtype not in [dtypes.bfloat16, dtypes.float32, dtypes.int32]:
    306     raise TypeError('For model replication, only (bfloat16, float32 and int32) '
    307                     'is supported for model outputs and targets. Got {} for '
    308                     '{}.'.format(input_dtype, name))
    309 
    310   batch_size = tensor.shape[0]
    311   mask = math_ops.cast(
    312       math_ops.equal(np.arange(num_cores, dtype=np.int32), core_id),
    313       dtypes.float32)
    314   mask = array_ops.reshape(mask, [num_cores] + [1] * tensor.shape.ndims)
    315   result = mask * math_ops.cast(tensor, dtypes.float32)
    316   local_tensor_with_holes = array_ops.reshape(result,
    317                                               [-1] + result.shape.as_list()[2:])
    318   concat_tensor = tpu_ops.cross_replica_sum(local_tensor_with_holes)
    319   concat_tensor.set_shape((num_cores * batch_size,) + tuple(tensor.shape[1:]))
    320 
    321   if concat_tensor != input_dtype:
    322     concat_tensor = math_ops.cast(concat_tensor, input_dtype)
    323   return concat_tensor
    324 
    325 
    326 class KerasCrossShardOptimizer(keras_optimizers.Optimizer):
    327   """An optimizer that averages gradients across TPU shards."""
    328 
    329   def __init__(self, opt, name='KerasCrossShardOptimizer'):
    330     """Construct a new cross-shard optimizer.
    331 
    332     Args:
    333       opt: An existing `Optimizer` to encapsulate.
    334       name: Optional name prefix for the operations created when applying
    335         gradients. Defaults to "KerasCrossShardOptimizer".
    336 
    337     Raises:
    338       ValueError: If reduction is not a valid cross-shard reduction.
    339     """
    340     super(KerasCrossShardOptimizer, self).__init__()
    341     self._name = name
    342     self._opt = opt
    343     logging.info('KerasCrossShard: %s %s', self._opt, self._opt.weights)
    344 
    345   def get_updates(self, loss, params):
    346     self._opt.get_gradients = self.get_gradients
    347     return self._opt.get_updates(loss, params)
    348 
    349   def get_gradients(self, loss, params):
    350     num_shards = tpu_function.get_tpu_context().number_of_shards
    351     grads = super(KerasCrossShardOptimizer, self).get_gradients(loss, params)
    352     return [tpu_ops.cross_replica_sum(grad) / num_shards for grad in grads]
    353 
    354   def get_weights(self):
    355     return self._opt.get_weights()
    356 
    357   def get_config(self):
    358     return self._opt.get_config()
    359 
    360   # Defer remaining operations to the underlying optimizer
    361   def __getattr__(self, key):
    362     return getattr(self._opt, key)
    363 
    364 
    365 class TPUModelOp(
    366     collections.namedtuple('TPUModelOp', [
    367         'compile_op', 'execute_op', 'infeed_tensors', 'infeed_op', 'outfeed_op'
    368     ])):
    369   pass
    370 
    371 
    372 def _valid_name(tensor_name):
    373   """Return a valid tensor name (strips '/', ':', etc)."""
    374   return re.sub('[^a-zA-Z0-9_-]+', '', tensor_name)
    375 
    376 
    377 def _replicated_optimizer(opt):
    378   """Wrap the optimizer `opt` with CrossShardOptimizer if applicable."""
    379   # Always wrap `opt` with CrossShardOptimizer, even if we are running on a
    380   # single core.  This ensures Keras properly tracks and initializes optimizer
    381   # variables.
    382   if isinstance(opt, keras_optimizers.TFOptimizer):
    383     return tpu_optimizer.CrossShardOptimizer(opt.optimizer)
    384   else:
    385     return KerasCrossShardOptimizer(opt)
    386 
    387 
    388 def _clone_optimizer(optimizer, config=None, worker_name=None):
    389   """Returns a cloned optimizer with the provided optimizer.config or config."""
    390   if not isinstance(optimizer, keras_optimizers.Optimizer):
    391     # In the first call to tpu_model(model), Keras may not have wrapped the TF
    392     # optimizer in the TFOptimizer helper, e.g., the given model isn't compiled
    393     # or optimizer isn't set, and later generated tpu_model compiles with a TF
    394     # optimizer.
    395     return optimizer
    396 
    397   if isinstance(optimizer, keras_optimizers.TFOptimizer):
    398     return keras_optimizers.TFOptimizer(optimizer.optimizer)
    399 
    400   if config is None:
    401     config = optimizer.get_config()
    402   logging.info('Cloning %s %s', optimizer.__class__.__name__, config)
    403   with ops.device(
    404       '%s/device:CPU:0' % ('/job:%s' % worker_name if worker_name else '')):
    405     # Explicitly put optimizer parameter variables on TPU worker.
    406     return optimizer.__class__.from_config(config)
    407 
    408 
    409 class TPURewriteContext(object):
    410   """Prepare the environment for a Keras model during `tpu.rewrite`.
    411 
    412   This overrides the default placeholder behaviour to instead refer to a preset
    413   input mapping.  Placeholders are unsupported in TPU compiled code, and must
    414   be replaced with explicit inputs or values from the infeed queue.
    415 
    416   Instead of explicitly threading inputs all the way through the Keras codebase,
    417   we override the behavior of the placeholder while compiling and inject the
    418   Tensors from the infeed in place of the placeholder.
    419 
    420   Similarly, as we compile a new sub-graph for each unique shape and execution
    421   mode, we need to override the behavior of an embedded `name_scope` call in
    422   the base Keras layer code.  This allows us to re-use the same weights across
    423   many compiles and share a single session/graph.
    424   """
    425 
    426   def __init__(self, input_map):
    427     self._input_map = input_map
    428     self._default_placeholder = None
    429     self._default_name_scope = None
    430 
    431   def __enter__(self):
    432 
    433     def _placeholder(dtype, shape=None, name=None):  # pylint: disable=unused-argument
    434       logging.info('Remapping placeholder for %s', name)
    435       if name in self._input_map:
    436         return self._input_map[name]
    437       else:
    438         logging.info('Default: %s', name)
    439         return self._default_placeholder(dtype, shape, name)
    440 
    441     def _name_scope(name, default_name=None, values=None):
    442       caller_frame = sys._getframe().f_back
    443       caller_obj = caller_frame.f_locals.get('self')
    444       if (caller_obj is not None and
    445           isinstance(caller_obj, base_layer.Layer) and name is not None):
    446         return variable_scope.variable_scope(
    447             name, default_name, values, reuse=variable_scope.AUTO_REUSE)
    448 
    449       return self._default_name_scope(name, default_name, values)
    450 
    451     self._default_placeholder = array_ops.placeholder
    452     self._default_name_scope = ops.name_scope
    453     self._default_make_variable = base_layer_utils.make_variable
    454     self._default_random_normal = random_ops.random_normal
    455     self._default_qr = gen_linalg_ops.qr
    456 
    457     array_ops.placeholder = _placeholder
    458 
    459     # Replace random_ops.random_normal with a dummy function because
    460     # `random_normal` isn't yet implemented on the TPU. Because these
    461     # initialized values are overwritten by the CPU values, this is okay.
    462     def random_normal(shape,
    463                       mean=0.0,
    464                       stddev=1.0,
    465                       dtype=dtypes.float32,
    466                       seed=None,
    467                       name=None):
    468       del mean
    469       del stddev
    470       del seed
    471       return array_ops.zeros(shape, dtype=dtype, name=name)
    472 
    473     random_ops.random_normal = random_normal
    474 
    475     # Replace gen_linalg_ops.qr because QR decomposition is not yet implemented.
    476     # TODO(saeta): Remove qr override once we confirm the qr implementation is
    477     # ok.
    478     # pylint: disable=redefined-builtin
    479     def qr(input, full_matrices=False, name=None):
    480       """Dummy implementation of qr decomposition."""
    481       del full_matrices  # TODO(saeta): Properly handle the full matrix case.
    482       input_shape = input.shape
    483       if len(input_shape) < 2:
    484         raise ValueError('Invalid shape passed to qr: %s' % input_shape)
    485       p = min(input_shape[-1], input_shape[-2])
    486       if len(input_shape) == 2:
    487         q = array_ops.zeros((p, p), name=name)
    488         r = array_ops.zeros(input_shape, name=name)
    489         return (r, q)
    490       elif len(input_shape) == 3:
    491         n = input_shape[0]
    492         q = array_ops.zeros((n, p, p), name=name)
    493         r = array_ops.zeros(input_shape, name=name)
    494         return (r, q)
    495       else:
    496         raise ValueError('Invalid shape passed to qr: %s' % input_shape)
    497 
    498     gen_linalg_ops.qr = qr
    499 
    500     ops.name_scope = _name_scope
    501     base_layer_utils.make_variable = variable_scope.get_variable
    502     logging.info('Overriding default placeholder.')
    503     return
    504 
    505   def __exit__(self, exc_type, exc_val, exc_tb):
    506     array_ops.placeholder = self._default_placeholder
    507     ops.name_scope = self._default_name_scope
    508     base_layer_utils.make_variable = self._default_make_variable
    509     random_ops.random_normal = self._default_random_normal
    510     gen_linalg_ops.qr = self._default_qr
    511 
    512 
    513 class SizedInfeed(
    514     collections.namedtuple('SizedInfeed',
    515                            ['sharded_infeed_tensors', 'infeed_ops'])):
    516   """Represents an instantiation of the infeed ops for a concrete input shape.
    517 
    518   sharded_infeed_tensors: A data structure of Tensors used to represent the
    519     placeholder tensors that must be fed when using feed_dicts.
    520 
    521   infeed_ops: the set of ops that will be run to drive infeed for a single step.
    522   """
    523   pass
    524 
    525 
    526 class TPUInfeedInstance(object):
    527   """TPUInfeedInstance represents the logic to manage feeding in a single step.
    528 
    529   See the comments on the `TPUInfeedManager` for a description for how infeed
    530   is managed.
    531   """
    532 
    533   @abc.abstractmethod
    534   def make_input_specs(self, input_tensors):
    535     """Constructs the infeed_specs for the given Infeed instance.
    536 
    537     Args:
    538       input_tensors: The inputs to the model.
    539 
    540     Returns:
    541       A list of
    542     """
    543     pass
    544 
    545   def make_feed_dict(self, tpu_model_op):
    546     """Constructs a feed_dict for this instance, given the tpu_model_op.
    547 
    548     Args:
    549       tpu_model_op: A `TPUModelOp` representing the TPU Model for this
    550         instance's input spec.
    551 
    552     Returns:
    553       A dictionary to use as the feed_dict of a `session.run` call.
    554     """
    555     pass
    556 
    557 
    558 @six.add_metaclass(abc.ABCMeta)
    559 class TPUInfeedManager(object):
    560   """TPUInfeedManager manages the data infeeding of data to a TPU computation.
    561 
    562   Because there are multiple data sources (e.g. in-memory NumPy arrays,
    563   `tf.data.Dataset`s), we abstract the different logic behind a single
    564   interface: the `TPUInfeedManager`.
    565 
    566   (1) A `TPUFunction` is called with a set of inputs. Based on the inputs,
    567   `TPUFunction` retrieves the corresponding `TPUInfeedManager` (or constructs a
    568   new one if required).
    569 
    570   (2) The `TPUFunction` calls `make_infeed_instance` on the `TPUInfeedManager`
    571   which returns a `TPUInfeedInstance`.
    572 
    573   (3) The `TPUFunction` checks in the shape cache for a pre-compiled instance of
    574   the model based on the returned `input_specs` from `TPUInfeedInstance`.
    575 
    576   (4) [Optional.] If the model has not already been instantiated for the given
    577   input spec, the `TPUFunction` compiles the model for the input spec (using the
    578   `TPUInfeedManager`).
    579 
    580   (5) The `TPUInfeedInstance` constructs the session.run's feed_dict given the
    581   compiled model instance corresponding to its shape.
    582   """
    583 
    584   @abc.abstractmethod
    585   def make_infeed_instance(self, inputs):
    586     """Given a single step's input, construct a `TPUInfeedInstance`.
    587 
    588     Args:
    589       inputs: The inputs to a given step.
    590 
    591     Returns:
    592       A subclass of `TPUInfeedInstance`.
    593     """
    594     pass
    595 
    596   @abc.abstractmethod
    597   def build_infeed_from_input_specs(self, input_specs, execution_mode):
    598     """For a given input specification (size, type), construct the infeed ops.
    599 
    600     This is called only once for a given input specification and builds the
    601     graph ops. It does not have a pointer to the actual infeed data.
    602 
    603     Args:
    604       input_specs: TODO(saeta): Document me!
    605       execution_mode: TODO(saeta): Document me!
    606 
    607     Returns:
    608       A `SizedInfeed` instance.
    609     """
    610     pass
    611 
    612 
    613 class TPUNumpyInfeedManager(TPUInfeedManager):
    614   """TPU Infeed manager for Numpy inputs."""
    615 
    616   class NumpyInfeedInstance(TPUInfeedInstance):
    617     """Infeed instance for Numpy inputs."""
    618 
    619     def __init__(self, sharded_inputs):
    620       self._sharded_inputs = sharded_inputs
    621 
    622     def make_input_specs(self, input_tensors):
    623       # Compute an input specification (used to generate infeed enqueue and
    624       # dequeue operations).  We use the shape from our input array and the
    625       # dtype from our model.  A user may pass in a float64 for a float32
    626       # input: for model compatibility we still must generate a float32 infeed.
    627       input_specs = []
    628       # We use the shape and dtype from the first shard to compute the input
    629       # metadata (`input_specs`); all replicas have the same type and shape.
    630       for tensor, ary in zip(input_tensors, self._sharded_inputs[0]):
    631         input_specs.append(
    632             tensor_spec.TensorSpec(ary.shape, tensor.dtype,
    633                                    _valid_name(tensor.name)))
    634 
    635       return input_specs
    636 
    637     def make_feed_dict(self, tpu_model_op):
    638       infeed_dict = {}
    639       for infeed_tensors, inputs in zip(tpu_model_op.infeed_tensors,
    640                                         self._sharded_inputs):
    641         for tensor, value in zip(infeed_tensors, inputs):
    642           infeed_dict[tensor] = value
    643       return infeed_dict
    644 
    645   def __init__(self, tpu_assignment):
    646     self._tpu_assignment = tpu_assignment
    647 
    648   def _split_tensors(self, inputs):
    649     """Split input data across shards.
    650 
    651     Each input is sliced along the batch axis.
    652 
    653     Args:
    654       inputs: List of Numpy arrays to run on the TPU.
    655 
    656     Returns:
    657       List of lists containing the input to feed to each TPU shard.
    658     """
    659     if self._tpu_assignment.num_towers == 1:
    660       return [inputs]
    661 
    662     batch_size = inputs[0].shape[0]
    663     assert batch_size % self._tpu_assignment.num_towers == 0, (
    664         'batch_size must be divisible by the number of TPU cores in use (%s '
    665         'vs %s)' % (batch_size, self._tpu_assignment.num_towers))
    666     shard_size = batch_size // self._tpu_assignment.num_towers
    667     input_list = []
    668     for index in range(self._tpu_assignment.num_towers):
    669       shard_inputs = [
    670           x[index * shard_size:(index + 1) * shard_size] for x in inputs
    671       ]
    672       input_list.append(shard_inputs)
    673     return input_list
    674 
    675   def make_infeed_instance(self, inputs):
    676     sharded_inputs = self._split_tensors(inputs)
    677     return self.NumpyInfeedInstance(sharded_inputs)
    678 
    679   def build_infeed_from_input_specs(self, input_specs, execution_mode):
    680     infeed_op = []
    681     shard_infeed_tensors = []
    682 
    683     for shard_id in range(self._tpu_assignment.num_towers):
    684       with ops.device(
    685           '/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
    686         infeed_tensors = []
    687         with ops.device('/device:TPU:%d' % shard_id):
    688           for spec in input_specs:
    689             # Construct placeholders for each of the inputs.
    690             infeed_tensors.append(
    691                 array_ops.placeholder(
    692                     dtype=spec.dtype,
    693                     shape=spec.shape,
    694                     name='infeed-enqueue-%s-%d' % (spec.name, shard_id)))
    695         shard_infeed_tensors.append(infeed_tensors)
    696 
    697         infeed_op.append(
    698             tpu_ops.infeed_enqueue_tuple(
    699                 infeed_tensors, [spec.shape for spec in input_specs],
    700                 name='infeed-enqueue-%s-%d' % (execution_mode, shard_id),
    701                 device_ordinal=shard_id))
    702     return SizedInfeed(
    703         infeed_ops=infeed_op, sharded_infeed_tensors=shard_infeed_tensors)
    704 
    705 
    706 class TPUDatasetInfeedManager(TPUInfeedManager):
    707   """Manages infeed for a `tf.data.Dataset` into a TPU computation.
    708 
    709   """
    710 
    711   class DatasetInfeedInstance(TPUInfeedInstance):
    712     """An instance of the TPU infeed."""
    713 
    714     def __init__(self, input_specs):
    715       self._input_specs = input_specs
    716 
    717     def make_input_specs(self, input_tensors):
    718       # TODO(saeta): Do error checking here!
    719       return self._input_specs
    720 
    721     def make_feed_dict(self, tpu_model_op):
    722       # TODO(saeta): Verify tpu_model_op is as expected!
    723       return {}
    724 
    725   # pylint: disable=redefined-outer-name
    726   def __init__(self, dataset, tpu_assignment, mode):
    727     """Constructs a TPUDatasetInfeedManager.
    728 
    729     Args:
    730       dataset: A `tf.data.Dataset` to infeed.
    731       tpu_assignment: The `TPUAssignment` used to configure the
    732         Keras TPU model.
    733       mode: ModeKeys enum.
    734     """
    735     self._verify_dataset_shape(dataset)
    736 
    737     self._dataset = dataset
    738     self._tpu_assignment = tpu_assignment
    739     dataset_output_shapes = dataset_ops.get_legacy_output_shapes(dataset)
    740     dummy_x_shape = dataset_output_shapes[0].as_list()
    741     dummy_x_shape[0] *= tpu_assignment.num_towers
    742     dummy_y_shape = dataset_output_shapes[1].as_list()
    743     dummy_y_shape[0] *= tpu_assignment.num_towers
    744     self._iterator = dataset_ops.make_initializable_iterator(dataset)
    745     K.get_session().run(self._iterator.initializer)
    746 
    747     self._get_next_ops = []
    748     ctrl_deps = []
    749     for i in range(tpu_assignment.num_towers):
    750       with ops.control_dependencies(ctrl_deps):  # Ensure deterministic
    751         # TODO(saeta): Ensure correct placement!
    752         get_next_op = self._iterator.get_next()
    753         self._get_next_ops.append(get_next_op)
    754         ctrl_deps.extend(get_next_op)
    755 
    756     # Use dummy numpy inputs for the rest of Keras' shape checking. We
    757     # intercept them when building the model.
    758     dataset_output_types = dataset_ops.get_legacy_output_types(dataset)
    759     self._dummy_x = np.zeros(
    760         dummy_x_shape, dtype=dataset_output_types[0].as_numpy_dtype)
    761     self._dummy_y = np.zeros(
    762         dummy_y_shape, dtype=dataset_output_types[1].as_numpy_dtype)
    763 
    764     input_specs = []
    765     iterator_output_shapes = dataset_ops.get_legacy_output_shapes(
    766         self._iterator)
    767     iterator_output_types = dataset_ops.get_legacy_output_types(self._iterator)
    768     if isinstance(iterator_output_shapes, tuple):
    769       assert isinstance(iterator_output_types, tuple)
    770       assert len(iterator_output_shapes) == len(iterator_output_types)
    771       for i in range(len(iterator_output_shapes)):
    772         spec = tensor_spec.TensorSpec(iterator_output_shapes[i],
    773                                       iterator_output_types[i])
    774         input_specs.append(spec)
    775     elif isinstance(iterator_output_shapes, tensor_shape.TensorShape):
    776       spec = tensor_spec.TensorSpec(iterator_output_shapes,
    777                                     iterator_output_types)
    778       input_specs.append(spec)
    779 
    780     # Pre-process the inputs and get_next_ops before caching.
    781     input_specs, self._get_next_ops = (
    782         _inject_tpu_inputs_for_dataset(
    783             tpu_assignment, mode, input_specs, self._get_next_ops))
    784     self._infeed_instance = self.DatasetInfeedInstance(input_specs)
    785 
    786   def _verify_dataset_shape(self, dataset):
    787     """Verifies a dataset is of an appropriate shape for TPUs."""
    788     dataset_output_shapes = dataset_ops.get_legacy_output_shapes(dataset)
    789     dataset_output_classes = dataset_ops.get_legacy_output_classes(dataset)
    790     if not isinstance(dataset, dataset_ops.DatasetV2):
    791       raise ValueError('The function passed as the `x` parameter did not '
    792                        'return a `tf.data.Dataset`.')
    793     if not isinstance(dataset_output_classes, tuple):
    794       raise ValueError('The dataset must return a tuple of tf.Tensors, '
    795                        'instead it returns: %s' % dataset_output_classes)
    796     if len(dataset_output_classes) != 2:
    797       raise ValueError('The dataset must return a 2-element tuple, got '
    798                        '%s output classes instead.' % (dataset_output_classes,))
    799     for i, cls in enumerate(dataset_output_classes):
    800       if cls != ops.Tensor:
    801         raise ValueError('The dataset returned a non-Tensor type (%s) at '
    802                          'index %d.' % (cls, i))
    803     for i, shape in enumerate(dataset_output_shapes):
    804       if not shape:
    805         raise ValueError('The dataset returns a scalar tensor in '
    806                          'tuple index %d. Did you forget to batch? '
    807                          '(Output shapes: %s).' % (i, dataset_output_shapes))
    808       for j, dim in enumerate(shape):
    809         if dim.value is None:
    810           if j == 0:
    811             hint = (' Hint: did you use `ds.batch(BATCH_SIZE, '
    812                     'drop_remainder=True)`?')
    813           else:
    814             hint = ''
    815           raise ValueError(
    816               'The Keras-TPU integration for `tf.data` '
    817               'currently requires static shapes. The provided '
    818               'dataset only has a partially defined shape. '
    819               '(Dimension %d of output tensor %d is not statically known '
    820               'for output shapes: %s.%s)' % (j, i, dataset_output_shapes, hint))
    821 
    822   @property
    823   def dummy_x(self):
    824     return self._dummy_x
    825 
    826   @property
    827   def dummy_y(self):
    828     return self._dummy_y
    829 
    830   def make_infeed_instance(self, inputs):
    831     # TODO(saeta): Verify inputs is as expected.
    832     return self._infeed_instance
    833 
    834   def build_infeed_from_input_specs(self, input_specs, execution_mode):
    835     shard_infeed_tensors = self._get_next_ops
    836     assert len(shard_infeed_tensors) == self._tpu_assignment.num_towers
    837     infeed_ops = []
    838     for shard_id in range(self._tpu_assignment.num_towers):
    839       with ops.device(
    840           '/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
    841         infeed_ops.append(
    842             tpu_ops.infeed_enqueue_tuple(
    843                 shard_infeed_tensors[shard_id],
    844                 [spec.shape for spec in input_specs],
    845                 name='infeed-enqueue-%s-%d' % (execution_mode, shard_id),
    846                 device_ordinal=shard_id))
    847     return SizedInfeed(
    848         infeed_ops=infeed_ops, sharded_infeed_tensors=shard_infeed_tensors)
    849 
    850 
    851 def _inject_tpu_inputs_for_dataset(tpu_assignment, mode,
    852                                    input_specs, get_next_ops):
    853   """Append core information to the set of dataset inputs."""
    854   # This is used during compilation to identify the current TPU core and enable
    855   # concatenation operations across cores.
    856   if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL]:
    857     return input_specs, get_next_ops
    858 
    859   # Dataset inputs operate on per core basis.
    860   per_core_batch_size = input_specs[0].shape.as_list()[0]
    861 
    862   # Insert, at head, the tensor for core_id.
    863   assert len(get_next_ops) == tpu_assignment.num_towers
    864   for i in range(tpu_assignment.num_towers):
    865     core_id_constant = constant_op.constant(
    866         np.array([i] * per_core_batch_size).astype('int32'),
    867         dtype=dtypes.int32,
    868         name='cord_id_constant')
    869     get_next_ops[i] = [core_id_constant] + list(get_next_ops[i])
    870 
    871   # Insert the input spec at head also.
    872   input_specs = [tensor_spec.TensorSpec([per_core_batch_size], dtypes.int32)
    873                 ] + input_specs
    874 
    875   return input_specs, get_next_ops
    876 
    877 
    878 def _inject_tpu_inputs_for_infeed(tpu_assignment, mode,
    879                                   core_id_place_holder, input_tensors, inputs):
    880   """Append core information to the set of inputs."""
    881   # This is used during compilation to identify the current TPU core and enable
    882   # concatenation operations across cores.
    883   if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL]:
    884     return input_tensors, inputs
    885 
    886   # Puts a place holder in input spec.
    887   input_tensors = [core_id_place_holder] + input_tensors
    888 
    889   # Now fill the core id. For `num_cores` = 2, `batch_size` = 8, we fill the
    890   # core id inputs as [0, 0, 0, 0, 1, 1, 1, 1], so each core sees its core id
    891   # (duplicated).
    892   num_cores = tpu_assignment.num_towers
    893   per_core_batch_size = inputs[0].shape[0] // num_cores
    894   core_ids = np.arange(num_cores).repeat(per_core_batch_size)
    895   inputs = [core_ids] + inputs
    896   return input_tensors, inputs
    897 
    898 
    899 def _read_tpu_coreid_from_infeed(mode, infeed_tensors):
    900   """Popping out the core ids from infeed."""
    901   if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL]:
    902     return None, infeed_tensors
    903 
    904   if len(infeed_tensors) <= 1:
    905     raise RuntimeError(
    906         'The infeed tensors on TPU core has only {} tensors. '
    907         'This is not expected. Please report a bug.\nTensors: {}'.format(
    908             len(infeed_tensors), infeed_tensors))
    909 
    910   core_id = infeed_tensors[0][0]  # Pop out the scalar version.
    911   rest = infeed_tensors[1:]
    912   return core_id, rest
    913 
    914 
    915 class TPUFunction(object):
    916   """K.function compatible interface for invoking a TPU compiled function.
    917 
    918   Recompilation is triggered on-demand for each set of new inputs shapes: the
    919   results are cached for future execution.  We expect most computations will
    920   be dominated by a standard batch-size, followed by a straggler batch for
    921   the end of training or evaluation.
    922 
    923   All `inputs` and `outputs` will be loaded via the infeed and outfeed queues
    924   instead of being injected as `feed_dict` items or fetches.
    925   """
    926 
    927   def __init__(self, model, execution_mode, tpu_assignment):
    928     self.model = model
    929     self.execution_mode = execution_mode
    930     self._tpu_assignment = tpu_assignment
    931     self._compilation_cache = {}
    932     self._cloned_model = None
    933     self._cloned_optimizer = None
    934     # Create a placeholder for the TPU core ID. Cache the placeholder to avoid
    935     # modifying the graph for every batch.
    936     self._core_id_place_holder = array_ops.placeholder(
    937         dtype=dtypes.int32, shape=[1], name='core_id')
    938 
    939   def _specialize_model(self, input_specs, infeed_manager):
    940     """Specialize `self.model` (a Keras model) for the given input shapes."""
    941     # Re-create our input and output layers inside our subgraph.  They will be
    942     # attached to the true computation when we clone our model in `tpu_fn`.
    943     K.set_learning_phase(self.execution_mode == model_fn_lib.ModeKeys.TRAIN)
    944 
    945     # functools.partial and callable objects are not supported by tpu.rewrite
    946     def _model_fn():
    947       """Compute fit/eval/predict for the TPU."""
    948       is_training = self.execution_mode == model_fn_lib.ModeKeys.TRAIN
    949       is_test = self.execution_mode == model_fn_lib.ModeKeys.EVAL
    950       is_predict = self.execution_mode == model_fn_lib.ModeKeys.PREDICT
    951 
    952       # During train/eval, we infeed our features as well as labels.
    953       if is_training or is_test:
    954         infeed_layers = self.model._input_layers + self.model._output_layers
    955       else:
    956         infeed_layers = self.model._input_layers
    957 
    958       # Generate our infeed operation to read features & labels.
    959       infeed_tensors = tpu_ops.infeed_dequeue_tuple(
    960           dtypes=[spec.dtype for spec in input_specs],
    961           shapes=[spec.shape for spec in input_specs],
    962           name='infeed-%s' % self.execution_mode)
    963 
    964       core_id, infeed_tensors = (
    965           _read_tpu_coreid_from_infeed(
    966               mode=self.execution_mode, infeed_tensors=infeed_tensors))
    967 
    968       assert len(infeed_tensors) == len(infeed_layers), (
    969           'Infeed inputs did not match model: %s vs %s' % (infeed_layers,
    970                                                            infeed_tensors))
    971 
    972       tpu_targets = []
    973       tpu_input_map = {}
    974 
    975       # Sort infeed outputs into inputs and labels for calling our Keras model.
    976       for tensor, layer in zip(infeed_tensors, infeed_layers):
    977         if layer in self.model._input_layers:
    978           tpu_input_map[layer.name] = tensor
    979         if layer in self.model._output_layers:
    980           tpu_targets.append(tensor)
    981 
    982       # Clone our CPU model, running within the TPU device context.
    983       #
    984       # We use the id of the original model as a key to avoid weight collisions
    985       # (if a user re-runs the same model multiple times, in e.g. Colab).
    986       with TPURewriteContext(tpu_input_map):
    987         with variable_scope.variable_scope('tpu_%s' % id(self.model)):
    988           with keras_tpu_variables.replicated_scope(
    989               self._tpu_assignment.num_towers):
    990             if not self._cloned_optimizer:
    991               self._cloned_optimizer = _clone_optimizer(
    992                   self.model.cpu_optimizer,
    993                   worker_name=self._tpu_assignment.worker_name)
    994 
    995             self._cloned_model = models.clone_model(self.model)
    996 
    997             # When running on more than one core, concatenate outputs at the end
    998             # of processing. In backprop stage, the gradients will be
    999             # calculated according to the local inputs as gradient of
   1000             # cross-replica-concat being zero for any outputs other than those
   1001             # from mlocal core so the loss calculation is identical.
   1002             num_towers = self.model._tpu_assignment.num_towers
   1003             if num_towers > 1 and (is_training or is_test):
   1004               new_outputs = [
   1005                   _cross_replica_concat(
   1006                       o, core_id, num_towers,
   1007                       name='model output ({})'.format(o.name))
   1008                   for o in self._cloned_model.outputs
   1009               ]
   1010               # Recast all low precision outputs back to float32 since we only
   1011               # casted the inputs to bfloat16 and not targets. This is done so
   1012               # that we can preserve precision when calculating the loss value.
   1013               if new_outputs and new_outputs[0].dtype == dtypes.bfloat16:
   1014                 new_outputs = [
   1015                     math_ops.cast(o, dtypes.float32) for o in new_outputs]
   1016               self._cloned_model.outputs = new_outputs
   1017               tpu_targets = [
   1018                   _cross_replica_concat(
   1019                       tensor,
   1020                       core_id,
   1021                       num_towers,
   1022                       name='model target ({})'.format(tensor.name))
   1023                   for tensor in tpu_targets
   1024               ]
   1025 
   1026           if is_training or is_test:
   1027             with variable_scope.variable_scope(
   1028                 'metrics', reuse=variable_scope.AUTO_REUSE):
   1029               self._cloned_model.compile(
   1030                   optimizer=_replicated_optimizer(self._cloned_optimizer),
   1031                   loss=self.model.loss,
   1032                   loss_weights=self.model.loss_weights,
   1033                   metrics=metrics_module.clone_metrics(
   1034                       self.model._compile_metrics),
   1035                   weighted_metrics=metrics_module.clone_metrics(
   1036                       self.model._compile_weighted_metrics),
   1037                   target_tensors=tpu_targets,
   1038               )
   1039 
   1040       # Compute our outfeed depending on the execution mode
   1041       if is_training:
   1042         if not isinstance(self._cloned_optimizer, keras_optimizers.TFOptimizer):
   1043           # For Keras optimizer, we try to place the variable weights on the TPU
   1044           # device. Keras creates optimizer variables (e.g. momentum values for
   1045           # the Momentum optimizer) when _make_train_function is invoked.
   1046           with keras_tpu_variables.replicated_variable_for_optimizer(
   1047               self._tpu_assignment.num_towers):
   1048             self._cloned_model._make_train_function()
   1049         else:
   1050           self._cloned_model._make_train_function()
   1051 
   1052         self._outfeed_spec = [
   1053             tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name)
   1054             for tensor in self._cloned_model.train_function.outputs
   1055         ]
   1056         return [
   1057             self._cloned_model.train_function.updates_op,
   1058             tpu_ops.outfeed_enqueue_tuple(
   1059                 self._cloned_model.train_function.outputs,
   1060                 name='outfeed-enqueue-train')
   1061         ]
   1062       elif is_test:
   1063         self._cloned_model._make_test_function()
   1064         self._outfeed_spec = [
   1065             tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name)
   1066             for tensor in self._cloned_model.test_function.outputs
   1067         ]
   1068         return [
   1069             tpu_ops.outfeed_enqueue_tuple(
   1070                 self._cloned_model.test_function.outputs,
   1071                 name='outfeed-enqueue-test')
   1072         ]
   1073       elif is_predict:
   1074         self._cloned_model._make_predict_function()
   1075         self._outfeed_spec = [
   1076             tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name)
   1077             for tensor in self._cloned_model.predict_function.outputs
   1078         ]
   1079         return [
   1080             tpu_ops.outfeed_enqueue_tuple(
   1081                 self._cloned_model.predict_function.outputs,
   1082                 name='outfeed-enqueue-predict',
   1083             )
   1084         ]
   1085       else:
   1086         assert False, 'Unexpected execution mode: %s' % self.execution_mode
   1087 
   1088     # Capture outfeed metadata computed during the rewrite.
   1089     self._outfeed_spec = None
   1090 
   1091     # Generate out TPU operations using `tpu.split_compile_and_replicate`.
   1092     # `compile_op` can be used to test the TPU model compiles before execution.
   1093     # `execute op` replicates `_model_fn` `num_replicas` times, with each shard
   1094     # running on a different logical core.
   1095     compile_op, execute_op = tpu.split_compile_and_replicate(
   1096         _model_fn, inputs=[[] for _ in range(self._tpu_assignment.num_towers)])
   1097 
   1098     # Generate CPU side operations to enqueue features/labels and dequeue
   1099     # outputs from the model call.
   1100     sized_infeed = infeed_manager.build_infeed_from_input_specs(
   1101         input_specs, self.execution_mode)
   1102     # Build output ops.
   1103     outfeed_op = []
   1104     for shard_id in range(self._tpu_assignment.num_towers):
   1105       with ops.device(
   1106           '/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
   1107         outfeed_op.extend(
   1108             tpu_ops.outfeed_dequeue_tuple(
   1109                 dtypes=[spec.dtype for spec in self._outfeed_spec],
   1110                 shapes=[spec.shape for spec in self._outfeed_spec],
   1111                 name='outfeed-dequeue-%s-%d' % (self.execution_mode, shard_id),
   1112                 device_ordinal=shard_id))
   1113 
   1114     return TPUModelOp(
   1115         compile_op,
   1116         execute_op,
   1117         infeed_tensors=sized_infeed.sharded_infeed_tensors,
   1118         infeed_op=sized_infeed.infeed_ops,
   1119         outfeed_op=outfeed_op)
   1120 
   1121   def _test_model_compiles(self, tpu_model_ops):
   1122     """Verifies that the given TPUModelOp can be compiled via XLA."""
   1123     logging.info('Started compiling')
   1124     start_time = time.time()
   1125 
   1126     result = K.get_session().run(tpu_model_ops.compile_op)
   1127     proto = tpu_compilation_result.CompilationResultProto()
   1128     proto.ParseFromString(result)
   1129     if proto.status_error_message:
   1130       raise RuntimeError('Compilation failed: {}'.format(
   1131           proto.status_error_message))
   1132 
   1133     end_time = time.time()
   1134     logging.info('Finished compiling. Time elapsed: %s secs',
   1135                  end_time - start_time)
   1136 
   1137   def _lookup_infeed_manager(self, inputs):
   1138     """Return an existing manager, or construct a new InfeedManager for inputs.
   1139 
   1140     _lookup_infeed_manager will return an existing InfeedManager if one has been
   1141     previously assigned for this model and input. If not, it will construct a
   1142     new TPUNumpyInfeedManager.
   1143 
   1144     Args:
   1145       inputs: A NumPy input to the model.
   1146 
   1147     Returns:
   1148       A `TPUInfeedManager` object to manage infeeds for this input.
   1149     """
   1150     if inputs is None:
   1151       return None
   1152 
   1153     for x, mgr in self.model._numpy_to_infeed_manager_list:
   1154       if inputs[0] is x:
   1155         return mgr
   1156 
   1157     return TPUNumpyInfeedManager(self.model._tpu_assignment)
   1158 
   1159   def _tpu_model_ops_for_input_specs(self, input_specs, infeed_manager):
   1160     """Looks up the corresponding `TPUModelOp` for a given `input_specs`.
   1161 
   1162     It instantiates a new copy of the model for each unique input shape.
   1163 
   1164     Args:
   1165       input_specs: The specification of the inputs to train on.
   1166       infeed_manager: The infeed manager responsible for feeding in data.
   1167 
   1168     Returns:
   1169       A `TPUModelOp` instance that can be used to execute a step of the model.
   1170     """
   1171     if input_specs is None or infeed_manager is None:
   1172       # Note: this condition is possible during the prologue or epilogue of the
   1173       # pipelined loop.
   1174       return None
   1175 
   1176     # XLA requires every operation in the graph has a fixed shape.  To
   1177     # handle varying batch sizes we recompile a new sub-graph for each
   1178     # unique input shape.
   1179     shape_key = tuple([tuple(spec.shape.as_list()) for spec in input_specs])
   1180     if shape_key not in self._compilation_cache:
   1181       logging.info(
   1182           'New input shapes; (re-)compiling: mode=%s '
   1183           '(# of cores %d), %s', self.execution_mode,
   1184           self._tpu_assignment.num_towers, input_specs)
   1185       new_tpu_model_ops = self._specialize_model(input_specs,
   1186                                                  infeed_manager)
   1187       self._compilation_cache[shape_key] = new_tpu_model_ops
   1188       self._test_model_compiles(new_tpu_model_ops)
   1189 
   1190     return self._compilation_cache[shape_key]
   1191 
   1192   def _construct_input_tensors_and_inputs(self, inputs):
   1193     """Returns input tensors and numpy array inputs corresponding to `inputs`.
   1194 
   1195     Args:
   1196       inputs: NumPy inputs.
   1197 
   1198     Returns:
   1199       A tuple of `input_tensors`, and `inputs`.
   1200     """
   1201     if inputs is None:
   1202       # Note: this condition is possible during the prologue or epilogue of the
   1203       # pipelined loop.
   1204       return None, None
   1205 
   1206     if isinstance(inputs[-1], int):
   1207       # Remove the learning_phase flag at the end. We currently hard code the
   1208       # learning_phase in TPUFunction.
   1209       inputs = inputs[:-1]
   1210 
   1211     if (self.execution_mode == model_fn_lib.ModeKeys.TRAIN or
   1212         self.execution_mode == model_fn_lib.ModeKeys.EVAL):
   1213       # Strip sample weight from inputs.
   1214       input_tensors = self.model._feed_inputs + self.model._feed_targets
   1215     else:
   1216       input_tensors = self.model._feed_inputs
   1217 
   1218     inputs = inputs[:len(input_tensors)]
   1219     input_tensors, inputs = (
   1220         _inject_tpu_inputs_for_infeed(
   1221             self._tpu_assignment, self.execution_mode,
   1222             self._core_id_place_holder, input_tensors, inputs))
   1223     return input_tensors, inputs
   1224 
   1225   def _process_outputs(self, outfeed_outputs):
   1226     """Processes the outputs of a model function execution.
   1227 
   1228     Args:
   1229       outfeed_outputs: The sharded outputs of the TPU computation.
   1230 
   1231     Returns:
   1232       The aggregated outputs of the TPU computation to be used in the rest of
   1233       the model execution.
   1234     """
   1235     # TODO(xiejw): Decide how to reduce outputs, or discard all but first.
   1236     if self.execution_mode == model_fn_lib.ModeKeys.PREDICT:
   1237       outputs = [[] for _ in range(len(self._outfeed_spec))]
   1238       outputs_per_replica = len(self._outfeed_spec)
   1239 
   1240       for i in range(self._tpu_assignment.num_towers):
   1241         output_group = outfeed_outputs[i * outputs_per_replica:(i + 1) *
   1242                                        outputs_per_replica]
   1243         for j in range(outputs_per_replica):
   1244           outputs[j].append(output_group[j])
   1245 
   1246       return [np.concatenate(group) for group in outputs]
   1247     else:
   1248       return outfeed_outputs[:len(outfeed_outputs) //
   1249                              self._tpu_assignment.num_towers]
   1250 
   1251   def __call__(self, inputs):
   1252     """__call__ executes the function on the computational hardware.
   1253 
   1254     It handles executing infeed, and preprocessing in addition to executing the
   1255     model on the TPU hardware.
   1256 
   1257     Note: `__call__` has a sibling method `pipeline_run` which performs the same
   1258     operations, but with software pipelining.
   1259 
   1260     Args:
   1261       inputs: The inputs to use to train.
   1262 
   1263     Returns:
   1264       The output of the computation for the given mode it is executed in.
   1265 
   1266     Raises:
   1267       RuntimeError: If there is an inappropriate use of the function.
   1268     """
   1269     assert isinstance(inputs, list)
   1270 
   1271     infeed_manager = self._lookup_infeed_manager(inputs)
   1272     input_tensors, inputs = self._construct_input_tensors_and_inputs(inputs)
   1273     infeed_instance = infeed_manager.make_infeed_instance(inputs)
   1274     del inputs  # To avoid accident usage.
   1275     input_specs = infeed_instance.make_input_specs(input_tensors)
   1276     tpu_model_ops = self._tpu_model_ops_for_input_specs(input_specs,
   1277                                                         infeed_manager)
   1278     infeed_dict = infeed_instance.make_feed_dict(tpu_model_ops)
   1279 
   1280     # Initialize our TPU weights on the first compile.
   1281     self.model._initialize_weights(self._cloned_model)
   1282 
   1283     _, _, outfeed_outputs = K.get_session().run([
   1284         tpu_model_ops.infeed_op, tpu_model_ops.execute_op,
   1285         tpu_model_ops.outfeed_op
   1286     ], infeed_dict)
   1287     return self._process_outputs(outfeed_outputs)
   1288 
   1289   def pipeline_run(self, cur_step_inputs, next_step_inputs):
   1290     """pipeline_run executes the function on the computational hardware.
   1291 
   1292     pipeline_run performs the same computation as __call__, however it runs the
   1293     infeed in a software pipelined fashion compared to the on-device execution.
   1294 
   1295     Note: it is the responsibility of the caller to call `pipeline_run` in the
   1296     following sequence:
   1297       - Once with `cur_step_inputs=None` and `next_step_inputs=list(...)`
   1298       - `n` times with `cur_step_inputs` and `next_step_inputs` as `list`s
   1299       - Once with `cur_step_inputs=list(...)` and `next_step_inputs=None`
   1300     Additionally, it is the responsibility of the caller to pass
   1301     `next_step_inputs` as `cur_step_inputs` on the next invocation of
   1302     `pipeline_run`.
   1303 
   1304     Args:
   1305       cur_step_inputs: The current step's inputs.
   1306       next_step_inputs: The next step's inputs.
   1307 
   1308     Returns:
   1309       The output of the computation for the given mode it is executed in.
   1310 
   1311     Raises:
   1312       RuntimeError: If there is an inappropriate use of the function.
   1313     """
   1314     # Software pipelined case.
   1315     next_step_infeed_manager = self._lookup_infeed_manager(next_step_inputs)
   1316     cur_step_infeed_manager = self._lookup_infeed_manager(cur_step_inputs)
   1317 
   1318     if (next_step_infeed_manager is not None and
   1319         cur_step_infeed_manager is not None):
   1320       assert type(next_step_infeed_manager) is type(cur_step_infeed_manager)
   1321 
   1322     next_input_tensors, next_step_inputs = (
   1323         self._construct_input_tensors_and_inputs(next_step_inputs))
   1324     cur_input_tensors, cur_step_inputs = (
   1325         self._construct_input_tensors_and_inputs(cur_step_inputs))
   1326 
   1327     cur_infeed_instance = None
   1328     if cur_step_infeed_manager:
   1329       cur_infeed_instance = cur_step_infeed_manager.make_infeed_instance(
   1330           cur_step_inputs)
   1331     next_infeed_instance = None
   1332     if next_step_infeed_manager:
   1333       next_infeed_instance = next_step_infeed_manager.make_infeed_instance(
   1334           next_step_inputs)
   1335 
   1336     del cur_step_inputs  # Avoid accidental re-use.
   1337     del next_step_inputs  # Avoid accidental re-use.
   1338 
   1339     cur_tpu_model_ops = None
   1340     next_tpu_model_ops = None
   1341     infeed_dict = None
   1342 
   1343     if cur_infeed_instance and cur_input_tensors and cur_step_infeed_manager:
   1344       cur_input_specs = cur_infeed_instance.make_input_specs(cur_input_tensors)
   1345       cur_tpu_model_ops = self._tpu_model_ops_for_input_specs(
   1346           cur_input_specs, cur_step_infeed_manager)
   1347 
   1348     if (next_infeed_instance and next_input_tensors and
   1349         next_step_infeed_manager):
   1350       next_input_specs = next_infeed_instance.make_input_specs(
   1351           next_input_tensors)
   1352       next_tpu_model_ops = self._tpu_model_ops_for_input_specs(
   1353           next_input_specs, next_step_infeed_manager)
   1354       infeed_dict = next_infeed_instance.make_feed_dict(next_tpu_model_ops)
   1355 
   1356     # Initialize our TPU weights on the first compile.
   1357     self.model._initialize_weights(self._cloned_model)
   1358 
   1359     if next_tpu_model_ops and cur_tpu_model_ops:
   1360       _, _, outfeed_outputs = K.get_session().run([
   1361           next_tpu_model_ops.infeed_op, cur_tpu_model_ops.execute_op,
   1362           cur_tpu_model_ops.outfeed_op
   1363       ], infeed_dict)
   1364       return self._process_outputs(outfeed_outputs)
   1365 
   1366     if cur_tpu_model_ops:
   1367       _, outfeed_outputs = K.get_session().run(
   1368           [cur_tpu_model_ops.execute_op, cur_tpu_model_ops.outfeed_op])
   1369       return self._process_outputs(outfeed_outputs)
   1370 
   1371     if next_tpu_model_ops:
   1372       K.get_session().run(next_tpu_model_ops.infeed_op, infeed_dict)
   1373       return None
   1374     raise RuntimeError('Internal error: both current & next tpu_model_ops '
   1375                        'were None')
   1376 
   1377 
   1378 class KerasTPUModel(models.Model):
   1379   """TPU compatible Keras model wrapper."""
   1380 
   1381   def __init__(self, cpu_model, strategy):
   1382     super(models.Model, self).__init__(  # pylint: disable=bad-super-call
   1383         inputs=cpu_model.inputs,
   1384         outputs=cpu_model.outputs,
   1385         name=cpu_model.name,
   1386     )
   1387     if tf2.enabled():
   1388       raise RuntimeError(
   1389           'Keras support is now deprecated in support of TPU Strategy. '
   1390           'Please follow the distribution strategy guide on tensorflow.org '
   1391           'to migrate to the 2.0 supported version.')
   1392     else:
   1393       logging.warning(
   1394           'Keras support is now deprecated in support of TPU Strategy. '
   1395           'Please follow the distribution strategy guide on tensorflow.org '
   1396           'to migrate to the 2.0 supported version.')
   1397     # Create a mapping from numpy arrays to infeed managers.
   1398     # Note: uses a list of tuples instead of a map because numpy arrays are
   1399     # not hashable.
   1400     self._numpy_to_infeed_manager_list = []
   1401 
   1402     # Add distribution specific arguments since we don't call the Model init.
   1403     self._distribution_strategy = None
   1404     self._compile_distribution = None
   1405 
   1406     self.predict_function = None
   1407     self.test_function = None
   1408     self.train_function = None
   1409     self._stateful_metric_functions = []
   1410 
   1411     cluster_resolver = strategy._tpu_cluster_resolver
   1412     self._tpu_name_or_address = cluster_resolver.get_master()
   1413     self._cpu_model = cpu_model
   1414     self._tpu_assignment = strategy._make_assignment_for_model(cpu_model)
   1415     self._tpu_model = None
   1416     self._tpu_weights_initialized = False
   1417 
   1418     # If the input CPU model has already been compiled, compile our TPU model
   1419     # immediately.
   1420     if self._cpu_model.optimizer:
   1421       self.compile(
   1422           self._cpu_model.optimizer,
   1423           self._cpu_model.loss,
   1424           self._cpu_model._compile_metrics,
   1425           self._cpu_model.loss_weights,
   1426           self._cpu_model.sample_weight_mode,
   1427           self._cpu_model._compile_weighted_metrics,
   1428           self._cpu_model.target_tensors,
   1429       )
   1430 
   1431     # This flag must be disabled upon model mutation, such as changing the model
   1432     # layers or recompiling the model to use a different optimizer. New function
   1433     # definitions are generated whenever this flag is disabled, ensuring that
   1434     # internal graph functions are always using the current model structure.
   1435     #
   1436     # Requires declaration here because this constructor skips the
   1437     # Model constructor.
   1438     self._built_graph_functions = False
   1439 
   1440   def get_config(self):
   1441     return {
   1442         'cpu_model': self._cpu_model,
   1443         'tpu_name_or_address': self._tpu_name_or_address,
   1444         'tpu_assignment': self._tpu_assignment,
   1445     }
   1446 
   1447   def compile(self,
   1448               optimizer,
   1449               loss=None,
   1450               metrics=None,
   1451               loss_weights=None,
   1452               sample_weight_mode=None,
   1453               weighted_metrics=None,
   1454               target_tensors=None,
   1455               **kwargs):
   1456     if sample_weight_mode:
   1457       raise ValueError('sample_weight_mode not supported for TPU execution.')
   1458     if weighted_metrics:
   1459       raise ValueError('weighted_metrics not supported for TPU execution.')
   1460     if target_tensors:
   1461       raise ValueError('target_tensors is not supported for TPU execution.')
   1462 
   1463     self._cpu_model.compile(
   1464         _clone_optimizer(optimizer), loss,
   1465         metrics_module.clone_metrics(metrics), loss_weights, sample_weight_mode,
   1466         metrics_module.clone_metrics(weighted_metrics), target_tensors,
   1467         **kwargs)
   1468 
   1469     super(KerasTPUModel, self).compile(optimizer, loss, metrics, loss_weights,
   1470                                        sample_weight_mode, weighted_metrics,
   1471                                        target_tensors, **kwargs)
   1472 
   1473   def fit(self,
   1474           x=None,
   1475           y=None,
   1476           batch_size=None,
   1477           epochs=1,
   1478           verbose=1,
   1479           callbacks=None,
   1480           validation_split=0.,
   1481           validation_data=None,
   1482           shuffle=True,
   1483           class_weight=None,
   1484           sample_weight=None,
   1485           initial_epoch=0,
   1486           steps_per_epoch=None,
   1487           validation_steps=None,
   1488           **kwargs):
   1489     if context.executing_eagerly():
   1490       raise EnvironmentError('KerasTPUModel currently does not support eager '
   1491                              'mode.')
   1492 
   1493     with _tpu_session_context():
   1494       assert not self._numpy_to_infeed_manager_list  # Ensure empty.
   1495 
   1496       infeed_managers = []  # Managers to clean up at the end of the fit call.
   1497       if isinstance(x, dataset_ops.DatasetV2):
   1498         # TODO(b/111413240): Support taking a tf.data.Dataset directly.
   1499         raise ValueError(
   1500             'Taking a Dataset directly is not yet supported. Please '
   1501             'wrap your dataset construction code in a function and '
   1502             'pass that to fit instead. For examples, see: '
   1503             'https://github.com/tensorflow/tpu/tree/master/models/experimental'
   1504             '/keras')
   1505       if callable(x):
   1506         with ops.device(
   1507             '/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
   1508           dataset = x()
   1509           if steps_per_epoch is None:
   1510             raise ValueError('When using tf.data as input to a model, you '
   1511                              'should specify the steps_per_epoch argument.')
   1512           if y is not None:
   1513             raise ValueError('When using tf.data as input to a model, y must '
   1514                              'be None')
   1515           infeed_manager = TPUDatasetInfeedManager(
   1516               dataset, self._tpu_assignment, model_fn_lib.ModeKeys.TRAIN)
   1517           # Use dummy numpy inputs for the rest of Keras' shape checking. We
   1518           # intercept them when building the model.
   1519           x = infeed_manager.dummy_x
   1520           y = infeed_manager.dummy_y
   1521           infeed_managers.append((x, infeed_manager))
   1522 
   1523       if isinstance(validation_data, dataset_ops.DatasetV2):
   1524         # TODO(b/111413240): Support taking a tf.data.Dataset directly.
   1525         raise ValueError(
   1526             'Taking a Dataset directly is not yet supported. Please '
   1527             'wrap your dataset construction code in a function and '
   1528             'pass that to fit instead. For examples, see: '
   1529             'https://github.com/tensorflow/tpu/tree/master/models/experimental'
   1530             '/keras')
   1531       if callable(validation_data):
   1532         dataset = validation_data()
   1533         if validation_steps is None:
   1534           raise ValueError('When using tf.data as validation for a model, you '
   1535                            'should specify the validation_steps argument.')
   1536         infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
   1537                                                  model_fn_lib.ModeKeys.EVAL)
   1538         # Use dummy numpy inputs for the rest of Keras' shape checking. We
   1539         # intercept them when building the model.
   1540         val_x = infeed_manager.dummy_x
   1541         val_y = infeed_manager.dummy_y
   1542         infeed_managers.append((val_x, infeed_manager))
   1543         validation_data = (val_x, val_y)
   1544 
   1545       self._numpy_to_infeed_manager_list = infeed_managers
   1546       try:
   1547         pipeline = kwargs.get('_pipeline', True)
   1548         if '_pipeline' in kwargs:
   1549           kwargs.pop('_pipeline')
   1550         if not pipeline:
   1551           logging.info('Running non-pipelined training loop (`_pipeline=%s`).',
   1552                        pipeline)
   1553           return super(KerasTPUModel, self).fit(
   1554               x, y, batch_size, epochs, verbose, callbacks, validation_split,
   1555               validation_data, shuffle, class_weight, sample_weight,
   1556               initial_epoch, steps_per_epoch, validation_steps, **kwargs)
   1557         return self._pipeline_fit(x, y, batch_size, epochs, verbose, callbacks,
   1558                                   validation_split, validation_data, shuffle,
   1559                                   class_weight, sample_weight, initial_epoch,
   1560                                   steps_per_epoch, validation_steps, **kwargs)
   1561       finally:
   1562         self._numpy_to_infeed_manager_list = []
   1563 
   1564   def evaluate(self,
   1565                x=None,
   1566                y=None,
   1567                batch_size=None,
   1568                verbose=1,
   1569                sample_weight=None,
   1570                steps=None):
   1571     original_numpy_to_infeed_manager_list = []
   1572     if self._numpy_to_infeed_manager_list:
   1573       # evaluate call may be executed as callbacks during the training. In this
   1574       # case, _numpy_to_infeed_manager_list is not empty, so save it for
   1575       # recovery at the end of evaluate call.
   1576       original_numpy_to_infeed_manager_list = self._numpy_to_infeed_manager_list
   1577       self._numpy_to_infeed_manager_list = []
   1578 
   1579     with _tpu_session_context():
   1580       # Managers to clean up at the end of the evaluate call.
   1581       infeed_managers = []
   1582       if isinstance(x, dataset_ops.DatasetV2):
   1583         # TODO(b/111413240): Support taking a tf.data.Dataset directly.
   1584         raise ValueError(
   1585             'Taking a Dataset directly is not yet supported. Please '
   1586             'wrap your dataset construction code in a function and '
   1587             'pass that to fit instead. For examples, see: '
   1588             'https://github.com/tensorflow/tpu/tree/master/models/experimental'
   1589             '/keras')
   1590       if callable(x):
   1591         dataset = x()
   1592         if steps is None:
   1593           raise ValueError('When using tf.data as input to a model, you '
   1594                            'should specify the steps argument.')
   1595         if y is not None:
   1596           raise ValueError('When using tf.data as input to a model, y must be '
   1597                            'None')
   1598         infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
   1599                                                  model_fn_lib.ModeKeys.EVAL)
   1600         # Use dummy numpy inputs for the rest of Keras' shape checking. We
   1601         # intercept them when building the model.
   1602         x = infeed_manager.dummy_x
   1603         y = infeed_manager.dummy_y
   1604         infeed_managers.append((x, infeed_manager))
   1605 
   1606       self._numpy_to_infeed_manager_list = infeed_managers
   1607       try:
   1608         return super(KerasTPUModel, self).evaluate(x, y, batch_size, verbose,
   1609                                                    sample_weight, steps)
   1610       finally:
   1611         self._numpy_to_infeed_manager_list = (
   1612             original_numpy_to_infeed_manager_list)
   1613 
   1614   def _pipeline_fit(self, x, y, batch_size, epochs, verbose, callbacks,
   1615                     validation_split, validation_data, shuffle, class_weight,
   1616                     sample_weight, initial_epoch, steps_per_epoch,
   1617                     validation_steps, **kwargs):
   1618     # Similar to super.fit(...), but modified to support software pipelining.
   1619 
   1620     # Backwards compatibility
   1621     if batch_size is None and steps_per_epoch is None:
   1622       batch_size = 32
   1623     # Legacy support
   1624     if 'nb_epoch' in kwargs:
   1625       logging.warning('The `nb_epoch` argument in `fit` has been renamed '
   1626                       '`epochs`.')
   1627       epochs = kwargs.pop('nb_epoch')
   1628     if kwargs:
   1629       raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
   1630 
   1631     # Validate and standardize user data
   1632     x, y, sample_weights = self._standardize_user_data(
   1633         x,
   1634         y,
   1635         sample_weight=sample_weight,
   1636         class_weight=class_weight,
   1637         batch_size=batch_size,
   1638         check_steps=True,
   1639         steps_name='steps_per_epoch',
   1640         steps=steps_per_epoch,
   1641         validation_split=validation_split)
   1642 
   1643     # Prepare validation data
   1644     val_x, val_y, val_sample_weights = self._prepare_validation_data(
   1645         validation_data, validation_split, validation_steps, x, y,
   1646         sample_weights, batch_size)
   1647     return self._pipeline_fit_loop(
   1648         x,
   1649         y,
   1650         sample_weights=sample_weights,
   1651         batch_size=batch_size,
   1652         epochs=epochs,
   1653         verbose=verbose,
   1654         callbacks=callbacks,
   1655         val_inputs=val_x,
   1656         val_targets=val_y,
   1657         val_sample_weights=val_sample_weights,
   1658         shuffle=shuffle,
   1659         initial_epoch=initial_epoch,
   1660         steps_per_epoch=steps_per_epoch,
   1661         validation_steps=validation_steps)
   1662 
   1663   def _pipeline_fit_loop(self,
   1664                          inputs,
   1665                          targets,
   1666                          sample_weights,
   1667                          batch_size,
   1668                          epochs,
   1669                          verbose,
   1670                          callbacks,
   1671                          val_inputs,
   1672                          val_targets,
   1673                          val_sample_weights,
   1674                          shuffle,
   1675                          initial_epoch,
   1676                          steps_per_epoch,
   1677                          validation_steps):
   1678     self._make_train_function()
   1679     sample_weights = sample_weights or []
   1680     val_sample_weights = val_sample_weights or []
   1681     if not isinstance(K.learning_phase(), int):
   1682       ins = inputs + targets + sample_weights + [1]
   1683     else:
   1684       ins = inputs + targets + sample_weights
   1685 
   1686     do_validation = False
   1687     if val_inputs:
   1688       do_validation = True
   1689       if (steps_per_epoch is None and verbose and inputs and
   1690           hasattr(inputs[0], 'shape') and hasattr(val_inputs[0], 'shape')):
   1691         print('Train on %d samples, validate on %d samples' %
   1692               (inputs[0].shape[0], val_inputs[0].shape[0]))
   1693 
   1694     if validation_steps:
   1695       do_validation = True
   1696       if steps_per_epoch is None:
   1697         raise ValueError('Can only use `validation_steps` when doing step-wise '
   1698                          'training, i.e. `steps_per_epoch` must be set.')
   1699 
   1700     num_training_samples = training_utils.check_num_samples(
   1701         ins, batch_size, steps_per_epoch, 'steps_per_epoch')
   1702     count_mode = 'steps' if steps_per_epoch else 'samples'
   1703     callbacks = cbks.configure_callbacks(
   1704         callbacks,
   1705         self,
   1706         do_validation=do_validation,
   1707         batch_size=batch_size,
   1708         epochs=epochs,
   1709         steps_per_epoch=steps_per_epoch,
   1710         samples=num_training_samples,
   1711         verbose=verbose,
   1712         count_mode=count_mode)
   1713 
   1714     if num_training_samples is not None:
   1715       index_array = np.arange(num_training_samples)
   1716 
   1717     # To prevent a slowdown, we find beforehand the arrays that need conversion.
   1718     feed = self._feed_inputs + self._feed_targets + self._feed_sample_weights
   1719     indices_for_conversion_to_dense = []
   1720     for i in range(len(feed)):
   1721       if issparse is not None and issparse(ins[i]) and not K.is_sparse(feed[i]):
   1722         indices_for_conversion_to_dense.append(i)
   1723 
   1724     callbacks.on_train_begin()
   1725     for epoch in range(initial_epoch, epochs):
   1726       # Reset stateful metrics
   1727       for m in self.metrics:
   1728         m.reset_states()
   1729       # Update callbacks
   1730       callbacks.on_epoch_begin(epoch)
   1731       epoch_logs = {}
   1732       if steps_per_epoch is not None:
   1733         # Step-wise fit loop.
   1734         self._pipeline_fit_loop_step_wise(
   1735             ins=ins,
   1736             callbacks=callbacks,
   1737             steps_per_epoch=steps_per_epoch,
   1738             epochs=epochs,
   1739             do_validation=do_validation,
   1740             val_inputs=val_inputs,
   1741             val_targets=val_targets,
   1742             val_sample_weights=val_sample_weights,
   1743             validation_steps=validation_steps,
   1744             epoch_logs=epoch_logs)
   1745       else:
   1746         # Sample-wise fit loop.
   1747         self._pipeline_fit_loop_sample_wise(
   1748             ins=ins,
   1749             callbacks=callbacks,
   1750             index_array=index_array,
   1751             shuffle=shuffle,
   1752             batch_size=batch_size,
   1753             num_training_samples=num_training_samples,
   1754             indices_for_conversion_to_dense=indices_for_conversion_to_dense,
   1755             do_validation=do_validation,
   1756             val_inputs=val_inputs,
   1757             val_targets=val_targets,
   1758             val_sample_weights=val_sample_weights,
   1759             validation_steps=validation_steps,
   1760             epoch_logs=epoch_logs)
   1761 
   1762       callbacks.on_epoch_end(epoch, epoch_logs)
   1763       if callbacks.model.stop_training:
   1764         break
   1765     callbacks.on_train_end()
   1766     return self.history
   1767 
   1768   def _pipeline_fit_loop_sample_wise(self,
   1769                                      ins,
   1770                                      callbacks,
   1771                                      index_array,
   1772                                      shuffle,
   1773                                      batch_size,
   1774                                      num_training_samples,
   1775                                      indices_for_conversion_to_dense,
   1776                                      do_validation,
   1777                                      val_inputs,
   1778                                      val_targets,
   1779                                      val_sample_weights,
   1780                                      validation_steps,
   1781                                      epoch_logs):
   1782     f = self.train_function
   1783     if shuffle == 'batch':
   1784       index_array = training_utils.batch_shuffle(index_array, batch_size)
   1785     elif shuffle:
   1786       np.random.shuffle(index_array)
   1787     batches = make_batches(num_training_samples, batch_size)
   1788 
   1789     ins_last_batch = None
   1790     last_batch_logs = None
   1791     batch_index = 0
   1792 
   1793     for batch_index, (batch_start, batch_end) in enumerate(batches):
   1794       batch_ids = index_array[batch_start:batch_end]
   1795       try:
   1796         if isinstance(ins[-1], int):
   1797           # Do not slice the training phase flag.
   1798           ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]]
   1799         else:
   1800           ins_batch = slice_arrays(ins, batch_ids)
   1801       except TypeError:
   1802         raise TypeError('TypeError while preparing batch. If using HDF5 '
   1803                         'input data, pass shuffle="batch".')
   1804 
   1805       # Pipeline batch logs
   1806       next_batch_logs = {}
   1807       next_batch_logs['batch'] = batch_index
   1808       next_batch_logs['size'] = len(batch_ids)
   1809       if batch_index > 0:
   1810         # Callbacks operate one step behind in software pipeline.
   1811         callbacks.on_batch_begin(batch_index - 1, last_batch_logs)
   1812       for i in indices_for_conversion_to_dense:
   1813         ins_batch[i] = ins_batch[i].toarray()
   1814 
   1815       outs = f.pipeline_run(
   1816           cur_step_inputs=ins_last_batch, next_step_inputs=ins_batch)
   1817       ins_last_batch = ins_batch
   1818 
   1819       if batch_index == 0:
   1820         assert outs is None
   1821       else:
   1822         if not isinstance(outs, list):
   1823           outs = [outs]
   1824         for l, o in zip(self.metrics_names, outs):
   1825           last_batch_logs[l] = o  # pylint: disable=unsupported-assignment-operation
   1826         callbacks.on_batch_end(batch_index - 1, last_batch_logs)
   1827         if callbacks.model.stop_training:
   1828           return
   1829       last_batch_logs = next_batch_logs
   1830 
   1831     # Final batch
   1832     callbacks.on_batch_begin(batch_index, last_batch_logs)
   1833     outs = f.pipeline_run(cur_step_inputs=ins_last_batch, next_step_inputs=None)
   1834     if not isinstance(outs, list):
   1835       outs = [outs]
   1836     for l, o in zip(self.metrics_names, outs):
   1837       last_batch_logs[l] = o
   1838     callbacks.on_batch_end(batch_index, last_batch_logs)
   1839     if callbacks.model.stop_training:
   1840       return
   1841 
   1842     if do_validation:
   1843       val_outs = training_arrays.test_loop(
   1844           self,
   1845           val_inputs,
   1846           val_targets,
   1847           sample_weights=val_sample_weights,
   1848           batch_size=batch_size,
   1849           steps=validation_steps,
   1850           verbose=0)
   1851       if not isinstance(val_outs, list):
   1852         val_outs = [val_outs]
   1853       # Same labels assumed.
   1854       for l, o in zip(self.metrics_names, val_outs):
   1855         epoch_logs['val_' + l] = o
   1856 
   1857   def _pipeline_fit_loop_step_wise(self,
   1858                                    ins,
   1859                                    callbacks,
   1860                                    steps_per_epoch,
   1861                                    epochs,
   1862                                    do_validation,
   1863                                    val_inputs,
   1864                                    val_targets,
   1865                                    val_sample_weights,
   1866                                    validation_steps,
   1867                                    epoch_logs):
   1868     f = self.train_function
   1869 
   1870     # Loop prologue
   1871     try:
   1872       outs = f.pipeline_run(cur_step_inputs=None, next_step_inputs=ins)
   1873       assert outs is None  # Function shouldn't return anything!
   1874     except errors.OutOfRangeError:
   1875       logging.warning('Your dataset iterator ran out of data on the first step '
   1876                       'of the epoch, preventing further training. Check to '
   1877                       'make sure your paths are correct and you have '
   1878                       'permissions to read the files. Skipping validation')
   1879 
   1880     for step_index in range(steps_per_epoch):
   1881       batch_logs = {'batch': step_index, 'size': 1}
   1882       callbacks.on_batch_begin(step_index, batch_logs)
   1883       try:
   1884         if step_index < steps_per_epoch - 1:
   1885           next_step_inputs = ins
   1886         else:
   1887           next_step_inputs = None
   1888         outs = f.pipeline_run(
   1889             cur_step_inputs=ins, next_step_inputs=next_step_inputs)
   1890       except errors.OutOfRangeError:
   1891         logging.warning('Your dataset iterator ran out of data; '
   1892                         'interrupting training. Make sure that your '
   1893                         'dataset can generate at least `steps_per_batch * '
   1894                         'epochs` batches (in this case, %d batches). You '
   1895                         'may need to use the repeat() function when '
   1896                         'building your dataset.' % steps_per_epoch * epochs)
   1897         break
   1898 
   1899       if not isinstance(outs, list):
   1900         outs = [outs]
   1901       for l, o in zip(self.metrics_names, outs):
   1902         batch_logs[l] = o
   1903 
   1904       callbacks.on_batch_end(step_index, batch_logs)
   1905       if callbacks.model.stop_training:
   1906         break
   1907 
   1908     if do_validation:
   1909       val_outs = training_arrays.test_loop(
   1910           self,
   1911           val_inputs,
   1912           val_targets,
   1913           sample_weights=val_sample_weights,
   1914           steps=validation_steps,
   1915           verbose=0)
   1916       if not isinstance(val_outs, list):
   1917         val_outs = [val_outs]
   1918       # Same labels assumed.
   1919       for l, o in zip(self.metrics_names, val_outs):
   1920         epoch_logs['val_' + l] = o
   1921 
   1922   def _prepare_validation_data(self, validation_data, validation_split,
   1923                                validation_steps, x, y, sample_weights,
   1924                                batch_size):
   1925     """Prepares the validation dataset.
   1926 
   1927     Args:
   1928       validation_data: The validation data (if provided)
   1929       validation_split: The validation split (if provided)
   1930       validation_steps: The validation steps (if provided)
   1931       x: The main training data x (if provided)
   1932       y: The main training data y (if provided)
   1933       sample_weights: The sample weights (if provided)
   1934       batch_size: The training batch size (if provided)
   1935 
   1936     Returns:
   1937       A 3-tuple of (val_x, val_y, val_sample_weights).
   1938 
   1939     Raises:
   1940       ValueError: If the provided arguments are not compatible with
   1941         `KerasTPUModel`.
   1942     """
   1943     # Note: this is similar to a section of $tf/python/keras/engine/training.py
   1944     # It differns in that tf.data objects are not allowed to be passed directly.
   1945     # Additionally, it handles validating shapes & types appropriately for use
   1946     # in TPUs.
   1947     if validation_data:
   1948       if (isinstance(validation_data, iterator_ops.Iterator) or
   1949           isinstance(validation_data, iterator_ops.EagerIterator) or
   1950           isinstance(validation_data, dataset_ops.DatasetV2)):
   1951         raise ValueError('KerasTPUModel cannot handle a Dataset or Iterator '
   1952                          'for validation_data. Please instead pass a function '
   1953                          'that returns a `tf.data.Dataset`.')
   1954       if len(validation_data) == 2:
   1955         val_x, val_y = validation_data  # pylint: disable=unpacking-non-sequence
   1956         val_sample_weight = None
   1957       elif len(validation_data) == 3:
   1958         val_x, val_y, val_sample_weight = validation_data  # pylint: disable=unpacking-non-sequence
   1959       else:
   1960         raise ValueError('When passing a `validation_data` argument, it must '
   1961                          'contain either 2 items (x_val, y_val), or 3 items '
   1962                          '(x_val, y_val, val_sample_weights). However we '
   1963                          'received `validation_data=%s`' % validation_data)
   1964       val_x, val_y, val_sample_weights = self._standardize_user_data(
   1965           val_x,
   1966           val_y,
   1967           sample_weight=val_sample_weight,
   1968           batch_size=batch_size,
   1969           steps=validation_steps)
   1970     elif validation_split and 0. < validation_split < 1.:
   1971       if training_utils.has_symbolic_tensors(x):
   1972         raise ValueError('If your data is in the form of symbolic tensors, you '
   1973                          'cannot use `validation_split`.')
   1974       if hasattr(x[0], 'shape'):
   1975         split_at = int(x[0].shape[0] * (1. - validation_split))
   1976       else:
   1977         split_at = int(len(x[0]) * (1. - validation_split))
   1978 
   1979       x, val_x = (slice_arrays(x, 0, split_at), slice_arrays(x, split_at))
   1980       y, val_y = (slice_arrays(y, 0, split_at), slice_arrays(y, split_at))
   1981       sample_weights, val_sample_weights = (
   1982           slice_arrays(sample_weights, 0, split_at),
   1983           slice_arrays(sample_weights, split_at)
   1984       )
   1985     elif validation_steps:
   1986       val_x = []
   1987       val_y = []
   1988       val_sample_weights = []
   1989     else:
   1990       val_x = None
   1991       val_y = None
   1992       val_sample_weights = None
   1993 
   1994     return val_x, val_y, val_sample_weights
   1995 
   1996   def predict(self,
   1997               x,
   1998               batch_size=None,
   1999               verbose=0,
   2000               steps=None,
   2001               max_queue_size=10,
   2002               workers=1,
   2003               use_multiprocessing=False):
   2004     with _tpu_session_context():
   2005       return super(KerasTPUModel, self).predict(
   2006           x,
   2007           batch_size=batch_size,
   2008           verbose=verbose,
   2009           steps=steps,
   2010           max_queue_size=max_queue_size,
   2011           workers=workers,
   2012           use_multiprocessing=use_multiprocessing)
   2013 
   2014   @property
   2015   def optimizer(self):
   2016     if self._tpu_model:
   2017       return self._tpu_model.optimizer
   2018     return self._cpu_model.optimizer
   2019 
   2020   @optimizer.setter
   2021   def optimizer(self, optimizer):
   2022     self._optimizer = optimizer
   2023 
   2024   @property
   2025   def metrics(self):
   2026     if self._tpu_model:
   2027       return self._tpu_model.metrics
   2028     return self._stateful_metric_functions
   2029 
   2030   @metrics.setter
   2031   def metrics(self, metrics):
   2032     self._stateful_metric_functions = metrics
   2033 
   2034   def _make_train_function(self):
   2035     if not self.train_function:
   2036       self.train_function = TPUFunction(
   2037           self,
   2038           model_fn_lib.ModeKeys.TRAIN,
   2039           tpu_assignment=self._tpu_assignment)
   2040 
   2041     return self.train_function
   2042 
   2043   def _make_test_function(self):
   2044     if not self.test_function:
   2045       self.test_function = TPUFunction(
   2046           self, model_fn_lib.ModeKeys.EVAL, tpu_assignment=self._tpu_assignment)
   2047     return self.test_function
   2048 
   2049   def _make_predict_function(self):
   2050     if not self.predict_function:
   2051       self.predict_function = TPUFunction(
   2052           self,
   2053           model_fn_lib.ModeKeys.PREDICT,
   2054           tpu_assignment=self._tpu_assignment)
   2055     return self.predict_function
   2056 
   2057   def _initialize_weights(self, cloned_model):
   2058     """Initialize TPU weights.
   2059 
   2060     This is called on the first compile of the TPU model (first call to
   2061     fit/predict/evaluate).
   2062 
   2063     Args:
   2064       cloned_model: `keras.Model`, TPU model to initialize.
   2065     """
   2066     if self._tpu_weights_initialized:
   2067       return
   2068 
   2069     self._tpu_model = cloned_model
   2070     self._tpu_weights_initialized = True
   2071 
   2072     weights = self._cpu_model.get_weights()
   2073 
   2074     if isinstance(self.cpu_optimizer, keras_optimizers.TFOptimizer):
   2075       cpu_optimizer_config = {}
   2076     else:
   2077       cpu_optimizer_config = self.cpu_optimizer.get_config()
   2078 
   2079     logging.info('Setting weights on TPU model.')
   2080     cloned_model.set_weights(weights)
   2081     if self._tpu_model.optimizer is None:
   2082       # tpu_model may not be compiled, e.g., loading weights and then predict.
   2083       return
   2084     for k, v in six.iteritems(cpu_optimizer_config):
   2085       if k == 'name':
   2086         continue
   2087       opt_var = getattr(self._tpu_model.optimizer, k)
   2088       if isinstance(opt_var, variables.Variable):
   2089         logging.info('CPU -> TPU %s: %s {%s}', k, v, K.get_value(opt_var))
   2090         K.get_session().run(opt_var.assign(v))
   2091       else:
   2092         logging.warning('Cannot update non-variable config: %s', k)
   2093 
   2094   @property
   2095   def cpu_optimizer(self):
   2096     return self._cpu_model.optimizer
   2097 
   2098   def sync_to_cpu(self):
   2099     """Copy weights from the CPU, returning a synchronized CPU model."""
   2100     if not self._tpu_weights_initialized:
   2101       return self._cpu_model
   2102 
   2103     logging.info('Copying TPU weights to the CPU')
   2104     tpu_weights = self._tpu_model.get_weights()
   2105 
   2106     # TFOptimizers have no configurable options
   2107     if isinstance(self.cpu_optimizer, keras_optimizers.TFOptimizer):
   2108       tpu_optimizer_config = {}
   2109     else:
   2110       tpu_optimizer_config = self._tpu_model.optimizer.get_config()
   2111 
   2112     self._cpu_model.set_weights(tpu_weights)
   2113     for k, v in six.iteritems(tpu_optimizer_config):
   2114       logging.info('TPU -> CPU %s: %s', k, v)
   2115       if k == 'name':
   2116         continue
   2117       opt_var = getattr(self.cpu_optimizer, k)
   2118       if isinstance(opt_var, variables.Variable):
   2119         K.get_session().run(opt_var.assign(v))
   2120       else:
   2121         logging.warning('Cannot update non-variable config: %s', k)
   2122 
   2123     return self._cpu_model
   2124 
   2125   def get_weights(self):
   2126     return self.sync_to_cpu().get_weights()
   2127 
   2128   def save_weights(self, *args, **kw):
   2129     return self.sync_to_cpu().save_weights(*args, **kw)
   2130 
   2131   def save(self, *args, **kw):
   2132     return self.sync_to_cpu().save(*args, **kw)
   2133 
   2134   def set_weights(self, weights):
   2135     # We may not have a TPU model available if we haven't run fit/predict, so
   2136     # we can't directly set the TPU weights here.
   2137     # Instead, reset CPU model weights and force TPU re-initialization at the
   2138     # next call.
   2139     self._cpu_model.set_weights(weights)
   2140     self._tpu_weights_initialized = False
   2141 
   2142   def load_weights(self, filepath, by_name=False):
   2143     self._cpu_model.load_weights(filepath, by_name)
   2144     self._tpu_weights_initialized = False
   2145 
   2146 
   2147 # pylint: disable=bad-continuation
   2148 def _validate_shapes(model):
   2149   """Validate that all layers in `model` have constant shape."""
   2150   for layer in model.layers:
   2151     if isinstance(layer.input_shape, tuple):
   2152       input_shapes = [layer.input_shape]
   2153     else:
   2154       input_shapes = layer.input_shape
   2155 
   2156     if isinstance(layer.output_shape, tuple):
   2157       output_shapes = [layer.output_shape]
   2158     else:
   2159       output_shapes = layer.output_shape
   2160 
   2161     for shape in input_shapes + output_shapes:
   2162       for dim in shape[1:]:
   2163         if dim is None:
   2164           raise ValueError(
   2165               """
   2166 Layer %(layer)s has a variable shape in a non-batch dimension.  TPU models must
   2167 have constant shapes for all operations.
   2168 
   2169 You may have to specify `input_length` for RNN/TimeDistributed layers.
   2170 
   2171 Layer: %(layer)s
   2172 Input shape: %(input_shape)s
   2173 Output shape: %(output_shape)s
   2174   """ % {
   2175           'layer': layer,
   2176           'input_shape': layer.input_shape,
   2177           'output_shape': layer.output_shape
   2178           })
   2179 
   2180 
   2181 # pylint: enable=bad-continuation
   2182 
   2183 
   2184 @deprecated(
   2185     '2019-02-20', 'Switch to tf.contrib.distribute.TPUStrategy. '
   2186     'https://www.tensorflow.org/api_docs/python/tf/contrib/distribute/DistributionStrategy'
   2187 )
   2188 def tpu_model(model, strategy=None):
   2189   """Copy `model` along with weights to the TPU.
   2190 
   2191   Returns a TPU model.
   2192 
   2193   Usage:
   2194   ```
   2195   a = Input(shape=(32,))
   2196   b = Dense(32)(a)
   2197   model = Model(inputs=a, outputs=b)
   2198 
   2199   # If `num_cores_per_host` is greater than one, batch parallelism will be used
   2200   # to run on multiple TPU cores.
   2201   strategy = keras_support.TPUDistributionStrategy(tpu_cluster_resolver)
   2202   model = keras_support.tpu_model(model, strategy)
   2203   model.compile(
   2204       optimizer=tf.train.GradientDescentOptimizer(learning_rate=1.0),
   2205       ...)
   2206   ```
   2207 
   2208   Args:
   2209     model: A `tf.keras.Model` instance.
   2210     strategy: `TPUDistributionStrategy`.  The strategy to use for replicating
   2211       model across multiple TPU cores.
   2212 
   2213   Returns:
   2214     A new `KerasTPUModel` instance.
   2215   """
   2216   _validate_shapes(model)
   2217   # TODO(xiejw): Validate TPU model. TPUModel only?
   2218   # TODO(xiejw): Validate replicas. Full or 1. Shall we allow subset?
   2219   # TODO(xiejw): Adds reduction option.
   2220 
   2221   if strategy is None:
   2222     strategy = TPUDistributionStrategy()
   2223   else:
   2224     if not isinstance(strategy, TPUDistributionStrategy):
   2225       raise TypeError(
   2226           '`strategy` must have type `tf.contrib.tpu.TPUDistributionStrategy`. '
   2227           'Got: {}'.format(type(strategy)))
   2228 
   2229   # If the model has already been initialized, grab the optimizer configuration
   2230   # and model weights before entering the TPU session.
   2231   if model.optimizer:
   2232     if (isinstance(model.optimizer, keras_optimizers.Optimizer) and not
   2233         isinstance(model.optimizer, keras_optimizers.TFOptimizer)):
   2234       optimizer_config = model.optimizer.get_config()
   2235     else:
   2236       optimizer_config = None
   2237     model_weights = model.get_weights()
   2238   else:
   2239     model_weights = None
   2240 
   2241   setup_tpu_session(strategy._tpu_cluster_resolver)
   2242 
   2243   # Force initialization of the CPU model in the TPU session.
   2244   cpu_model = models.clone_model(model)
   2245   if model.optimizer:
   2246     cpu_model.compile(
   2247         _clone_optimizer(model.optimizer, optimizer_config),
   2248         model.loss,
   2249         metrics_module.clone_metrics(model._compile_metrics),
   2250         model.loss_weights,
   2251         model.sample_weight_mode,
   2252         metrics_module.clone_metrics(model._compile_weighted_metrics),
   2253     )
   2254 
   2255   if model_weights:
   2256     cpu_model.set_weights(model_weights)
   2257     cpu_model.reset_states()
   2258 
   2259   return KerasTPUModel(cpu_model=cpu_model, strategy=strategy)
   2260