Home | History | Annotate | Download | only in tpu
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ===================================================================
     15 """TPUEstimator class."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import collections
     22 import copy
     23 import os
     24 import signal
     25 import sys
     26 import threading
     27 import time
     28 
     29 import numpy as np
     30 import six
     31 from six.moves import queue as Queue  # pylint: disable=redefined-builtin
     32 from six.moves import xrange  # pylint: disable=redefined-builtin
     33 
     34 from tensorflow.core.framework import variable_pb2
     35 from tensorflow.core.framework.summary_pb2 import Summary
     36 from tensorflow.core.protobuf import config_pb2
     37 from tensorflow.core.protobuf.tpu import compilation_result_pb2 as tpu_compilation_result
     38 from tensorflow.python.client import session as tf_session
     39 from tensorflow.python.data.ops import dataset_ops
     40 from tensorflow.python.data.util import nest as data_nest
     41 from tensorflow.python.estimator import estimator as estimator_lib
     42 from tensorflow.python.estimator import model_fn as model_fn_lib
     43 from tensorflow.python.estimator.export import export_output as export_output_lib
     44 from tensorflow.python.framework import constant_op
     45 from tensorflow.python.framework import dtypes
     46 from tensorflow.python.framework import errors
     47 from tensorflow.python.framework import function
     48 from tensorflow.python.framework import ops
     49 from tensorflow.python.ops import array_ops
     50 from tensorflow.python.ops import check_ops
     51 from tensorflow.python.ops import control_flow_ops
     52 from tensorflow.python.ops import init_ops
     53 from tensorflow.python.ops import math_ops
     54 from tensorflow.python.ops import resource_variable_ops
     55 from tensorflow.python.ops import state_ops
     56 from tensorflow.python.ops import summary_ops_v2 as contrib_summary
     57 from tensorflow.python.ops import variable_scope
     58 from tensorflow.python.ops import variables
     59 from tensorflow.python.platform import tf_logging as logging
     60 from tensorflow.python.saved_model import tag_constants
     61 from tensorflow.python.summary import summary
     62 from tensorflow.python.tpu import _tpu_estimator_embedding
     63 from tensorflow.python.tpu import error_handling
     64 from tensorflow.python.tpu import functional as tpu_functional
     65 from tensorflow.python.tpu import session_support
     66 from tensorflow.python.tpu import tensor_tracer
     67 from tensorflow.python.tpu import tpu
     68 from tensorflow.python.tpu import tpu_config
     69 from tensorflow.python.tpu import tpu_context
     70 from tensorflow.python.tpu import tpu_embedding_gradient
     71 from tensorflow.python.tpu import tpu_feed
     72 from tensorflow.python.tpu import tpu_function
     73 from tensorflow.python.tpu import training_loop
     74 from tensorflow.python.tpu import util as util_lib
     75 from tensorflow.python.tpu._tpu_estimator_embedding import AdagradParameters  # pylint: disable=unused-import
     76 from tensorflow.python.tpu._tpu_estimator_embedding import AdamParameters  # pylint: disable=unused-import
     77 from tensorflow.python.tpu._tpu_estimator_embedding import StochasticGradientDescentParameters  # pylint: disable=unused-import
     78 from tensorflow.python.tpu._tpu_estimator_embedding import EmbeddingConfigSpec  # pylint: disable=unused-import
     79 from tensorflow.python.tpu.ops import tpu_ops
     80 from tensorflow.python.training import basic_session_run_hooks
     81 from tensorflow.python.training import evaluation
     82 from tensorflow.python.training import session_run_hook
     83 from tensorflow.python.training import training
     84 from tensorflow.python.training import training_util
     85 from tensorflow.python.util import function_utils
     86 from tensorflow.python.util import nest
     87 from tensorflow.python.util import tf_inspect
     88 
     89 _INITIAL_LOSS = 1e7
     90 _ZERO_LOSS = 0.
     91 _TPU_ESTIMATOR = 'tpu_estimator'
     92 _ITERATIONS_PER_LOOP_VAR = 'iterations_per_loop'
     93 _BATCH_SIZE_KEY = 'batch_size'
     94 _CTX_KEY = 'context'
     95 _USE_TPU_KEY = 'use_tpu'
     96 _CROSS_REPLICA_SUM_OP = 'CrossReplicaSum'
     97 _ONE_GIGABYTE = 1024 * 1024 * 1024
     98 _TPU_ENQUEUE_OPS = '_tpu_enqueue_ops'
     99 _TPU_TRAIN_OP = '_tpu_train_op'
    100 _REWRITE_FOR_INFERENCE_MODE = '_rewrite_for_inference'
    101 _KEY_WHEN_PREDICTIONS_IS_A_TENSOR = '_key_when_predictions_is_a_tensor'
    102 
    103 # Ideally _USE_TPU_KEY should be reserved as well. However there are already
    104 # models that make use of this key, thus it can not be reserved now to prevent
    105 # breakage. In the long run, we would like to mitigate this by migrating models
    106 # off of using _USE_TPU_KEY.
    107 _RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY, _CTX_KEY]
    108 
    109 # TODO(b/65703635): Flip the value and remove all dead code. Currently, this is
    110 # only used for per-core based deployments. For per-host based pipelines, if a
    111 # user returns a Dataset instance it will be automatically wrapped in a
    112 # tf.while_loop (This can be disabled by returning features and labels
    113 # explicitly).
    114 _WRAP_INPUT_FN_INTO_WHILE_LOOP = False
    115 
    116 ops.register_proto_function(
    117     '{}_{}'.format(_TPU_ESTIMATOR, _ITERATIONS_PER_LOOP_VAR),
    118     proto_type=variable_pb2.VariableDef,
    119     to_proto=resource_variable_ops._to_proto_fn,  # pylint: disable=protected-access
    120     from_proto=resource_variable_ops._from_proto_fn)  # pylint: disable=protected-access
    121 
    122 
    123 def _is_iterable(obj):
    124   """A Python 2 and 3 compatible util to check whether `obj` is iterable."""
    125   try:
    126     iter(obj)
    127     return True
    128   except TypeError:
    129     return False
    130 
    131 
    132 class CatchInvalidHostcallFunctions(control_flow_ops.XLAControlFlowContext):
    133 
    134   def AddOp(self, op):
    135     if op.type in [
    136         'AudioSummary', 'AudioSummaryV2', 'HistogramSummary', 'ImageSummary',
    137         'MergeSummary', 'ScalarSummary', 'TensorSummary', 'TensorSummaryV2'
    138     ]:
    139       raise ValueError('Use tf.contrib.summary inside of host_calls.')
    140 
    141 
    142 def _create_global_step(graph):
    143   graph = graph or ops.get_default_graph()
    144   if training.get_global_step(graph) is not None:
    145     raise ValueError('"global_step" already exists.')
    146   # Create in proper graph and base name_scope.
    147   with graph.as_default() as g, g.name_scope(None):
    148     return variable_scope.get_variable(
    149         ops.GraphKeys.GLOBAL_STEP,
    150         shape=[],
    151         dtype=dtypes.int64,
    152         initializer=init_ops.zeros_initializer(),
    153         trainable=False,
    154         use_resource=True,
    155         collections=[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP])
    156 
    157 
    158 def _create_or_get_iterations_per_loop():
    159   """Creates or gets the iterations_per_loop variable.
    160 
    161   In TPUEstimator, the user provided computation, the model_fn, is wrapped
    162   inside a tf.while_loop for peak performance. The iterations of the loop are
    163   specified by this variable, which adjusts its value on the CPU after each TPU
    164   program execution and before the next TPU execution.
    165 
    166   The purpose of using a variable, rather then a constant, is to allow
    167   TPUEstimator adapt the TPU training iterations according to the final steps
    168   specified by users. For example, if the user sets the iterations_per_loop as 4
    169   in TPUConfig and steps as 10 in TPUEstimator.train(), the iterations_per_loop
    170   variable will have the following value before each TPU training.
    171 
    172       - 1-th TPU execution: iterations_per_loop = 4
    173       - 2-th TPU execution: iterations_per_loop = 4
    174       - 3-th TPU execution: iterations_per_loop = 2
    175 
    176   As model_fn increases the global step once per train_op invocation, the global
    177   step is 10 after all TPU executions, matching the steps=10 inputs passed in by
    178   users.
    179 
    180   Returns:
    181     A TF non-trainable resource variable.
    182 
    183   Raises:
    184     RuntimeError: If multi iterations_per_loop variables were found.
    185   """
    186   graph = ops.get_default_graph()
    187   collection_name = '{}_{}'.format(_TPU_ESTIMATOR, _ITERATIONS_PER_LOOP_VAR)
    188   iter_vars = graph.get_collection(collection_name)
    189   if len(iter_vars) == 1:
    190     return iter_vars[0]
    191   elif len(iter_vars) > 1:
    192     raise RuntimeError('Multiple iterations_per_loop_var in collection.')
    193 
    194   with ops.colocate_with(training_util.get_global_step()):
    195     with variable_scope.variable_scope(
    196         _TPU_ESTIMATOR, reuse=variable_scope.AUTO_REUSE):
    197       return variable_scope.get_variable(
    198           _ITERATIONS_PER_LOOP_VAR,
    199           initializer=init_ops.zeros_initializer(),
    200           shape=[],
    201           dtype=dtypes.int32,
    202           trainable=False,
    203           collections=[collection_name, ops.GraphKeys.LOCAL_VARIABLES],
    204           use_resource=True)
    205 
    206 
    207 def _sync_variables_ops(ctx):
    208   """Create varriables synchronization ops.
    209 
    210   Gets the variables back from TPU nodes. This means the variables updated
    211   by TPU will now be *synced* to host memory.
    212   In BROADCAST mode, we skip this sync since the variables are ususally too
    213   big to transmit via RPC.
    214 
    215   Args:
    216     ctx: A `_InternalTPUContext` instance with mode.
    217 
    218   Returns:
    219     A list of sync ops.
    220   """
    221 
    222   if not ctx.is_input_broadcast_with_iterators():
    223     return [
    224         array_ops.check_numerics(v.read_value(),
    225                                  'Gradient for %s is NaN' % v.name).op
    226         for v in variables.trainable_variables()
    227     ]
    228   else:
    229     return [control_flow_ops.no_op()]
    230 
    231 
    232 def _increase_eval_step_op(iterations_per_loop):
    233   """Returns an op to increase the eval step for TPU evaluation.
    234 
    235   Args:
    236     iterations_per_loop: Tensor. The number of eval steps running in TPU system
    237       before returning to CPU host for each `Session.run`.
    238 
    239   Returns:
    240     An operation
    241   """
    242   eval_step = evaluation._get_or_create_eval_step()  # pylint: disable=protected-access
    243   # Estimator evaluate increases 1 by default. So, we increase the difference.
    244   return state_ops.assign_add(
    245       eval_step,
    246       math_ops.cast(iterations_per_loop - 1, dtype=eval_step.dtype),
    247       use_locking=True)
    248 
    249 
    250 def _extract_key_names(tensor_or_dict):
    251   if isinstance(tensor_or_dict, dict):
    252     return sorted(tensor_or_dict.keys())
    253   return []
    254 
    255 
    256 class _SIGNAL(object):
    257   """Signal used to control the thread of infeed/outfeed.
    258 
    259   All preserved signals must be negative numbers. Positive numbers are used to
    260   indicate the number of iterations for next training/evaluation loop.
    261   """
    262   NEXT_BATCH = -1
    263   STOP = -2
    264 
    265 
    266 class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec):  # pylint: disable=protected-access
    267   """Ops and objects returned from a `model_fn` and passed to `TPUEstimator`.
    268 
    269   See `EstimatorSpec` for `mode`, `predictions`, `loss`, `train_op`, and
    270   `export_outputs`.
    271 
    272   For evaluation, `eval_metrics `is a tuple of `metric_fn` and `tensors`, where
    273   `metric_fn` runs on CPU to generate metrics and `tensors` represents the
    274   `Tensor`s transferred from TPU system to CPU host and passed to `metric_fn`.
    275   To be precise, TPU evaluation expects a slightly different signature from the
    276   `tf.estimator.Estimator`. While `EstimatorSpec.eval_metric_ops` expects a
    277   dict, `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`.
    278   The `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. The
    279   `tensors` usually specify the model logits, which are transferred back from
    280   TPU system to CPU host. All tensors must have be batch-major, i.e., the batch
    281   size is the first dimension. Once all tensors are available at CPU host from
    282   all shards, they are concatenated (on CPU) and passed as positional arguments
    283   to the `metric_fn` if `tensors` is list or keyword arguments if `tensors` is
    284   a dict. `metric_fn` takes the `tensors` and returns a dict from metric string
    285   name to the result of calling a metric function, namely a `(metric_tensor,
    286   update_op)` tuple. See `TPUEstimator` for MNIST example how to specify the
    287   `eval_metrics`.
    288 
    289   `scaffold_fn` is a function running on CPU to generate the `Scaffold`. This
    290   function should not capture any Tensors in `model_fn`.
    291 
    292   `host_call` is a tuple of a `function` and a list or dictionary of `tensors`
    293   to pass to that function and returns a list of Tensors. `host_call` currently
    294   works for train() and evaluate(). The Tensors returned by the function is
    295   executed on the CPU on every step, so there is communication overhead when
    296   sending tensors from TPU to CPU. To reduce the overhead, try reducing the
    297   size of the tensors. The `tensors` are concatenated along their major (batch)
    298   dimension, and so must be >= rank 1. The `host_call` is useful for writing
    299   summaries with `tf.contrib.summary.create_file_writer`.
    300   """
    301 
    302   def __new__(cls,
    303               mode,
    304               predictions=None,
    305               loss=None,
    306               train_op=None,
    307               eval_metrics=None,
    308               export_outputs=None,
    309               scaffold_fn=None,
    310               host_call=None,
    311               training_hooks=None,
    312               evaluation_hooks=None,
    313               prediction_hooks=None):
    314     """Creates a validated `TPUEstimatorSpec` instance."""
    315     host_calls = {}
    316     if eval_metrics is not None:
    317       host_calls['eval_metrics'] = eval_metrics
    318     if host_call is not None:
    319       host_calls['host_call'] = host_call
    320     _OutfeedHostCall.validate(host_calls)
    321 
    322     training_hooks = tuple(training_hooks or [])
    323     evaluation_hooks = tuple(evaluation_hooks or [])
    324     prediction_hooks = tuple(prediction_hooks or [])
    325 
    326     for hook in training_hooks + evaluation_hooks + prediction_hooks:
    327       if not isinstance(hook, session_run_hook.SessionRunHook):
    328         raise TypeError('All hooks must be SessionRunHook instances, given: {}'
    329                         .format(hook))
    330 
    331     return super(TPUEstimatorSpec, cls).__new__(
    332         cls,
    333         mode=mode,
    334         predictions=predictions,
    335         loss=loss,
    336         train_op=train_op,
    337         eval_metrics=eval_metrics,
    338         export_outputs=export_outputs,
    339         scaffold_fn=scaffold_fn,
    340         host_call=host_call,
    341         training_hooks=training_hooks,
    342         evaluation_hooks=evaluation_hooks,
    343         prediction_hooks=prediction_hooks)
    344 
    345   def as_estimator_spec(self):
    346     """Creates an equivalent `EstimatorSpec` used by CPU train/eval."""
    347     host_calls = {}
    348     if self.eval_metrics is not None:
    349       host_calls['eval_metrics'] = self.eval_metrics
    350     if self.host_call is not None:
    351       host_calls['host_call'] = self.host_call
    352     host_call_ret = _OutfeedHostCall.create_cpu_hostcall(host_calls)
    353     eval_metric_ops = None
    354     if self.eval_metrics is not None:
    355       eval_metric_ops = host_call_ret['eval_metrics']
    356     hooks = None
    357     if self.host_call is not None:
    358       hooks = [_OutfeedHostCallHook(host_call_ret['host_call'])]
    359     loss = self.loss
    360     if tensor_tracer.TensorTracer.is_enabled() \
    361        and self.train_op is not None:
    362       tt = tensor_tracer.TensorTracer()
    363       loss = tt.trace_cpu(ops.get_default_graph(), loss, self.train_op)
    364 
    365     hooks = tuple(hooks or [])
    366     scaffold = self.scaffold_fn() if self.scaffold_fn else None
    367     return model_fn_lib.EstimatorSpec(
    368         mode=self.mode,
    369         predictions=self.predictions,
    370         loss=loss,
    371         train_op=self.train_op,
    372         eval_metric_ops=eval_metric_ops,
    373         export_outputs=self.export_outputs,
    374         scaffold=scaffold,
    375         training_hooks=self.training_hooks + hooks,
    376         evaluation_hooks=self.evaluation_hooks + hooks,
    377         prediction_hooks=self.prediction_hooks + hooks)
    378 
    379 
    380 class _OpQueueContext(object):
    381   """Manages work queue and thread for a infeed/outfeed thread."""
    382 
    383   def __init__(self, name, target, args):
    384     self._name = name
    385     self._queue = Queue.Queue()
    386     args = (self,) + args
    387     self._thread = threading.Thread(name=name, target=target, args=args)
    388     self._thread.daemon = True
    389     self._thread.start()
    390 
    391   def stop(self):
    392     self._queue.put(_SIGNAL.STOP)
    393 
    394   def send_next_batch_signal(self, iterations):
    395     self._queue.put(iterations)
    396 
    397   def read_iteration_counts(self):
    398     while True:
    399       iterations = self._queue.get(block=True)
    400       logging.debug('%s read iterations %s', self._name, iterations)
    401       if iterations == _SIGNAL.STOP:
    402         logging.info('%s received shutdown signal, stopping.', self._name)
    403         return
    404       yield iterations
    405 
    406   def join(self):
    407     logging.info('Shutting down %s thread.', self._name)
    408     self.stop()
    409     self._thread.join()
    410 
    411 
    412 class _OpSignalOnceQueueContext(_OpQueueContext):
    413   """Manages work queue and thread for a infeed/outfeed thread.
    414 
    415   This subclass only signals once.
    416   """
    417 
    418   def __init__(self, name, target, args):
    419     super(_OpSignalOnceQueueContext, self).__init__(name, target, args)
    420     self._has_signaled = False
    421 
    422   def send_next_batch_signal(self, iterations):
    423     if not self._has_signaled:
    424       self._queue.put(iterations)
    425       self._has_signaled = True
    426 
    427 
    428 class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
    429   """A Session hook setting up the TPU initialization, infeed, and outfeed.
    430 
    431   This hook does two major things:
    432   1. initialize and shutdown TPU system.
    433   2. launch and join the threads for infeed enqueue and (optional) outfeed
    434      dequeue.
    435   """
    436 
    437   def __init__(self,
    438                ctx,
    439                enqueue_ops,
    440                dequeue_ops,
    441                tpu_compile_op,
    442                run_infeed_loop_on_coordinator=True,
    443                rendezvous=None,
    444                master=None,
    445                session_config=None,
    446                tpu_init_ops=None):
    447     self._master_job = ctx.master_job
    448     self._enqueue_ops = enqueue_ops
    449     self._dequeue_ops = dequeue_ops
    450     self._rendezvous = rendezvous
    451     self._master = master
    452     self._session_config = session_config
    453     self._init_ops = list(tpu_init_ops or [])
    454     if ctx.embedding_config is None:
    455       self._embedding_layer_config = None
    456     else:
    457       self._embedding_layer_config = (
    458           ctx.embedding_config.tpu_embedding.config_proto)
    459     self._run_infeed_loop_on_coordinator = run_infeed_loop_on_coordinator
    460     self._initial_infeed_sleep_secs = (
    461         ctx.config.tpu_config.initial_infeed_sleep_secs)
    462 
    463     self._feed_error = None
    464     self._finished = False
    465     # When using model parallelism, the TPU is pre-initialized at startup to
    466     # fetch mesh information.  We skip re-initializing it here to avoid
    467     # suspected issues due to the mesh layout changing on the second
    468     # initialization.
    469     self._should_initialize_tpu = not ctx.model_parallelism_enabled
    470     self._tpu_compile_op = tpu_compile_op
    471 
    472   def begin(self):
    473     logging.info('TPU job name %s', self._master_job)
    474     self._iterations_per_loop_var = _create_or_get_iterations_per_loop()
    475     if self._should_initialize_tpu:
    476       self._finalize_ops = [tpu.shutdown_system(job=self._master_job)]
    477     else:
    478       self._finalize_ops = []
    479 
    480     summary_writer_init_ops = contrib_summary.summary_writer_initializer_op()
    481     self._init_ops.extend(summary_writer_init_ops)
    482     # Get all the writer resources from the initializer, so we know what to
    483     # flush.
    484     for op in summary_writer_init_ops:
    485       self._finalize_ops.append(contrib_summary.flush(writer=op.inputs[0]))
    486 
    487   def _run_infeed(self, queue_ctx, session):
    488     logging.info('Starting infeed thread controller.')
    489     if self._initial_infeed_sleep_secs:
    490       logging.info('Infeed thread sleeping for %d seconds.',
    491                    self._initial_infeed_sleep_secs)
    492       time.sleep(self._initial_infeed_sleep_secs)
    493       logging.info('Infeed thread starting after sleep')
    494 
    495     with self._rendezvous.catch_errors(source='infeed', session=session):
    496       if self._run_infeed_loop_on_coordinator:
    497         for count, steps in enumerate(queue_ctx.read_iteration_counts()):
    498           for i in xrange(steps):
    499             logging.debug('Infeed enqueue for iteration (%d, %d)', count, i)
    500             session.run(self._enqueue_ops)
    501       else:
    502         for _ in queue_ctx.read_iteration_counts():
    503           session.run(self._enqueue_ops)
    504       logging.info('Infeed thread finished, shutting down.')
    505 
    506   def _run_outfeed(self, queue_ctx, session):
    507     logging.info('Starting outfeed thread controller.')
    508     with self._rendezvous.catch_errors(source='outfeed', session=session):
    509       for count, steps in enumerate(queue_ctx.read_iteration_counts()):
    510         for i in xrange(steps):
    511           logging.debug('Outfeed dequeue for iteration (%d, %d)', count, i)
    512           session.run(self._dequeue_ops)
    513       logging.info('Outfeed thread finished, shutting down.')
    514 
    515   def _create_infeed_controller(self, name, target, args):
    516     return _OpQueueContext(name=name, target=target, args=args)
    517 
    518   def _assertCompilationSucceeded(self, result, coord):
    519     proto = tpu_compilation_result.CompilationResultProto()
    520     proto.ParseFromString(result)
    521     if proto.status_error_message:
    522       logging.error('Compilation failed: {}'.format(proto.status_error_message))
    523       coord.request_stop()
    524     else:
    525       logging.info('Compilation succeeded')
    526 
    527   def after_create_session(self, session, coord):
    528     if self._should_initialize_tpu:
    529       logging.info('Init TPU system')
    530       start = time.time()
    531       with ops.Graph().as_default():
    532         with tf_session.Session(
    533             self._master, config=self._session_config) as sess:
    534           sess.run(
    535               tpu.initialize_system(
    536                   job=self._master_job,
    537                   embedding_config=self._embedding_layer_config))
    538       logging.info('Initialized TPU in %d seconds', time.time() - start)
    539 
    540     session.run(self._init_ops,
    541                 options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000))
    542 
    543     if os.environ.get('TPU_SPLIT_COMPILE_AND_EXECUTE', '') == '1':
    544       logging.info('Compiling user program: this may take a while...')
    545       self._assertCompilationSucceeded(session.run(self._tpu_compile_op), coord)
    546 
    547     self._infeed_controller = self._create_infeed_controller(
    548         name='InfeedController', target=self._run_infeed, args=(session,))
    549 
    550     self._outfeed_controller = _OpQueueContext(
    551         name='OutfeedController', target=self._run_outfeed, args=(session,))
    552 
    553     # Enable the worker watchdog to terminate workers on coordinator exit.
    554     watchdog_timeout = int(os.environ.get('TF_TPU_WATCHDOG_TIMEOUT', '0'))
    555     if watchdog_timeout > 0:
    556       session_support.start_worker_watchdog(session,
    557                                             shutdown_timeout=watchdog_timeout)
    558 
    559   def before_run(self, run_context):
    560     self._feed_error = None
    561 
    562     iterations = run_context.session.run(self._iterations_per_loop_var)
    563 
    564     logging.info('Enqueue next (%d) batch(es) of data to infeed.', iterations)
    565     self._infeed_controller.send_next_batch_signal(iterations)
    566 
    567     logging.info('Dequeue next (%d) batch(es) of data from outfeed.',
    568                  iterations)
    569     self._outfeed_controller.send_next_batch_signal(iterations)
    570 
    571   def end(self, session):
    572     self._finished = True
    573     logging.info('Stop infeed thread controller')
    574     self._infeed_controller.join()
    575     self._rendezvous.record_done('infeed')
    576 
    577     logging.info('Stop output thread controller')
    578     self._outfeed_controller.join()
    579     self._rendezvous.record_done('outfeed')
    580 
    581     logging.info('Shutdown TPU system.')
    582     session.run(self._finalize_ops)
    583 
    584 
    585 class TPUInfeedOutfeedSessionHookForPrediction(TPUInfeedOutfeedSessionHook):
    586 
    587   def __init__(self, ctx, enqueue_ops, dequeue_ops, tpu_compile_op,
    588                rendezvous=None, master=None, session_config=None):
    589     super(TPUInfeedOutfeedSessionHookForPrediction, self).__init__(
    590         ctx,
    591         enqueue_ops,
    592         dequeue_ops,
    593         tpu_compile_op=tpu_compile_op,
    594         run_infeed_loop_on_coordinator=False,
    595         rendezvous=rendezvous,
    596         master=master,
    597         session_config=session_config)
    598 
    599   def _create_infeed_controller(self, name, target, args):
    600     return _OpSignalOnceQueueContext(name=name, target=target, args=args)
    601 
    602 
    603 class _TPUStopAtStepHook(session_run_hook.SessionRunHook):
    604   """Hook that requests stop at a specified step.
    605 
    606   This hook is similar to the `session_run_hook._StopAfterNEvalsHook` with
    607   following differences for TPU training:
    608 
    609   1. This hook sets the variable for iterations_per_loop, which is used by
    610      `TPUInfeedOutfeedSessionHook` to control the iterations for infeed/outfeed.
    611      As the hook execution order is not guaranteed, the variable update is
    612      handled in `after_create_session` and `after_run` as
    613      `TPUInfeedOutfeedSessionHook` reads the variable value in `before_run`.
    614 
    615   2. For each training loop (session.run), the global step could be increased
    616      multiple times on TPU. The global step tensor value will be explicitly read
    617      again in `after_run` to ensure the latest value is retrieved to avoid race
    618      condition.
    619   """
    620 
    621   def __init__(self, iterations, num_steps=None, last_step=None):
    622     """Initializes a `StopAtStepHook`.
    623 
    624     Args:
    625       iterations: The number of iterations to run optimizer per training loop.
    626       num_steps: Number of steps to execute.
    627       last_step: Step after which to stop.
    628 
    629     Raises:
    630       ValueError: If one of the arguments is invalid.
    631     """
    632     if num_steps is None and last_step is None:
    633       raise ValueError('One of num_steps or last_step must be specified.')
    634     if num_steps is not None and last_step is not None:
    635       raise ValueError('Only one of num_steps or last_step can be specified.')
    636     self._num_steps = num_steps
    637     self._last_step = last_step
    638     self._iterations = iterations
    639 
    640   def _next_iterations(self, global_step, last_step):
    641     gap = last_step - global_step
    642     return min(gap, self._iterations)
    643 
    644   def begin(self):
    645     self._global_step_tensor = training_util.get_global_step()
    646     if self._global_step_tensor is None:
    647       raise RuntimeError('Global step should be created.')
    648 
    649     self._iterations_per_loop_var = _create_or_get_iterations_per_loop()
    650 
    651   def after_create_session(self, session, coord):
    652     global_step = session.run(self._global_step_tensor)
    653     if self._last_step is None:
    654       self._last_step = global_step + self._num_steps
    655 
    656     iterations = self._next_iterations(global_step, self._last_step)
    657 
    658     self._iterations_per_loop_var.load(iterations, session=session)
    659 
    660   def after_run(self, run_context, run_values):
    661     # Global step cannot be retrieved via SessionRunArgs and before_run due to
    662     # race condition.
    663     global_step = run_context.session.run(self._global_step_tensor)
    664     if global_step >= self._last_step:
    665       run_context.request_stop()
    666     else:
    667       iterations = self._next_iterations(global_step, self._last_step)
    668       self._iterations_per_loop_var.load(
    669           iterations, session=run_context.session)
    670 
    671 
    672 class _SetEvalIterationsHook(session_run_hook.SessionRunHook):
    673   """Hook that requests stop at a specified step."""
    674 
    675   def __init__(self, num_steps):
    676     """Initializes a `_SetEvalIterationsHook`.
    677 
    678     Args:
    679       num_steps: Number of steps to execute.
    680     """
    681     self._num_steps = num_steps
    682 
    683   def begin(self):
    684     self._iterations_per_loop_var = _create_or_get_iterations_per_loop()
    685 
    686   def after_create_session(self, session, coord):
    687     self._iterations_per_loop_var.load(self._num_steps, session=session)
    688 
    689 
    690 class _StoppingPredictHook(session_run_hook.SessionRunHook):
    691   """Hook that requests stop according to the stopping signal in prediction."""
    692 
    693   def __init__(self, scalar_stopping_signal):
    694     self._scalar_stopping_signal = scalar_stopping_signal
    695 
    696   def begin(self):
    697     self._iterations_per_loop_var = _create_or_get_iterations_per_loop()
    698 
    699   def after_create_session(self, session, coord):
    700     # This is not necessary as we do not run infeed enqueue and outfeed dequeue
    701     # in side threads for prediction model. But it makes the
    702     # TPUInfeedOutfeedSessionHook prints nice message.
    703     self._iterations_per_loop_var.load(1, session=session)
    704 
    705   def before_run(self, run_context):
    706     return session_run_hook.SessionRunArgs(self._scalar_stopping_signal)
    707 
    708   def after_run(self, run_context, run_values):
    709     _ = run_context
    710     scalar_stopping_signal = run_values.results
    711     if _StopSignals.should_stop(scalar_stopping_signal):
    712       # NOTE(xiejw): In prediction, stopping signals are inserted for each
    713       # batch. And we append one more batch to signal the system it should stop.
    714       # The data flow might look like
    715       #
    716       #  batch   0: images, labels, stop = 0  (user provided)
    717       #  batch   1: images, labels, stop = 0  (user provided)
    718       #  ...
    719       #  batch  99: images, labels, stop = 0  (user provided)
    720       #  batch 100: images, labels, stop = 1  (TPUEstimator appended)
    721       #
    722       # where the final batch (id = 100) is appended by TPUEstimator, so we
    723       # should drop it before returning the predictions to user.
    724       # To achieve that, we throw the OutOfRangeError in after_run. Once
    725       # Monitored Session sees this error in SessionRunHook.after_run, the
    726       # "current" prediction, i.e., batch with id=100, will be discarded
    727       # immediately
    728       raise errors.OutOfRangeError(None, None, 'Stopped by stopping signal.')
    729 
    730 
    731 def generate_per_core_enqueue_ops_fn_for_host(
    732     ctx, input_fn, inputs_structure_recorder, host_device, host_id):
    733   """Generates infeed enqueue ops for per-core input_fn on a single host."""
    734   captured_infeed_queue = _CapturedObject()
    735   tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id)
    736 
    737   def enqueue_ops_fn():
    738     """A fn returns enqueue_ops."""
    739     num_cores_per_host = ctx.num_of_cores_per_host
    740     per_host_sharded_inputs = []
    741     for core_ordinal in range(num_cores_per_host):
    742       with ops.name_scope('ordinal_%d' % (core_ordinal)):
    743         user_context = tpu_context.TPUContext(
    744             internal_ctx=ctx,
    745             input_device=host_device,
    746             invocation_index=host_id * ctx.num_of_cores_per_host + core_ordinal)
    747         inputs = _Inputs.from_input_fn(input_fn(user_context))
    748         if inputs.is_dataset:
    749           raise TypeError(
    750               '`input_fn` returning `Dataset`  is not yet supported in '
    751               'per-Core input pipeline deployment yet. Please set '
    752               'TPUConfig.per_host_input_for_training to True or return '
    753               '`features` and `labels` from `input_fn`')
    754         features, labels = inputs.features_and_labels()
    755 
    756         inputs_structure_recorder.validate_and_record_structure(
    757             features, labels)
    758         flattened_inputs = (
    759             inputs_structure_recorder.flatten_features_and_labels(
    760                 features, labels))
    761         per_host_sharded_inputs.append(flattened_inputs)
    762 
    763     infeed_queue = tpu_feed.InfeedQueue(
    764         number_of_tuple_elements=len(per_host_sharded_inputs[0]))
    765     captured_infeed_queue.capture(infeed_queue)
    766 
    767     per_host_enqueue_ops = infeed_queue.generate_enqueue_ops(
    768         per_host_sharded_inputs, tpu_ordinal_function=tpu_ordinal_function_impl)
    769     return per_host_enqueue_ops
    770 
    771   return enqueue_ops_fn, captured_infeed_queue
    772 
    773 
    774 def generate_per_host_enqueue_ops_fn_for_host(
    775     ctx, input_fn, inputs_structure_recorder, batch_axis, device, host_id):
    776   """Generates infeed enqueue ops for per-host input_fn on a single host."""
    777   captured_infeed_queue = _CapturedObject()
    778 
    779   dataset_initializer = None
    780 
    781   with ops.device(device):
    782     user_context = tpu_context.TPUContext(
    783         internal_ctx=ctx, input_device=device, invocation_index=host_id)
    784     inputs = _Inputs.from_input_fn(input_fn(user_context))
    785 
    786     is_dataset = inputs.is_dataset
    787     if ctx.mode == model_fn_lib.ModeKeys.PREDICT:
    788       if not is_dataset:
    789         raise TypeError(
    790             'For mode PREDICT, `input_fn` must return `Dataset` instead of '
    791             '`features` and `labels`.')
    792       if batch_axis is not None:
    793         raise TypeError('For mode PREDICT, batch_axis is not supported yet.')
    794       inputs = _InputsWithStoppingSignals(
    795           dataset=inputs.dataset,
    796           batch_size=ctx.batch_size_for_input_fn,
    797           add_padding=True)
    798 
    799     if is_dataset:
    800       dataset_initializer = inputs.dataset_initializer()
    801 
    802     tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id)
    803 
    804   def enqueue_ops_fn():
    805     """A Fn returning the TPU infeed enqueue ops.
    806 
    807     By providing as a Fn, it can be invoked inside the tf.while_loop such that
    808     the input pipeline for multiple iterations can be executed by one
    809     Session.run call.
    810 
    811     Returns:
    812       list of dict of ops.
    813     """
    814     with ops.device(device):
    815       num_of_replicas_per_host = ctx.num_of_replicas_per_host
    816       # Convert user input to features and labels.  If the user returns a
    817       # dataset, it is initialized and the features and labels extracted via
    818       # `dataset.iterator.get_next()`
    819       features, labels = inputs.features_and_labels()
    820       signals = inputs.signals()
    821 
    822       inputs_structure_recorder.validate_and_record_structure(features, labels)
    823       unsharded_tensor_list = (
    824           inputs_structure_recorder.flatten_features_and_labels(
    825               features, labels, signals))
    826 
    827       infeed_queue = tpu_feed.InfeedQueue(
    828           tuple_types=[t.dtype for t in unsharded_tensor_list],
    829           tuple_shapes=[t.shape for t in unsharded_tensor_list],
    830           shard_dimensions=batch_axis)
    831       captured_infeed_queue.capture(infeed_queue)
    832       infeed_queue.set_number_of_shards(num_of_replicas_per_host)
    833       per_host_enqueue_ops = (
    834           infeed_queue.split_inputs_and_generate_enqueue_ops(
    835               unsharded_tensor_list,
    836               placement_function=lambda x: device,
    837               tpu_ordinal_function=tpu_ordinal_function_impl))
    838       if signals is None:
    839         return per_host_enqueue_ops
    840       else:
    841         return {
    842             'ops': per_host_enqueue_ops,
    843             'signals': signals,
    844         }
    845 
    846   return enqueue_ops_fn, captured_infeed_queue, dataset_initializer
    847 
    848 
    849 def generate_per_host_v2_enqueue_ops_fn_for_host(
    850     ctx, input_fn, inputs_structure_recorder, device, host_id):
    851   """Generates infeed enqueue ops for per-host input_fn on a single host."""
    852   captured_infeed_queue = _CapturedObject()
    853   dataset_initializer = None
    854 
    855   with ops.device(device):
    856     user_context = tpu_context.TPUContext(
    857         internal_ctx=ctx, input_device=device, invocation_index=host_id)
    858     inputs = _Inputs.from_input_fn(input_fn(user_context))
    859 
    860     is_dataset = inputs.is_dataset
    861     if not is_dataset:
    862       raise TypeError('`input_fn` must return a `Dataset` for the PER_HOST_V2 '
    863                       'input pipeline configuration.')
    864 
    865     if ctx.mode == model_fn_lib.ModeKeys.PREDICT:
    866       inputs = _InputsWithStoppingSignals(
    867           dataset=inputs.dataset,
    868           batch_size=ctx.batch_size_for_input_fn,
    869           add_padding=True,
    870           num_invocations_per_step=ctx.num_of_replicas_per_host)
    871 
    872     dataset_initializer = inputs.dataset_initializer()
    873     tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id)
    874 
    875   def enqueue_ops_fn():
    876     """Generates the per_host enqueue ops."""
    877     control_deps = []
    878     per_host_sharded_inputs = []
    879     sparse_features_list = []
    880     num_replicas_per_host = ctx.num_of_replicas_per_host
    881     cached_signals = None
    882     with ops.device(device):
    883       if not inputs.is_dataset:
    884         raise TypeError('`input_fn` must return a `Dataset` for this mode.')
    885       for _ in range(num_replicas_per_host):
    886         # Use control dependencies to ensure a deterministic ordering.
    887         with ops.control_dependencies(control_deps):
    888           features, labels = inputs.features_and_labels()  # Calls get_next()
    889           signals = inputs.signals()
    890 
    891           # All the replicas share the replica 0's stopping singal.
    892           # This avoids inconsistent state among different model replcias.
    893           if cached_signals:
    894             signals['stopping'] = cached_signals['stopping']
    895           else:
    896             cached_signals = signals
    897 
    898         features, labels, sparse_features = (
    899             _tpu_estimator_embedding.split_inputs(ctx, features, labels))
    900         sparse_features_list.append(sparse_features)
    901 
    902         inputs_structure_recorder.validate_and_record_structure(
    903             features, labels)
    904         flattened_inputs = (
    905             inputs_structure_recorder.flatten_features_and_labels(
    906                 features, labels, signals))
    907         control_deps.extend(flattened_inputs)
    908         per_host_sharded_inputs.append(flattened_inputs)
    909 
    910       if inputs_structure_recorder.flattened_input_dims:
    911         input_partition_dims = inputs_structure_recorder.flattened_input_dims
    912         if signals:
    913           input_partition_dims += [None] * len(signals)
    914         # pylint: disable=protected-access
    915         infeed_queue = tpu_feed._PartitionedInfeedQueue(
    916             number_of_tuple_elements=len(per_host_sharded_inputs[0]),
    917             host_id=host_id,
    918             input_partition_dims=input_partition_dims,
    919             device_assignment=ctx.device_assignment)
    920         per_host_enqueue_ops = infeed_queue.generate_enqueue_ops(
    921             per_host_sharded_inputs)
    922       else:
    923         infeed_queue = tpu_feed.InfeedQueue(
    924             number_of_tuple_elements=len(per_host_sharded_inputs[0]))
    925         per_host_enqueue_ops = infeed_queue.generate_enqueue_ops(
    926             per_host_sharded_inputs,
    927             tpu_ordinal_function=tpu_ordinal_function_impl)
    928       captured_infeed_queue.capture(infeed_queue)
    929 
    930     if ctx.embedding_config:
    931       per_host_enqueue_ops.extend(
    932           ctx.embedding_config.tpu_embedding.generate_enqueue_ops(
    933               sparse_features_list))
    934 
    935     if signals is None:
    936       return per_host_enqueue_ops
    937     else:
    938       return {
    939           'ops': per_host_enqueue_ops,
    940           'signals': signals,
    941       }
    942 
    943   return enqueue_ops_fn, captured_infeed_queue, dataset_initializer
    944 
    945 
    946 def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder,
    947                                       num_hosts):
    948   """Generates infeed enqueue ops for one input_fn on all the hosts."""
    949   captured_infeed_queue = _CapturedObject()
    950   dataset_initializer = None
    951   device_0 = ctx.tpu_host_placement_function(host_id=0)
    952   with ops.device(device_0):
    953     user_context = tpu_context.TPUContext(
    954         internal_ctx=ctx, input_device=device_0, invocation_index=0)
    955     inputs = _Inputs.from_input_fn(input_fn(user_context))
    956 
    957     is_dataset = inputs.is_dataset
    958     if ctx.mode == model_fn_lib.ModeKeys.PREDICT:
    959       if not is_dataset:
    960         raise TypeError(
    961             'For mode PREDICT, `input_fn` must return `Dataset` instead of '
    962             '`features` and `labels`.')
    963 
    964       inputs = _InputsWithStoppingSignals(
    965           dataset=inputs.dataset,
    966           batch_size=ctx.batch_size_for_input_fn,
    967           add_padding=True)
    968 
    969     if is_dataset:
    970       dataset_initializer = inputs.dataset_initializer()
    971     num_replicas_per_host = ctx.num_of_replicas_per_host
    972 
    973   def tpu_ordinal_function_impl(replica_id):
    974     if ctx.device_assignment:
    975       return ctx.device_assignment.tpu_ordinal(replica=replica_id)
    976     else:
    977       return replica_id % num_replicas_per_host
    978 
    979   def device_function_impl(replica_id):
    980     return ctx.tpu_host_placement_function(replica_id=replica_id)
    981 
    982   def enqueue_ops_fn():
    983     """Generates enqueue ops for all the hosts."""
    984     broadcasted_inputs = []
    985     flattened_inputs = None  # Cache result from input_fn.
    986     signals = None
    987     num_replicas = ctx.num_replicas
    988     core_id = 0
    989     for host_id in xrange(num_hosts):
    990       with ops.device(ctx.tpu_host_placement_function(host_id=host_id)):
    991         for _ in xrange(ctx.num_of_replicas_per_host):
    992           # Note: input_fn is only called once at host 0 for the first replica.
    993           # The features and labels returned from that invocation are
    994           # broadcasted to other replicas(including the replicas on other
    995           # hosts).
    996           if flattened_inputs is None:
    997             features, labels = inputs.features_and_labels()  # Calls get_next()
    998             signals = inputs.signals()
    999 
   1000             inputs_structure_recorder.validate_and_record_structure(
   1001                 features, labels)
   1002             flattened_inputs = (
   1003                 inputs_structure_recorder.flatten_features_and_labels(
   1004                     features, labels, signals))
   1005             if (ctx.config.tpu_config.eval_training_input_configuration is
   1006                 tpu_config.InputPipelineConfig.SLICED):
   1007               input_slices = [
   1008                   array_ops.split(x, num_replicas) for x in flattened_inputs
   1009               ]
   1010           if (ctx.config.tpu_config.eval_training_input_configuration is
   1011               tpu_config.InputPipelineConfig.SLICED):
   1012             # for each core, slice out the flattened_inputs for each core.
   1013             broadcasted_inputs.append([x[core_id] for x in input_slices])
   1014             core_id += 1
   1015           else:
   1016             broadcasted_inputs.append(flattened_inputs)
   1017 
   1018     infeed_queue = tpu_feed.InfeedQueue(
   1019         number_of_tuple_elements=len(broadcasted_inputs[0]))
   1020     captured_infeed_queue.capture(infeed_queue)
   1021     enqueue_ops = infeed_queue.generate_enqueue_ops(
   1022         broadcasted_inputs,
   1023         tpu_ordinal_function=tpu_ordinal_function_impl,
   1024         placement_function=device_function_impl)
   1025 
   1026     if signals is None:
   1027       return enqueue_ops
   1028     else:
   1029       return {
   1030           'ops': enqueue_ops,
   1031           'signals': signals,
   1032       }
   1033 
   1034   return enqueue_ops_fn, captured_infeed_queue, dataset_initializer
   1035 
   1036 
   1037 class _InputPipeline(object):
   1038   """`_InputPipeline` handles invoking `input_fn` and piping to infeed queue.
   1039 
   1040   `_InputPipeline` abstracts the per-core/per-host `input_fn` invocation from
   1041   call site.  To be precise, based on the configuration in
   1042   `_InternalTPUContext`,  it invokes `input_fn` for all cores (usually
   1043   multi-host TPU training) or for one host (usually for single-host TPU
   1044   evaluation), and sends all `features` and `labels` returned by `input_fn` to
   1045   TPU infeed. For per-core invocation, `features` and `labels` are piped to
   1046   infeed directly, one tuple for each core. For per-host invocation,  `features`
   1047   and `labels` are split at host (with respect to `batch_axis`) and piped to all
   1048   cores accordingly.
   1049 
   1050   In addition, flatten/unflatten are handled by `_InputPipeline` also.  Model
   1051   inputs returned by the `input_fn` can have one of the following forms:
   1052   1. features
   1053   2. (features, labels)
   1054   3. ((arbitrarily nested structure of features), labels)
   1055 
   1056   Internally, form 1 is reformed to `(features, None)` as features and labels
   1057   are passed separately to underlying methods. For TPU training, TPUEstimator
   1058   may expect multiple `features` and `labels` tuples one for each core.
   1059 
   1060   TPUEstimator allows various different structures for inputs (namely `features`
   1061   and `labels`).  Both `features` and `labels` can be any nested sturcture
   1062   supported by TF nest (namely, dict, tuples, namedtuples or any nested
   1063   structure of such of Tensors).  `labels` could be `None` as well.
   1064 
   1065   These are flattened before they are passed to the infeed/outfeed library
   1066   as that expectes flattend lists.
   1067   """
   1068 
   1069   class InputsStructureRecorder(object):
   1070     """The recorder to record inputs structure."""
   1071 
   1072     def __init__(self, input_partition_dims=None):
   1073       # Holds the structure of inputs
   1074       self._feature_structure = {}
   1075       self._flattened_input_dims = None
   1076 
   1077       if input_partition_dims:
   1078         # This should have been validated in TPUConfig.
   1079         assert len(input_partition_dims) <= 2, 'must have 1 or 2 elements.'
   1080         if len(input_partition_dims) == 2:
   1081           self._feature_dims, self._label_dims = input_partition_dims
   1082         else:
   1083           self._feature_dims = input_partition_dims[0]
   1084           self._label_dims = None
   1085 
   1086         assert self._feature_dims is not None, ('input_partition_dims[0] must '
   1087                                                 'not be None')
   1088       else:
   1089         self._feature_dims = None
   1090         self._label_dims = None
   1091 
   1092       # Internal state.
   1093       self._initialized = False
   1094 
   1095     @property
   1096     def flattened_input_dims(self):
   1097       assert self._initialized, 'InputsStructureRecorder is not initialized.'
   1098       return self._flattened_input_dims
   1099 
   1100     def has_labels(self):
   1101       return 'labels' in self._feature_structure
   1102 
   1103     def _flatten_input_dims(self, feature_dims, feature_dims_names, label_dims,
   1104                             label_dims_names, label_names, has_labels):
   1105       """Flatten input dims with the same order as flattened input tensors."""
   1106       flattened_input_dims = []
   1107       if feature_dims_names:
   1108         # We need a fixed ordering for matching the tensors in features.
   1109         flattened_input_dims.extend(
   1110             [feature_dims[name] for name in feature_dims_names])
   1111       else:
   1112         flattened_input_dims.append(feature_dims)
   1113 
   1114       if label_dims_names:
   1115         # We need a fixed ordering for matching the tensors in labels.
   1116         flattened_input_dims.extend(
   1117             [label_dims[name] for name in label_dims_names])
   1118       else:
   1119         if label_names:
   1120           num_tensors_in_label = len(label_names)
   1121         else:
   1122           num_tensors_in_label = int(has_labels)
   1123         # Setting `None` in input_partition_dims[1] will apply `None` to
   1124         # all the tensors in labels, regardless of internal structure.
   1125         flattened_input_dims.extend([label_dims] * num_tensors_in_label)
   1126 
   1127       return flattened_input_dims
   1128 
   1129     def validate_and_record_structure(self, features, labels):
   1130       """Validates and records the structure of `features` and `labels`."""
   1131       # Extract structure.
   1132       has_labels = labels is not None
   1133       feature_names = _extract_key_names(features)
   1134       label_names = _extract_key_names(labels)
   1135 
   1136       if not self._initialized:
   1137         # Record structure.
   1138         self._initialized = True
   1139         if self._feature_dims is not None:
   1140           feature_dims_names = _extract_key_names(self._feature_dims)
   1141           if feature_dims_names != feature_names:
   1142             raise ValueError(
   1143                 'TPUConfig.input_partition_dims[0] mismatched feature'
   1144                 ' keys. Expected {}, got {}'.format(feature_names,
   1145                                                     feature_dims_names))
   1146 
   1147           label_dims_names = _extract_key_names(self._label_dims)
   1148           if self._label_dims is not None and label_dims_names != label_names:
   1149             raise ValueError(
   1150                 'TPUConfig.input_partition_dims[1] mismatched label'
   1151                 ' keys. Expected {}, got {}'.format(label_names,
   1152                                                     label_dims_names))
   1153 
   1154           self._flattened_input_dims = self._flatten_input_dims(
   1155               self._feature_dims, feature_dims_names, self._label_dims,
   1156               label_dims_names, label_names, has_labels)
   1157 
   1158     def flatten_features_and_labels(self, features, labels, signals=None):
   1159       """Flattens the `features` and `labels` to a single tensor list."""
   1160       self._feature_structure['features'] = features
   1161       if labels is not None:
   1162         self._feature_structure['labels'] = labels
   1163       if signals is not None:
   1164         self._feature_structure['signals'] = signals
   1165       return data_nest.flatten(self._feature_structure)
   1166 
   1167     def unflatten_features_and_labels(self, flattened_inputs):
   1168       """Restores the flattened inputs to original features and labels form.
   1169 
   1170       Args:
   1171         flattened_inputs: Flattened inputs for each shard.
   1172 
   1173       Returns:
   1174         A tuple of (`features`, `labels`), where `labels` could be None.
   1175         Each one, if present, should have identical structure (single tensor vs
   1176         dict) as the one returned by input_fn.
   1177 
   1178       Raises:
   1179         ValueError: If the number of expected tensors from `flattened_inputs`
   1180           mismatches the recorded structure.
   1181       """
   1182 
   1183       unflattened_inputs = data_nest.pack_sequence_as(self._feature_structure,
   1184                                                       flattened_inputs)
   1185       return _Inputs(
   1186           unflattened_inputs['features'],
   1187           unflattened_inputs.get('labels'),
   1188           signals=unflattened_inputs.get('signals'))
   1189 
   1190   def __init__(self, input_fn, batch_axis, ctx):
   1191     """Constructor.
   1192 
   1193     Args:
   1194       input_fn: input fn for train or eval.
   1195       batch_axis: A python tuple of int values describing how each tensor
   1196         produced by the Estimator `input_fn` should be split across the TPU
   1197         compute shards.
   1198       ctx: A `_InternalTPUContext` instance with mode.
   1199 
   1200     Raises:
   1201       ValueError: If both `sharded_features` and `num_cores` are `None`.
   1202     """
   1203     self._inputs_structure_recorder = _InputPipeline.InputsStructureRecorder(
   1204         ctx.input_partition_dims)
   1205 
   1206     self._sharded_per_core = ctx.is_input_sharded_per_core()
   1207     self._input_fn = input_fn
   1208     self._infeed_queue = None
   1209     self._ctx = ctx
   1210     self._batch_axis = batch_axis
   1211 
   1212   def generate_infeed_enqueue_ops_and_dequeue_fn(self):
   1213     """Generates infeed enqueue ops and dequeue_fn."""
   1214     # While tf.while_loop is called, the body function, which invokes
   1215     # `enqueue_fn` passed in, is called to construct the graph. So, input_fn
   1216     # structure is recorded.
   1217     enqueue_ops, all_hooks, run_infeed_loop_on_coordinator = (
   1218         self._invoke_input_fn_and_record_structure())
   1219 
   1220     self._validate_input_pipeline()
   1221 
   1222     def dequeue_fn():
   1223       """dequeue_fn is used by TPU to retrieve the tensors."""
   1224       # In the model-parallel case, both the host-side and device-side
   1225       # computations must agree on the core on which infeed takes place. We
   1226       # choose to perform infeed on logical core 0 of each replica.
   1227       values = self._infeed_queue.generate_dequeue_op(tpu_device=0)
   1228       # The unflatten process uses the structure information recorded above.
   1229       return self._inputs_structure_recorder.unflatten_features_and_labels(
   1230           values)
   1231 
   1232     return (enqueue_ops, dequeue_fn, all_hooks, run_infeed_loop_on_coordinator)
   1233 
   1234   def _invoke_input_fn_and_record_structure(self):
   1235     """Deploys the input pipeline and record input structure."""
   1236     enqueue_ops = []
   1237     infeed_queues = []
   1238     all_dataset_initializers = []
   1239     num_hosts = self._ctx.num_hosts
   1240     tpu_host_placement_fn = self._ctx.tpu_host_placement_function
   1241 
   1242     run_infeed_loop_on_coordinator = True
   1243 
   1244     if self._sharded_per_core:
   1245       # Per-Core input pipeline deployment.
   1246       # Invoke input pipeline for each core and placed on the corresponding
   1247       # host.
   1248       for host_id in range(num_hosts):
   1249         host_device = tpu_host_placement_fn(host_id=host_id)
   1250         with ops.device(host_device):
   1251           with ops.name_scope('input_pipeline_task%d' % (host_id)):
   1252             enqueue_ops_fn, captured_infeed_queue = (
   1253                 generate_per_core_enqueue_ops_fn_for_host(
   1254                     self._ctx, self._input_fn, self._inputs_structure_recorder,
   1255                     host_device, host_id))
   1256 
   1257             if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
   1258               run_infeed_loop_on_coordinator = False
   1259               enqueue_ops.append(
   1260                   _wrap_computation_in_while_loop(
   1261                       device=host_device, op_fn=enqueue_ops_fn))
   1262             else:
   1263               enqueue_ops.append(enqueue_ops_fn())
   1264             # Infeed_queue_getter must be called after enqueue_ops_fn is called.
   1265             infeed_queues.append(captured_infeed_queue.get())
   1266 
   1267     elif self._ctx.is_input_broadcast_with_iterators():
   1268       # Only calls input_fn in host 0.
   1269       host_device = tpu_host_placement_fn(host_id=0)
   1270       enqueue_ops_fn, captured_infeed_queue, dataset_initializer = (
   1271           generate_broadcast_enqueue_ops_fn(self._ctx, self._input_fn,
   1272                                             self._inputs_structure_recorder,
   1273                                             num_hosts))
   1274       if dataset_initializer:
   1275         all_dataset_initializers.append(dataset_initializer)
   1276         run_infeed_loop_on_coordinator = False
   1277         wrap_fn = (
   1278             _wrap_computation_in_while_loop
   1279             if self._ctx.mode != model_fn_lib.ModeKeys.PREDICT else
   1280             _wrap_computation_in_while_loop_with_stopping_signals)
   1281         enqueue_ops.append(wrap_fn(device=host_device, op_fn=enqueue_ops_fn))
   1282       else:
   1283         enqueue_ops.append(enqueue_ops_fn())
   1284       infeed_queues.append(captured_infeed_queue.get())
   1285     else:
   1286       for host_id in range(num_hosts):
   1287         host_device = tpu_host_placement_fn(host_id=host_id)
   1288         with ops.device(host_device):
   1289           with ops.name_scope('input_pipeline_task%d' % (host_id)):
   1290             if self._ctx.is_input_per_host_with_iterators():
   1291               enqueue_ops_fn, captured_infeed_queue, dataset_initializer = (
   1292                   generate_per_host_v2_enqueue_ops_fn_for_host(
   1293                       self._ctx, self._input_fn,
   1294                       self._inputs_structure_recorder, host_device, host_id))
   1295             else:
   1296               enqueue_ops_fn, captured_infeed_queue, dataset_initializer = (
   1297                   generate_per_host_enqueue_ops_fn_for_host(
   1298                       self._ctx, self._input_fn,
   1299                       self._inputs_structure_recorder, self._batch_axis,
   1300                       host_device, host_id))
   1301 
   1302             # NOTE(xiejw): We dispatch here based on the return type of the
   1303             # users `input_fn`.
   1304             #
   1305             # 1. If input_fn returns a Dataset instance, we initialize the
   1306             # iterator outside of tf.while_loop, and call the iterator.get_next
   1307             # inside tf.while_loop.  This should be always safe.
   1308             #
   1309             # 2. If input_fn returns (features, labels), it is too late to wrap
   1310             # them inside tf.while_loop, as resource initialization cannot be
   1311             # handled in TF control flow properly. In this case, we will use
   1312             # python loop to enqueue the data into TPU system.  This may be
   1313             # slow compared to the previous case.
   1314             if dataset_initializer:
   1315               all_dataset_initializers.append(dataset_initializer)
   1316               run_infeed_loop_on_coordinator = False
   1317               wrap_fn = (
   1318                   _wrap_computation_in_while_loop
   1319                   if self._ctx.mode != model_fn_lib.ModeKeys.PREDICT else
   1320                   _wrap_computation_in_while_loop_with_stopping_signals)
   1321               enqueue_ops.append(
   1322                   wrap_fn(device=host_device, op_fn=enqueue_ops_fn))
   1323             else:
   1324               enqueue_ops.append(enqueue_ops_fn())
   1325             infeed_queues.append(captured_infeed_queue.get())
   1326     # infeed_queue is used to generate dequeue ops. The only thing it uses for
   1327     # dequeue is dtypes and types. So, any one can be used. Here, grab the
   1328     # first one.
   1329     self._infeed_queue = infeed_queues[0]
   1330     return enqueue_ops, [
   1331         util_lib.MultiHostDatasetInitializerHook(all_dataset_initializers)
   1332     ], run_infeed_loop_on_coordinator
   1333 
   1334   def _validate_input_pipeline(self):
   1335     """Validates the input pipeline.
   1336 
   1337     Perform some sanity checks to log user friendly information. We should
   1338     error out to give users better error message. But, if
   1339     _WRAP_INPUT_FN_INTO_WHILE_LOOP is False (legacy behavior), we cannot break
   1340     user code, so, log a warning.
   1341 
   1342     Raises:
   1343       RuntimeError: If the validation failed.
   1344     """
   1345     if ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS):
   1346       err_msg = ('Input pipeline contains one or more QueueRunners. '
   1347                  'It could be slow and not scalable. Please consider '
   1348                  'converting your input pipeline to use `tf.data` instead (see '
   1349                  'https://www.tensorflow.org/guide/datasets for '
   1350                  'instructions.')
   1351       if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
   1352         raise RuntimeError(err_msg)
   1353       else:
   1354         logging.warn(err_msg)
   1355 
   1356 
   1357 def call_computation(computation,
   1358                      experimental_exported_model_uses_all_cores=True):
   1359   """Call computation.
   1360 
   1361   computation uses a single-core for TPU inference. If
   1362   `experimental_exported_model_uses_all_cores` is `True`, this function will
   1363   round-robin
   1364   computation among all TPU cores visible to the host; otherwise, it will use
   1365   a single core.
   1366 
   1367   Args:
   1368     computation: A Python function that takes no inputs and builds computation
   1369       graph. If `computation` returns m outputs, this function will return a
   1370       list of m Tensors.
   1371     experimental_exported_model_uses_all_cores: Whether to round-robin among all
   1372       cores visible to the host, or to use a single core.
   1373 
   1374   Returns:
   1375     A list of output tensors.
   1376   """
   1377   if experimental_exported_model_uses_all_cores:
   1378     # Using `TPUPartitionedCall` makes it possible to target a different
   1379     # TPU core with every `Session.run()` call. Note that the entire inference
   1380     # graph executes on a single core, and that invocations of this graph
   1381     # will round-robin among the cores attached to a host.
   1382     @function.Defun(capture_resource_var_by_value=False)
   1383     def tpu_subgraph():
   1384       return computation()
   1385 
   1386     return tpu_functional.TPUPartitionedCall(
   1387         args=tpu_subgraph.captured_inputs,
   1388         device_ordinal=tpu_ops.tpu_ordinal_selector(),
   1389         Tout=[o.type for o in tpu_subgraph.definition.signature.output_arg],
   1390         f=tpu_subgraph)
   1391   else:
   1392     return computation()
   1393 
   1394 
   1395 class _ModelFnWrapper(object):
   1396   """A `model_fn` wrapper.
   1397 
   1398   This makes calling model_fn on CPU and TPU easier and more consistent and
   1399   performs necessary check and mutation required by TPU training and evaluation.
   1400 
   1401   In addition, this wrapper manages converting the `model_fn` to a single TPU
   1402   train and eval step.
   1403   """
   1404 
   1405   def __init__(self, model_fn, config, params, ctx):
   1406     self._model_fn = model_fn
   1407     self._config = config
   1408     self._params = params
   1409     self._ctx = ctx
   1410 
   1411   def call_without_tpu(self, features, labels, is_export_mode):
   1412     return self._call_model_fn(features, labels, is_export_mode=is_export_mode)
   1413 
   1414   def _add_embedding_features(self, features, hook_dummy_table_variables):
   1415     """Add embedding features, optionally add hook to intercept gradient."""
   1416     if self._ctx.embedding_config:
   1417       tpu_embedding_ = self._ctx.embedding_config.tpu_embedding
   1418       embedding_activations = tpu_embedding_.get_activations()
   1419       if hook_dummy_table_variables:
   1420         new_embedding_activations = (
   1421             tpu_embedding_gradient.hook_dummy_table_variables_to_activations(
   1422                 tpu_embedding_, embedding_activations,
   1423                 self._ctx.embedding_config.dummy_table_variables))
   1424         features.update(new_embedding_activations)
   1425       else:
   1426         features.update(embedding_activations)
   1427 
   1428   def convert_to_single_tpu_train_step(self, dequeue_fn):
   1429     """Converts user provided model_fn` as a single train step on TPU.
   1430 
   1431     The user provided `model_fn` takes input tuple
   1432     (features, labels) and produces the EstimatorSpec with train_op and loss for
   1433     train `mode`. This usually represents a single train computation on CPU.
   1434 
   1435     For TPU training, a train (computation) step is first wrapped in a
   1436     tf.while_loop control flow to repeat for many times and then replicated to
   1437     all TPU shards. Besides the input should be taken from TPU infeed rather
   1438     than input pipeline (input_fn) directly. To fit TPU loop and replicate
   1439     pattern, the original train computation should be reformed, which is the
   1440     returned `train_step`.
   1441 
   1442     Args:
   1443       dequeue_fn: The function to retrieve inputs, features and labels, from TPU
   1444         infeed dequeue channel.
   1445 
   1446     Returns:
   1447       A tuple of train_fn, host_calls, and captured scaffold_fn. The train_fn
   1448       representing the train step for TPU.
   1449     """
   1450 
   1451     host_call = _OutfeedHostCall(self._ctx)
   1452     captured_scaffold_fn = _CapturedObject()
   1453     captured_training_hooks = _CapturedObject()
   1454 
   1455     def train_step(loss):
   1456       """Training step function for use inside a while loop."""
   1457       del loss  # unused; required in function signature.
   1458       inputs = dequeue_fn()
   1459       features, labels = inputs.features_and_labels()
   1460       self._add_embedding_features(features, True)
   1461 
   1462       estimator_spec = self._verify_estimator_spec(
   1463           self._call_model_fn(features, labels))
   1464       loss, train_op = estimator_spec.loss, estimator_spec.train_op
   1465 
   1466       if isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec):  # pylint: disable=protected-access
   1467         captured_scaffold_fn.capture(estimator_spec.scaffold_fn)
   1468       else:
   1469         captured_scaffold_fn.capture(None)
   1470 
   1471       captured_training_hooks.capture(estimator_spec.training_hooks)
   1472 
   1473       if self._ctx.embedding_config is None:
   1474         apply_sparse_grads = []
   1475       else:
   1476         tpu_embedding_ = self._ctx.embedding_config.tpu_embedding
   1477         gradients = (
   1478             tpu_embedding_gradient.get_gradients_through_dummy_table_variables(
   1479                 tpu_embedding_)
   1480         )
   1481         apply_sparse_grads = [
   1482             tpu_embedding_.generate_send_gradients_op(gradients)
   1483         ]
   1484 
   1485       # We must run train_op to update the variables prior to running the
   1486       # outfeed.
   1487       with ops.control_dependencies([train_op] + apply_sparse_grads):
   1488         host_call_outfeed_ops = []
   1489         if (isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec)  # pylint: disable=protected-access
   1490             and estimator_spec.host_call is not None):
   1491           host_call.record({'host_call': estimator_spec.host_call})
   1492           host_call_outfeed_ops = host_call.create_enqueue_op()
   1493         with ops.control_dependencies(host_call_outfeed_ops):
   1494           return array_ops.identity(loss)
   1495 
   1496     return (train_step, host_call, captured_scaffold_fn,
   1497             captured_training_hooks)
   1498 
   1499   def convert_to_single_tpu_eval_step(self, dequeue_fn):
   1500     """Converts user provided model_fn` as a single eval step on TPU.
   1501 
   1502     Similar to training, the user provided `model_fn` takes input tuple
   1503     (features, labels) and produces the TPUEstimatorSpec with eval_metrics for
   1504     eval `mode`. This usually represents a single evaluation computation on CPU.
   1505 
   1506     For TPU evaluation, a eval (computation) step is first wrapped in a
   1507     tf.while_loop control flow to repeat for many times and then replicated to
   1508     all TPU shards. Besides the input and output are slightly different. Input,
   1509     features and labels, should be taken from TPU infeed rather than input
   1510     pipeline (input_fn) directly. Output is managed in two stages.  First, the
   1511     model outputs as the result of evaluation computation, usually model logits,
   1512     should be transferred from TPU system to CPU. Then, all model outputs are
   1513     concatenated first on CPU and sent to the metric_fn for metrics computation.
   1514     To fit TPU evaluation pattern, the original eval computation should be
   1515     reformed, which is the returned `eval_step`.
   1516 
   1517     Args:
   1518       dequeue_fn: The function to retrieve inputs, features and labels, from TPU
   1519         infeed dequeue channel.
   1520 
   1521     Returns:
   1522       A tuple of eval_fn, host_calls, and captured scaffold_fn. The eval_fn
   1523       representing the eval step for TPU.
   1524     """
   1525     host_calls = _OutfeedHostCall(self._ctx)
   1526     captured_scaffold_fn = _CapturedObject()
   1527     captured_eval_hooks = _CapturedObject()
   1528 
   1529     def eval_step(total_loss):
   1530       """Evaluation step function for use inside a while loop."""
   1531       inputs = dequeue_fn()
   1532       features, labels = inputs.features_and_labels()
   1533       self._add_embedding_features(features, False)
   1534 
   1535       tpu_estimator_spec = self._call_model_fn(features, labels)
   1536       if not isinstance(tpu_estimator_spec, model_fn_lib._TPUEstimatorSpec):  # pylint: disable=protected-access
   1537         raise RuntimeError(
   1538             'estimator_spec used by TPU evaluation must have type'
   1539             '`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec)))
   1540 
   1541       loss = tpu_estimator_spec.loss
   1542       captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn)
   1543       captured_eval_hooks.capture(tpu_estimator_spec.evaluation_hooks)
   1544 
   1545       to_record = {}
   1546       if tpu_estimator_spec.eval_metrics:
   1547         to_record['eval_metrics'] = tpu_estimator_spec.eval_metrics
   1548       if tpu_estimator_spec.host_call is not None:
   1549         # We assume that evaluate won't update global step, so we don't wrap
   1550         # this host_call.
   1551         to_record['host_call'] = tpu_estimator_spec.host_call
   1552       host_calls.record(to_record)
   1553 
   1554       with ops.control_dependencies(host_calls.create_enqueue_op()):
   1555         return math_ops.add(total_loss, loss)
   1556 
   1557     return eval_step, host_calls, captured_scaffold_fn, captured_eval_hooks
   1558 
   1559   def convert_to_single_tpu_predict_step(self, dequeue_fn):
   1560     """Converts user provided model_fn` as a single predict step on TPU.
   1561 
   1562     Args:
   1563       dequeue_fn: The function to retrieve inputs, features and labels, from TPU
   1564         infeed dequeue channel.
   1565 
   1566     Returns:
   1567       A tuple of predict_fn, host_calls, and captured scaffold_fn. The
   1568       predict_fn representing the predict step for TPU.
   1569     """
   1570     host_calls = _OutfeedHostCall(self._ctx)
   1571     captured_scaffold_fn = _CapturedObject()
   1572     captured_predict_hooks = _CapturedObject()
   1573 
   1574     def predict_step(unused_scalar_stopping_signal):
   1575       """Evaluation step function for use inside a while loop."""
   1576       inputs = dequeue_fn()
   1577       features, labels = inputs.features_and_labels()
   1578       stopping_signals = inputs.signals()
   1579 
   1580       assert stopping_signals is not None, (
   1581           'Internal Error: `signals` is missing.')
   1582 
   1583       tpu_estimator_spec = self._call_model_fn(
   1584           features, labels, is_export_mode=False)
   1585       if not isinstance(tpu_estimator_spec, model_fn_lib._TPUEstimatorSpec):  # pylint: disable=protected-access
   1586         raise RuntimeError(
   1587             'estimator_spec used by TPU prediction must have type'
   1588             '`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec)))
   1589 
   1590       self._verify_tpu_spec_predictions(tpu_estimator_spec.predictions)
   1591 
   1592       captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn)
   1593       captured_predict_hooks.capture(tpu_estimator_spec.prediction_hooks)
   1594       to_record = {}
   1595       identity_fn = lambda **kwargs: kwargs
   1596       to_record['predictions'] = [identity_fn, tpu_estimator_spec.predictions]
   1597       to_record['signals'] = [identity_fn, stopping_signals]
   1598       if tpu_estimator_spec.host_call is not None:
   1599         to_record['host_call'] = tpu_estimator_spec.host_call
   1600       host_calls.record(to_record)
   1601 
   1602       with ops.control_dependencies(host_calls.create_enqueue_op()):
   1603         return _StopSignals.as_scalar_stopping_signal(stopping_signals)
   1604 
   1605     return (predict_step, host_calls, captured_scaffold_fn,
   1606             captured_predict_hooks)
   1607 
   1608   def _verify_tpu_spec_predictions(self, predictions):
   1609     """Validates TPUEstimatorSpec.predictions dict."""
   1610     # TODO(xiejw): Adds validation for prediction dictionrary.
   1611     # TODO(xiejw): Adds support for single tensor as predictions.
   1612     if not isinstance(predictions, dict):
   1613       raise TypeError('TPUEstimatorSpec.predictions must be dict of Tensors.')
   1614 
   1615     for (key, tensor) in predictions.items():
   1616       if tensor.shape.dims[0].value is None:
   1617         raise ValueError(
   1618             'The tensor with key ({}) in TPUEstimatorSpec.predictions has '
   1619             'dynamic shape (should be static). Tensor: {}'.format(key, tensor))
   1620     return predictions
   1621 
   1622   def _validate_model_features_and_labels(self, features, labels,
   1623                                           is_export_mode):
   1624     """Validates that the features and labels for the model function are valid.
   1625 
   1626     A valid features/labels object is the one with:
   1627     - Type: A tensor or any nested structure of tensors supported by TF nest,
   1628         namely nested dictionary, tuple, namedtuple, or sequence of tensors.
   1629     - Static shape if is_export_mode is False.
   1630 
   1631     Args:
   1632       features: the features that would be input to the model function.
   1633       labels: the labels that would be input to the model function.
   1634       is_export_mode: boolean value specifying if in export mode.
   1635 
   1636     Raises:
   1637       TypeError: If features/labels are not of the correct type.
   1638       ValueError: If features/labels have dynamic shape.
   1639     """
   1640 
   1641     def validate(obj, obj_name):
   1642       """Helper validate function."""
   1643       if is_export_mode or self._ctx.is_running_on_cpu(is_export_mode):
   1644         return
   1645       if isinstance(obj, ops.Tensor):
   1646         if not obj.get_shape().is_fully_defined():
   1647           raise ValueError(
   1648               'The {} to the model returned by input_fn must have static shape.'
   1649               ' Tensor: {}'.format(obj_name, obj))
   1650       else:
   1651         for tensor in data_nest.flatten(obj):
   1652           if not tensor.get_shape().is_fully_defined():
   1653             raise ValueError(
   1654                 ('The {} to the model returned by input_fn must have static '
   1655                  'shape. Tensor: {}').format(obj_name, tensor))
   1656 
   1657     validate(features, 'features')
   1658     if labels is not None:
   1659       validate(labels, 'labels')
   1660 
   1661   def _call_model_fn(self, features, labels, is_export_mode=False):
   1662     """Calls the model_fn with required parameters."""
   1663     self._validate_model_features_and_labels(features, labels, is_export_mode)
   1664     model_fn_args = function_utils.fn_args(self._model_fn)
   1665     kwargs = {}
   1666 
   1667     # Makes deep copy with `config` and params` in case user mutates them.
   1668     config = copy.deepcopy(self._config)
   1669     params = copy.deepcopy(self._params)
   1670 
   1671     if 'labels' in model_fn_args:
   1672       kwargs['labels'] = labels
   1673     elif labels is not None:
   1674       raise ValueError(
   1675           'model_fn does not take labels, but input_fn returns labels.')
   1676     if 'mode' in model_fn_args:
   1677       kwargs['mode'] = self._ctx.mode
   1678     if 'config' in model_fn_args:
   1679       kwargs['config'] = config
   1680     if 'params' in model_fn_args:
   1681       kwargs['params'] = params
   1682 
   1683     if 'params' not in model_fn_args:
   1684       raise ValueError('model_fn ({}) does not include params argument, '
   1685                        'required by TPUEstimator to pass batch size as '
   1686                        'params[\'batch_size\']'.format(self._model_fn))
   1687 
   1688     if is_export_mode:
   1689       batch_size_for_model_fn = None
   1690     else:
   1691       batch_size_for_model_fn = self._ctx.batch_size_for_model_fn
   1692 
   1693     if batch_size_for_model_fn is not None:
   1694       _add_item_to_params(params, _BATCH_SIZE_KEY, batch_size_for_model_fn)
   1695 
   1696     running_on_cpu = self._ctx.is_running_on_cpu(is_export_mode)
   1697     # In export mode, params['use_tpu'] has already been set based on mode
   1698     # (i.e. True for _REWRITE_FOR_INFERENCE_MODE, False otherwise).
   1699     if not is_export_mode:
   1700       _add_item_to_params(params, _USE_TPU_KEY, not running_on_cpu)
   1701 
   1702     if not running_on_cpu:
   1703       user_context = tpu_context.TPUContext(
   1704           internal_ctx=self._ctx, call_from_input_fn=False)
   1705       _add_item_to_params(params, _CTX_KEY, user_context)
   1706 
   1707     estimator_spec = self._model_fn(features=features, **kwargs)
   1708     if (running_on_cpu and
   1709         isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec)):  # pylint: disable=protected-access
   1710       # The estimator_spec will be passed to `Estimator` directly, which expects
   1711       # type `EstimatorSpec`.
   1712       return estimator_spec.as_estimator_spec()
   1713     else:
   1714       return estimator_spec
   1715 
   1716   def _verify_estimator_spec(self, estimator_spec):
   1717     """Validates the estimator_spec."""
   1718     if isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec):  # pylint: disable=protected-access
   1719       return estimator_spec
   1720 
   1721     err_msg = '{} returned by EstimatorSpec is not supported in TPUEstimator.'
   1722     if estimator_spec.training_chief_hooks:
   1723       raise ValueError(
   1724           err_msg.format('training_chief_hooks') + 'If you want' +
   1725           ' to pass training hooks, please pass via training_hooks.')
   1726 
   1727     if estimator_spec.scaffold:
   1728       logging.warning('EstimatorSpec.Scaffold is ignored by TPU train/eval. '
   1729                       'Please use TPUEstimatorSpec.')
   1730     return estimator_spec
   1731 
   1732 
   1733 class _OutfeedHostCall(object):
   1734   """Support for `eval_metrics` and `host_call` in TPUEstimatorSpec."""
   1735 
   1736   def __init__(self, ctx):
   1737     self._ctx = ctx
   1738     self._names = []
   1739     # All of these are dictionaries of lists keyed on the name.
   1740     self._host_fns = {}
   1741     self._tensor_keys = collections.defaultdict(list)
   1742     self._tensors = collections.defaultdict(list)
   1743     self._tensor_dtypes = collections.defaultdict(list)
   1744     self._tensor_shapes = collections.defaultdict(list)
   1745 
   1746   @staticmethod
   1747   def validate(host_calls):
   1748     """Validates the `eval_metrics` and `host_call` in `TPUEstimatorSpec`."""
   1749 
   1750     for name, host_call in host_calls.items():
   1751       if not isinstance(host_call, (tuple, list)):
   1752         raise ValueError('{} should be tuple or list'.format(name))
   1753       if len(host_call) != 2:
   1754         raise ValueError('{} should have two elements.'.format(name))
   1755       if not callable(host_call[0]):
   1756         raise TypeError('{}[0] should be callable.'.format(name))
   1757       if not isinstance(host_call[1], (tuple, list, dict)):
   1758         raise ValueError('{}[1] should be tuple or list, or dict.'.format(name))
   1759 
   1760       if isinstance(host_call[1], (tuple, list)):
   1761         fullargspec = tf_inspect.getfullargspec(host_call[0])
   1762         fn_args = function_utils.fn_args(host_call[0])
   1763         # wrapped_hostcall_with_global_step uses varargs, so we allow that.
   1764         if fullargspec.varargs is None and len(host_call[1]) != len(fn_args):
   1765           raise RuntimeError(
   1766               'In TPUEstimatorSpec.{}, length of tensors {} does not match '
   1767               'method args of the function, which takes {}.'.format(
   1768                   name, len(host_call[1]), len(fn_args)))
   1769 
   1770   @staticmethod
   1771   def create_cpu_hostcall(host_calls):
   1772     """Runs on the host_call on CPU instead of TPU when use_tpu=False."""
   1773 
   1774     _OutfeedHostCall.validate(host_calls)
   1775     ret = {}
   1776     for name, host_call in host_calls.items():
   1777       host_fn, tensors = host_call
   1778       if isinstance(tensors, (tuple, list)):
   1779         ret[name] = host_fn(*tensors)
   1780       else:
   1781         # Must be dict.
   1782         try:
   1783           ret[name] = host_fn(**tensors)
   1784         except TypeError as e:
   1785           logging.warning(
   1786               'Exception while calling %s: %s. It is likely the tensors '
   1787               '(%s[1]) do not match the '
   1788               'function\'s arguments', name, e, name)
   1789           raise
   1790     return ret
   1791 
   1792   def record(self, host_calls):
   1793     """Records the host_call structure."""
   1794 
   1795     for name, host_call in host_calls.items():
   1796       host_fn, tensor_list_or_dict = host_call
   1797       self._names.append(name)
   1798       self._host_fns[name] = host_fn
   1799 
   1800       if isinstance(tensor_list_or_dict, dict):
   1801         for (key, tensor) in six.iteritems(tensor_list_or_dict):
   1802           self._tensor_keys[name].append(key)
   1803           self._tensors[name].append(tensor)
   1804           self._tensor_dtypes[name].append(tensor.dtype)
   1805           self._tensor_shapes[name].append(tensor.shape)
   1806       else:
   1807         # List or tuple.
   1808         self._tensor_keys[name] = None
   1809         for tensor in tensor_list_or_dict:
   1810           self._tensors[name].append(tensor)
   1811           self._tensor_dtypes[name].append(tensor.dtype)
   1812           self._tensor_shapes[name].append(tensor.shape)
   1813 
   1814   def create_enqueue_op(self):
   1815     """Create the op to enqueue the recorded host_calls.
   1816 
   1817     Returns:
   1818       A list of enqueue ops, which is empty if there are no host calls.
   1819     """
   1820     if not self._names:
   1821       return []
   1822 
   1823     tensors = []
   1824     # TODO(jhseu): Consider deduping tensors.
   1825     for name in self._names:
   1826       tensors.extend(self._tensors[name])
   1827 
   1828     with ops.device(tpu.core(0)):
   1829       return [tpu_ops.outfeed_enqueue_tuple(tensors)]
   1830 
   1831   def create_tpu_hostcall(self):
   1832     """Sends the tensors through outfeed and runs the host_fn on CPU.
   1833 
   1834     The tensors are concatenated along dimension 0 to form a global tensor
   1835     across all shards. The concatenated function is passed to the host_fn and
   1836     executed on the first host.
   1837 
   1838     Returns:
   1839       A dictionary mapping name to the return type of the host_call by that
   1840       name.
   1841 
   1842     Raises:
   1843       RuntimeError: If outfeed tensor is scalar.
   1844     """
   1845     if not self._names:
   1846       return {}
   1847 
   1848     ret = {}
   1849     # For each i, dequeue_ops[i] is a list containing the tensors from all
   1850     # shards. This list is concatenated later.
   1851     dequeue_ops = []
   1852     tensor_dtypes = []
   1853     tensor_shapes = []
   1854     for name in self._names:
   1855       for _ in self._tensors[name]:
   1856         dequeue_ops.append([])
   1857       for dtype in self._tensor_dtypes[name]:
   1858         tensor_dtypes.append(dtype)
   1859       for shape in self._tensor_shapes[name]:
   1860         tensor_shapes.append(shape)
   1861 
   1862     # Outfeed ops execute on each replica's first logical core. Note: we must
   1863     # constraint it such that we have at most one outfeed dequeue and enqueue
   1864     # per replica.
   1865     for i in xrange(self._ctx.num_replicas):
   1866       host_device, ordinal_id = self._ctx.device_for_replica(i)
   1867       with ops.device(host_device):
   1868         outfeed_tensors = tpu_ops.outfeed_dequeue_tuple(
   1869             dtypes=tensor_dtypes,
   1870             shapes=tensor_shapes,
   1871             device_ordinal=ordinal_id)
   1872         for j, item in enumerate(outfeed_tensors):
   1873           dequeue_ops[j].append(item)
   1874 
   1875     # Deconstruct dequeue ops.
   1876     flat_dequeue_ops = []
   1877     for l in dequeue_ops:
   1878       flat_dequeue_ops.extend(l)
   1879 
   1880     dequeue_ops_by_name = {}
   1881     pos = 0
   1882     for name in self._names:
   1883       dequeue_ops_by_name[name] = dequeue_ops[pos:pos +
   1884                                               len(self._tensors[name])]
   1885       pos += len(self._tensors[name])
   1886 
   1887     def _call_host_fn(fn, *args, **kw):
   1888       context = CatchInvalidHostcallFunctions()
   1889       context.Enter()
   1890       result = fn(*args, **kw)
   1891       context.Exit()
   1892       context.ExitResult(result)
   1893       return result
   1894 
   1895     # It is assumed evaluation always happens on single host TPU system. So,
   1896     # place all ops on tpu host if possible.
   1897     #
   1898     # TODO(jhseu): Evaluate whether this is right for summaries.
   1899     with ops.device(self._ctx.tpu_host_placement_function(replica_id=0)):
   1900       for name in self._names:
   1901         dequeue_ops = dequeue_ops_by_name[name]
   1902         for i, item in enumerate(dequeue_ops):
   1903           if dequeue_ops[i][0].shape.ndims == 0:
   1904             raise RuntimeError(
   1905                 'All tensors outfed from TPU should preserve batch size '
   1906                 'dimension, but got scalar {}'.format(dequeue_ops[i][0]))
   1907           # TODO(xiejw): Make the specification of the outfeed combinaton
   1908           # function more explicit and well-documented.  We may want to give the
   1909           # user the option of concatenating along any axis.
   1910           if (self._ctx.config.tpu_config.per_host_input_for_training is
   1911               tpu_config.InputPipelineConfig.BROADCAST):
   1912             # If the infeed is in BROADCAST mode (each core recieving the same
   1913             # input), then we assume that the cores also produce identical
   1914             # copies of the same output, and we simply take the output from
   1915             # the first core.  This mode is used by Mesh-TensorFlow.
   1916             with ops.control_dependencies(dequeue_ops[i]):
   1917               dequeue_ops[i] = array_ops.identity(dequeue_ops[i][0])
   1918           else:
   1919             # Assume that the input has been batch-split and that axis 0 of the
   1920             # output tensors represents the batch size.  Concatenate along
   1921             # the axis 0 to re-combine the batch.
   1922             dequeue_ops[i] = array_ops.concat(dequeue_ops[i], axis=0)
   1923 
   1924         if self._tensor_keys[name] is not None:
   1925           # The user-provided eval_metrics[1] is a dict.
   1926           dequeue_ops = dict(zip(self._tensor_keys[name], dequeue_ops))
   1927           try:
   1928             ret[name] = _call_host_fn(self._host_fns[name], **dequeue_ops)
   1929           except TypeError as e:
   1930             logging.warning(
   1931                 'Exception while calling %s: %s. It is likely the tensors '
   1932                 '(%s[1]) do not match the '
   1933                 'function\'s arguments', name, e, name)
   1934             raise
   1935         else:
   1936           ret[name] = _call_host_fn(self._host_fns[name], *dequeue_ops)
   1937 
   1938     # force all dequeue operations to be run if not consumed by the host calls
   1939     ret['__force_dequeue'] = control_flow_ops.group(*flat_dequeue_ops)
   1940     return ret
   1941 
   1942 
   1943 class _OutfeedHostCallHook(session_run_hook.SessionRunHook):
   1944   """Hook to run host calls when use_tpu=False."""
   1945 
   1946   def __init__(self, tensors):
   1947     self._tensors = tensors
   1948 
   1949   def begin(self):
   1950     # We duplicate this code from the TPUInfeedOutfeedSessionHook rather than
   1951     # create a separate hook to guarantee execution order, because summaries
   1952     # need to be initialized before the outfeed thread starts.
   1953     # TODO(jhseu): Make a wrapper hook instead?
   1954     self._init_ops = contrib_summary.summary_writer_initializer_op()
   1955     # Get all the writer resources from the initializer, so we know what to
   1956     # flush.
   1957     self._finalize_ops = []
   1958     for op in self._init_ops:
   1959       self._finalize_ops.append(contrib_summary.flush(writer=op.inputs[0]))
   1960 
   1961   def after_create_session(self, session, coord):
   1962     session.run(self._init_ops)
   1963 
   1964   def before_run(self, run_context):
   1965     return basic_session_run_hooks.SessionRunArgs(self._tensors)
   1966 
   1967   def end(self, session):
   1968     session.run(self._finalize_ops)
   1969 
   1970 
   1971 class ExamplesPerSecondHook(basic_session_run_hooks.StepCounterHook):
   1972   """Calculate and report global_step/sec and examples/sec during runtime."""
   1973 
   1974   def __init__(self,
   1975                batch_size,
   1976                every_n_steps=100,
   1977                every_n_secs=None,
   1978                output_dir=None,
   1979                summary_writer=None):
   1980     self._batch_size = batch_size
   1981     super(ExamplesPerSecondHook, self).__init__(
   1982         every_n_steps=every_n_steps,
   1983         every_n_secs=every_n_secs,
   1984         output_dir=output_dir,
   1985         summary_writer=summary_writer)
   1986 
   1987   def _log_and_record(self, elapsed_steps, elapsed_time, global_step):
   1988     global_step_per_sec = elapsed_steps / elapsed_time
   1989     examples_per_sec = self._batch_size * global_step_per_sec
   1990     if self._summary_writer is not None:
   1991       global_step_summary = Summary(value=[
   1992           Summary.Value(tag='global_step/sec', simple_value=global_step_per_sec)
   1993       ])
   1994       example_summary = Summary(value=[
   1995           Summary.Value(tag='examples/sec', simple_value=examples_per_sec)
   1996       ])
   1997       self._summary_writer.add_summary(global_step_summary, global_step)
   1998       self._summary_writer.add_summary(example_summary, global_step)
   1999     logging.info('global_step/sec: %g', global_step_per_sec)
   2000     logging.info('examples/sec: %g', examples_per_sec)
   2001 
   2002 
   2003 class InstallSignalHandlerHook(session_run_hook.SessionRunHook):
   2004   """Change SIGINT (CTRL^C) handler to force quit the process.
   2005 
   2006   The default behavior often results in hanging processes.
   2007   The original handler is restored after training/evaluation.
   2008   """
   2009 
   2010   def __init__(self):
   2011     self._signal_fn = signal.getsignal(signal.SIGINT)
   2012 
   2013   def before_run(self, run_context):
   2014     signal.signal(signal.SIGINT, signal.SIG_DFL)
   2015 
   2016   def end(self, session):
   2017     signal.signal(signal.SIGINT, self._signal_fn)
   2018 
   2019 
   2020 class TPUEstimator(estimator_lib.Estimator):
   2021   """Estimator with TPU support.
   2022 
   2023   TPUEstimator also supports training on CPU and GPU. You don't need to define
   2024   a separate `tf.estimator.Estimator`.
   2025 
   2026   TPUEstimator handles many of the details of running on TPU devices, such as
   2027   replicating inputs and models for each core, and returning to host
   2028   periodically to run hooks.
   2029 
   2030   TPUEstimator transforms a global batch size in params to a per-shard batch
   2031   size when calling the `input_fn` and `model_fn`. Users should specify
   2032   global batch size in constructor, and then get the batch size for each shard
   2033   in `input_fn` and `model_fn` by `params['batch_size']`.
   2034 
   2035   - For training, `model_fn` gets per-core batch size; `input_fn` may get
   2036     per-core or per-host batch size depending on `per_host_input_for_training`
   2037     in `TPUConfig` (See docstring for TPUConfig for details).
   2038 
   2039   - For evaluation and prediction, `model_fn` gets per-core batch size and
   2040     `input_fn` get per-host batch size.
   2041 
   2042   Evaluation
   2043   ==========
   2044 
   2045   `model_fn` should return `TPUEstimatorSpec`, which expects the `eval_metrics`
   2046   for TPU evaluation. If eval_on_tpu is False, the evaluation will execute on
   2047   CPU or GPU; in this case the following discussion on TPU evaluation does not
   2048   apply.
   2049 
   2050   `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`, where
   2051   `tensors` could be a list of any nested structure of `Tensor`s (See
   2052   `TPUEstimatorSpec` for details).  `metric_fn` takes the `tensors` and returns
   2053   a dict from metric string name to the result of calling a metric function,
   2054   namely a `(metric_tensor, update_op)` tuple.
   2055 
   2056   One can set `use_tpu` to `False` for testing. All training, evaluation, and
   2057   predict will be executed on CPU. `input_fn` and `model_fn` will receive
   2058   `train_batch_size` or `eval_batch_size` unmodified as `params['batch_size']`.
   2059 
   2060   Current limitations:
   2061   --------------------
   2062 
   2063   1. TPU evaluation only works on a single host (one TPU worker) except
   2064      BROADCAST mode.
   2065 
   2066   2. `input_fn` for evaluation should **NOT** raise an end-of-input exception
   2067      (`OutOfRangeError` or `StopIteration`). And all evaluation steps and all
   2068      batches should have the same size.
   2069 
   2070   Example (MNIST):
   2071   ----------------
   2072 
   2073   ```
   2074   # The metric Fn which runs on CPU.
   2075   def metric_fn(labels, logits):
   2076     predictions = tf.argmax(logits, 1)
   2077     return {
   2078       'accuracy': tf.metrics.precision(
   2079           labels=labels, predictions=predictions),
   2080     }
   2081 
   2082   # Your model Fn which runs on TPU (eval_metrics is list in this example)
   2083   def model_fn(features, labels, mode, config, params):
   2084     ...
   2085     logits = ...
   2086 
   2087     if mode = tf.estimator.ModeKeys.EVAL:
   2088       return tpu_estimator.TPUEstimatorSpec(
   2089           mode=mode,
   2090           loss=loss,
   2091           eval_metrics=(metric_fn, [labels, logits]))
   2092 
   2093   # or specify the eval_metrics tensors as dict.
   2094   def model_fn(features, labels, mode, config, params):
   2095     ...
   2096     final_layer_output = ...
   2097 
   2098     if mode = tf.estimator.ModeKeys.EVAL:
   2099       return tpu_estimator.TPUEstimatorSpec(
   2100           mode=mode,
   2101           loss=loss,
   2102           eval_metrics=(metric_fn, {
   2103               'labels': labels,
   2104               'logits': final_layer_output,
   2105           }))
   2106   ```
   2107 
   2108   Prediction
   2109   ==========
   2110 
   2111   Prediction on TPU is an experimental feature to support large batch inference.
   2112   It is not designed for latency-critical system. In addition, due to some
   2113   usability issues, for prediction with small dataset, CPU `.predict`, i.e.,
   2114   creating a new `TPUEstimator` instance with `use_tpu=False`, might be more
   2115   convenient.
   2116 
   2117   Note: In contrast to TPU training/evaluation, the `input_fn` for prediction
   2118   *should* raise an end-of-input exception (`OutOfRangeError` or
   2119   `StopIteration`), which serves as the stopping signal to `TPUEstimator`. To be
   2120   precise, the ops created by `input_fn` produce one batch of the data.
   2121   The `predict()` API processes one batch at a time. When reaching the end of
   2122   the data source, an end-of-input exception should be raised by one of these
   2123   operations. The user usually does not need to do this manually. As long as the
   2124   dataset is not repeated forever, the `tf.data` API will raise an end-of-input
   2125   exception automatically after the last batch has been produced.
   2126 
   2127   Note: Estimator.predict returns a Python generator. Please consume all the
   2128   data from the generator so that TPUEstimator can shutdown the TPU system
   2129   properly for user.
   2130 
   2131   Current limitations:
   2132   --------------------
   2133   1. TPU prediction only works on a single host (one TPU worker).
   2134 
   2135   2. `input_fn` must return a `Dataset` instance rather than `features`. In
   2136   fact, .train() and .evaluate() also support Dataset as return value.
   2137 
   2138   Example (MNIST):
   2139   ----------------
   2140   ```
   2141   height = 32
   2142   width = 32
   2143   total_examples = 100
   2144 
   2145   def predict_input_fn(params):
   2146     batch_size = params['batch_size']
   2147 
   2148     images = tf.random_uniform(
   2149         [total_examples, height, width, 3], minval=-1, maxval=1)
   2150 
   2151     dataset = tf.data.Dataset.from_tensor_slices(images)
   2152     dataset = dataset.map(lambda images: {'image': images})
   2153 
   2154     dataset = dataset.batch(batch_size)
   2155     return dataset
   2156 
   2157   def model_fn(features, labels, params, mode):
   2158      # Generate predictions, called 'output', from features['image']
   2159 
   2160     if mode == tf.estimator.ModeKeys.PREDICT:
   2161       return tf.contrib.tpu.TPUEstimatorSpec(
   2162           mode=mode,
   2163           predictions={
   2164               'predictions': output,
   2165               'is_padding': features['is_padding']
   2166           })
   2167 
   2168   tpu_est = TPUEstimator(
   2169       model_fn=model_fn,
   2170       ...,
   2171       predict_batch_size=16)
   2172 
   2173   # Fully consume the generator so that TPUEstimator can shutdown the TPU
   2174   # system.
   2175   for item in tpu_est.predict(input_fn=input_fn):
   2176     # Filter out item if the `is_padding` is 1.
   2177     # Process the 'predictions'
   2178   ```
   2179 
   2180   Exporting
   2181   =========
   2182 
   2183   `export_savedmodel` exports 2 metagraphs, one with `tag_constants.SERVING`,
   2184   and another with `tag_constants.SERVING` and `tag_constants.TPU`.
   2185   At serving time, these tags are used to select metagraph to load.
   2186 
   2187   Before running the graph on TPU, TPU system needs to be initialized. If
   2188   TensorFlow Serving model-server is used, this is done automatically. If
   2189   not, please call `session.run(tpu.initialize_system())`.
   2190 
   2191   `tpu.outside_compilation` can be used to wrap TPU incompatible ops in
   2192   `model_fn`.
   2193 
   2194   Example:
   2195   ----------------
   2196 
   2197   ```
   2198   def model_fn(features, labels, mode, config, params):
   2199     ...
   2200     logits = ...
   2201     export_outputs = {
   2202       'logits': export_output_lib.PredictOutput(
   2203         {'logits': logits})
   2204     }
   2205 
   2206     def host_call(logits):
   2207       class_ids = math_ops.argmax(logits)
   2208       classes = string_ops.as_string(class_ids)
   2209       export_outputs['classes'] =
   2210         export_output_lib.ClassificationOutput(classes=classes)
   2211 
   2212     tpu.outside_compilation(host_call, logits)
   2213 
   2214     ...
   2215   ```
   2216 
   2217   """
   2218 
   2219   def __init__(self,
   2220                model_fn=None,
   2221                model_dir=None,
   2222                config=None,
   2223                params=None,
   2224                use_tpu=True,
   2225                train_batch_size=None,
   2226                eval_batch_size=None,
   2227                predict_batch_size=None,
   2228                batch_axis=None,
   2229                eval_on_tpu=True,
   2230                export_to_tpu=True,
   2231                export_to_cpu=True,
   2232                warm_start_from=None,
   2233                experimental_exported_model_uses_all_cores=False,
   2234                experimental_export_device_assignment=False,
   2235                experimental_embedding_config_spec=None):
   2236     """Constructs an `TPUEstimator` instance.
   2237 
   2238     Args:
   2239       model_fn: Model function as required by `Estimator` which returns
   2240         EstimatorSpec or TPUEstimatorSpec. `training_hooks`, 'evaluation_hooks',
   2241         and `prediction_hooks` must not capure any TPU Tensor inside the
   2242         model_fn.
   2243       model_dir: Directory to save model parameters, graph and etc. This can
   2244         also be used to load checkpoints from the directory into a estimator to
   2245         continue training a previously saved model. If `None`, the model_dir in
   2246         `config` will be used if set. If both are set, they must be same. If
   2247         both are `None`, a temporary directory will be used.
   2248       config: An `tpu_config.RunConfig` configuration object. Cannot be `None`.
   2249       params: An optional `dict` of hyper parameters that will be passed into
   2250         `input_fn` and `model_fn`.  Keys are names of parameters, values are
   2251         basic python types. There are reserved keys for `TPUEstimator`,
   2252         including 'batch_size'.
   2253       use_tpu: A bool indicating whether TPU support is enabled. Currently, -
   2254         TPU training and evaluation respect this bit, but eval_on_tpu can
   2255         override execution of eval. See below. - Predict still happens on CPU.
   2256       train_batch_size: An int representing the global training batch size.
   2257         TPUEstimator transforms this global batch size to a per-shard batch
   2258         size, as params['batch_size'], when calling `input_fn` and `model_fn`.
   2259         Cannot be `None` if `use_tpu` is `True`. Must be divisible by total
   2260         number of replicas.
   2261       eval_batch_size: An int representing evaluation batch size. Must be
   2262         divisible by total number of replicas.
   2263       predict_batch_size: An int representing the prediction batch size. Must be
   2264         divisible by total number of replicas.
   2265       batch_axis: A python tuple of int values describing how each tensor
   2266         produced by the Estimator `input_fn` should be split across the TPU
   2267         compute shards. For example, if your input_fn produced (images, labels)
   2268         where the images tensor is in `HWCN` format, your shard dimensions would
   2269         be [3, 0], where 3 corresponds to the `N` dimension of your images
   2270         Tensor, and 0 corresponds to the dimension along which to split the
   2271         labels to match up with the corresponding images. If None is supplied,
   2272         and per_host_input_for_training is True, batches will be sharded based
   2273         on the major dimension. If tpu_config.per_host_input_for_training is
   2274         False or `PER_HOST_V2`, batch_axis is ignored.
   2275       eval_on_tpu: If False, evaluation runs on CPU or GPU. In this case, the
   2276         model_fn must return `EstimatorSpec` when called with `mode` as `EVAL`.
   2277       export_to_tpu: If True, `export_savedmodel()` exports a metagraph for
   2278         serving on TPU. Note that unsupported export modes such as EVAL will be
   2279         ignored. For those modes, only a CPU model will be exported.
   2280         Currently, export_to_tpu only supports PREDICT.
   2281       export_to_cpu: If True, `export_savedmodel()` exports a metagraph for
   2282         serving on CPU.
   2283       warm_start_from: Optional string filepath to a checkpoint or SavedModel to
   2284         warm-start from, or a `tf.estimator.WarmStartSettings` object to fully
   2285         configure warm-starting.  If the string filepath is provided instead of
   2286         a `WarmStartSettings`, then all variables are warm-started, and it is
   2287         assumed that vocabularies and Tensor names are unchanged.
   2288       experimental_exported_model_uses_all_cores: Whether to round-robin among
   2289         all cores visible to the host which is serving the saved model, or to
   2290         use a single core. This is a temporary flag to enable using all TPU
   2291         cores for inference with TPUPartitionedCall(). Once outside compilation
   2292         is supported in TPUPartitionedCall(), this flag will be enabled by
   2293         default.
   2294       experimental_export_device_assignment: Whether to include the device
   2295         assignment in the exported model. Doing so is useful in case of model
   2296         parallel inference but will tie the exported model to the TPU topology
   2297         used to export the model.
   2298       experimental_embedding_config_spec: Optional EmbeddingConfigSpec instance
   2299         to support using TPU embedding. IT IS STILL WORK IN PROGRESS, SO PLEASE
   2300         DO NOT USE.
   2301 
   2302     Raises:
   2303       ValueError: `params` has reserved keys already.
   2304     """
   2305     if config is None or not isinstance(config, tpu_config.RunConfig):
   2306       raise ValueError(
   2307           '`config` must be provided with type `tpu_config.RunConfig`')
   2308 
   2309     if params is not None and any(k in params for k in _RESERVED_PARAMS_KEYS):
   2310       raise ValueError('{} are reserved keys but existed in params {}.'.format(
   2311           _RESERVED_PARAMS_KEYS, params))
   2312 
   2313     if use_tpu:
   2314       # Perform some very basic validations. More validations will be found in
   2315       # _InternalTPUContext.
   2316       if train_batch_size is None:
   2317         raise ValueError('`train_batch_size` cannot be `None`')
   2318       util_lib.check_positive_integer(train_batch_size, 'train_batch_size')
   2319 
   2320       if (config.tpu_config.per_host_input_for_training is
   2321           tpu_config.InputPipelineConfig.PER_SHARD_V1 and
   2322           config.tpu_config.num_cores_per_replica):
   2323         raise ValueError(
   2324             'Model parallelism only supports per host input for training. '
   2325             'Please adjust TPURunconfig.per_host_input_for_training.')
   2326 
   2327       if eval_batch_size is not None:
   2328         util_lib.check_positive_integer(eval_batch_size, 'eval_batch_size')
   2329 
   2330       if predict_batch_size is not None:
   2331         util_lib.check_positive_integer(predict_batch_size,
   2332                                         'predict_batch_size')
   2333 
   2334     # Verifies the model_fn signature according to Estimator framework.
   2335     estimator_lib._verify_model_fn_args(model_fn, params)  # pylint: disable=protected-access
   2336     # We cannot store config and params in this constructor as parent
   2337     # constructor might change them, such as assigning a temp dir for
   2338     # config.model_dir.
   2339     model_function = self._augment_model_fn(model_fn, batch_axis)
   2340 
   2341     # Overwrite log_step_count_steps to disable TensorLoggingHook and
   2342     # StepCounterHook from being created in Estimator. TPUEstimator already
   2343     # added equivalent hooks in _augment_model_fn above.
   2344     self._log_every_n_steps = config.log_step_count_steps
   2345     config = config.replace(log_step_count_steps=None)
   2346 
   2347     # Passing non-None params as wrapped model_fn has it.
   2348     params = params or {}
   2349     super(TPUEstimator, self).__init__(
   2350         model_fn=model_function,
   2351         model_dir=model_dir,
   2352         config=config,
   2353         params=params,
   2354         warm_start_from=warm_start_from)
   2355     self._iterations_per_training_loop = (
   2356         self._config.tpu_config.iterations_per_loop)
   2357 
   2358     # All properties passed to _InternalTPUContext are immutable.
   2359     # pylint: disable=protected-access
   2360     self._ctx = tpu_context._get_tpu_context(
   2361         self._config, train_batch_size, eval_batch_size, predict_batch_size,
   2362         use_tpu, eval_on_tpu, experimental_embedding_config_spec)
   2363 
   2364     self._export_to_cpu = export_to_cpu
   2365     self._export_to_tpu = export_to_tpu
   2366     self._experimental_exported_model_uses_all_cores = (
   2367         experimental_exported_model_uses_all_cores)
   2368     self._experimental_export_device_assignment = (
   2369         experimental_export_device_assignment)
   2370     if (experimental_exported_model_uses_all_cores and
   2371         experimental_export_device_assignment):
   2372       raise ValueError('experimental_exported_model_uses_all_cores and '
   2373                        'experimental_export_device_assignment is not supported '
   2374                        'at the same time.')
   2375 
   2376     self._is_input_fn_invoked = None
   2377     self._rendezvous = {}
   2378 
   2379   def _add_meta_graph_for_mode(self,
   2380                                builder,
   2381                                input_receiver_fn_map,
   2382                                checkpoint_path,
   2383                                save_variables=True,
   2384                                mode=model_fn_lib.ModeKeys.PREDICT,
   2385                                export_tags=None,
   2386                                check_variables=True):
   2387     if self._export_to_tpu and mode != model_fn_lib.ModeKeys.PREDICT:
   2388       logging.warning('TPUEstimator only handles mode PREDICT for exporting '
   2389                       'when `export_to_tpu` is `True`; Mode {} will be ignored '
   2390                       'for TPU.'.format(mode))
   2391 
   2392     if not self._export_to_cpu and not self._export_to_tpu:
   2393       raise ValueError('One of export_to_cpu and export_to_tpu must be true.')
   2394 
   2395     if self._export_to_cpu:
   2396       (super(TPUEstimator, self)._add_meta_graph_for_mode(
   2397           builder,
   2398           input_receiver_fn_map,
   2399           checkpoint_path,
   2400           save_variables,
   2401           mode=mode,
   2402           export_tags=export_tags,
   2403           check_variables=check_variables))
   2404 
   2405     if self._export_to_tpu and mode == model_fn_lib.ModeKeys.PREDICT:
   2406       input_receiver_fn_map = {
   2407           _REWRITE_FOR_INFERENCE_MODE: input_receiver_fn_map[mode]
   2408       }
   2409       export_tags = [tag_constants.SERVING, tag_constants.TPU]
   2410       mode = _REWRITE_FOR_INFERENCE_MODE
   2411 
   2412       # See b/110052256 for why `check_variables` is `False`.
   2413       if not self._export_to_cpu:
   2414         check_variables = save_variables = True
   2415       else:
   2416         check_variables = save_variables = False
   2417       (super(TPUEstimator, self)._add_meta_graph_for_mode(
   2418           builder,
   2419           input_receiver_fn_map,
   2420           checkpoint_path,
   2421           save_variables=save_variables,
   2422           mode=mode,
   2423           export_tags=export_tags,
   2424           check_variables=check_variables))
   2425 
   2426   def _call_model_fn(self, features, labels, mode, config):
   2427     if mode == _REWRITE_FOR_INFERENCE_MODE:
   2428       return self._call_model_fn_for_inference(features, labels, mode, config)
   2429     else:
   2430       return super(TPUEstimator, self)._call_model_fn(features, labels, mode,
   2431                                                       config)
   2432 
   2433   def _call_model_fn_for_inference(self, features, labels, mode, config):
   2434     """Wraps `_call_model_fn` for `export_savedmodel`."""
   2435     if mode != _REWRITE_FOR_INFERENCE_MODE:
   2436       raise ValueError('mode must be {}; '
   2437                        'got {}.'.format(_REWRITE_FOR_INFERENCE_MODE, mode))
   2438 
   2439     computation, capture = self._build_computation_for_inference(
   2440         features, labels, mode, config)
   2441     tensors = call_computation(
   2442         computation,
   2443         experimental_exported_model_uses_all_cores=self
   2444         ._experimental_exported_model_uses_all_cores)
   2445     estimator_spec, export_outputs_dict, predictions_dict, none_indices = (
   2446         capture.get())
   2447     predictions_list = tensors[:len(predictions_dict)]
   2448     export_outputs_list_without_none = tensors[len(predictions_dict):]
   2449 
   2450     # Reinsert `None`s which we've taken out in
   2451     # `_build_computation_for_inference()`.
   2452     export_outputs_list = []
   2453     while none_indices or export_outputs_list_without_none:
   2454       if none_indices and none_indices[0] == len(export_outputs_list):
   2455         export_outputs_list.append(None)
   2456         none_indices.pop(0)
   2457       else:
   2458         export_outputs_list.append(export_outputs_list_without_none.pop(0))
   2459 
   2460     # Reconstruct `export_outputs` with updated tensors.
   2461     new_export_outputs_dict = nest.pack_sequence_as(export_outputs_dict,
   2462                                                     export_outputs_list)
   2463     export_outputs = estimator_spec.export_outputs
   2464     new_export_outputs = collections.OrderedDict(
   2465         (k, _clone_export_output_with_tensors(export_outputs[k], v))
   2466         for k, v in six.iteritems(new_export_outputs_dict))
   2467     # Reconstruct `predictions` with updated tensors.
   2468     new_predictions = nest.pack_sequence_as(predictions_dict, predictions_list)
   2469     if (len(new_predictions) == 1 and
   2470         _KEY_WHEN_PREDICTIONS_IS_A_TENSOR in new_predictions):
   2471       new_predictions = new_predictions[_KEY_WHEN_PREDICTIONS_IS_A_TENSOR]
   2472 
   2473     return estimator_spec._replace(
   2474         export_outputs=new_export_outputs, predictions=new_predictions)
   2475 
   2476   def _build_computation_for_inference(self, features, labels, mode, config):
   2477     capture = _CapturedObject()
   2478 
   2479     def computation():
   2480       """Computation to be passed to `TPUPartitionedCall()`."""
   2481       tpu_computation, tpu_capture = self._build_tpu_computation_for_inference(
   2482           features, labels, mode, config)
   2483 
   2484       if self._experimental_export_device_assignment:
   2485         # Export the device assignment as part of the model. This is useful for
   2486         # model parallel usecases where the model relies on the mapping between
   2487         # logical and physical devices.
   2488         with self._ctx.with_mode(mode) as ctx:
   2489           device_assignment = ctx.device_assignment
   2490       else:
   2491         device_assignment = None
   2492 
   2493       if self._experimental_exported_model_uses_all_cores:
   2494         tensors_on_cpu = tpu.rewrite(
   2495             tpu_computation, device_assignment=device_assignment)
   2496         tpu.prune_unconnected_ops_from_xla(ops.get_default_graph())
   2497       else:
   2498         tensors_on_cpu = tpu.rewrite_for_inference(
   2499             tpu_computation, device_assignment=device_assignment)
   2500 
   2501       (estimator_spec, export_outputs_dict, export_outputs_list,
   2502        predictions_dict) = (
   2503            tpu_capture.get())
   2504       predictions_list = tensors_on_cpu[:len(predictions_dict)]
   2505       export_outputs_tpu_on_cpu_list = tensors_on_cpu[len(predictions_dict):]
   2506 
   2507       # Reconstruct tensors used in export_outputs, with TPU tensors replaced
   2508       # with their CPU counterpart returned from `rewrite_for_inference()`.
   2509       # `function.Defun()` does not like `None`s in return values, so we leave
   2510       # `None`s out but record their positions for later reconstruction.
   2511       export_outputs_list_without_none = []
   2512       none_indices = []
   2513       for i, t in enumerate(export_outputs_list):
   2514         if t is None:
   2515           none_indices.append(i)
   2516         else:
   2517           export_outputs_list_without_none.append(
   2518               export_outputs_tpu_on_cpu_list.pop(0))
   2519 
   2520       capture.capture((estimator_spec, export_outputs_dict, predictions_dict,
   2521                        none_indices))
   2522       return predictions_list + export_outputs_list_without_none
   2523 
   2524     return computation, capture
   2525 
   2526   def _build_tpu_computation_for_inference(self, features, labels, mode,
   2527                                            config):
   2528     capture = _CapturedObject()
   2529 
   2530     def computation():
   2531       """Compute tpu tensors used in export_outputs.
   2532 
   2533       Passed to rewrite_for_inference so that model_fn will be called under
   2534       the rewriting contexts. Only tpu tensors are returned, but export_outputs
   2535       and scaffold are captured.
   2536 
   2537       Returns:
   2538          A list of Tensors used in export_outputs and not marked for
   2539          outside_compilation.
   2540       """
   2541       # We should only call model fn once and it should be inside `computation`
   2542       # so that building the graph will happen under `rewrite_for_inference`.
   2543       estimator_spec = super(TPUEstimator, self)._call_model_fn(
   2544           features, labels, mode, config)
   2545 
   2546       # We pick the TPU tensors out from `export_output` and later return them
   2547       # from `computation` for rewriting.
   2548       export_outputs_dict = collections.OrderedDict(
   2549           (k, _export_output_to_tensors(v))
   2550           for k, v in six.iteritems(estimator_spec.export_outputs))
   2551       export_outputs_list = nest.flatten(export_outputs_dict)
   2552       export_outputs_tpu_list = [
   2553           t for t in export_outputs_list if t is not None
   2554       ]
   2555 
   2556       if isinstance(estimator_spec.predictions, dict):
   2557         predictions_dict = collections.OrderedDict(
   2558             (k, v) for k, v in six.iteritems(estimator_spec.predictions))
   2559       else:
   2560         predictions_dict = {
   2561             _KEY_WHEN_PREDICTIONS_IS_A_TENSOR: estimator_spec.predictions
   2562         }
   2563       predictions_list = nest.flatten(predictions_dict)
   2564 
   2565       # We cannot return everything we want through the return values, so
   2566       # capture the rest here for later use.
   2567       capture.capture((estimator_spec, export_outputs_dict, export_outputs_list,
   2568                        predictions_dict))
   2569       return predictions_list + export_outputs_tpu_list
   2570 
   2571     return computation, capture
   2572 
   2573   def _create_global_step(self, graph):
   2574     """Creates a global step suitable for TPUs.
   2575 
   2576     Args:
   2577       graph: The graph in which to create the global step.
   2578 
   2579     Returns:
   2580       A global step `Tensor`.
   2581 
   2582     Raises:
   2583       ValueError: if the global step tensor is already defined.
   2584     """
   2585     return _create_global_step(graph)
   2586 
   2587   def _convert_train_steps_to_hooks(self, steps, max_steps):
   2588     with self._ctx.with_mode(model_fn_lib.ModeKeys.TRAIN) as ctx:
   2589       if ctx.is_running_on_cpu():
   2590         return super(TPUEstimator, self)._convert_train_steps_to_hooks(
   2591             steps, max_steps)
   2592 
   2593     # On TPU.
   2594     if steps is None and max_steps is None:
   2595       raise ValueError(
   2596           'For TPU training, one of `steps` or `max_steps` must be set. '
   2597           'Cannot be both `None`.')
   2598 
   2599     # Estimator.train has explicit positiveness check.
   2600     if steps is not None:
   2601       util_lib.check_positive_integer(steps, 'Train steps')
   2602     if max_steps is not None:
   2603       util_lib.check_positive_integer(max_steps, 'Train max_steps')
   2604 
   2605     return [
   2606         _TPUStopAtStepHook(self._iterations_per_training_loop, steps, max_steps)
   2607     ]
   2608 
   2609   def _convert_eval_steps_to_hooks(self, steps):
   2610     with self._ctx.with_mode(model_fn_lib.ModeKeys.EVAL) as ctx:
   2611       if ctx.is_running_on_cpu():
   2612         return super(TPUEstimator, self)._convert_eval_steps_to_hooks(steps)
   2613 
   2614     if steps is None:
   2615       raise ValueError('Evaluate `steps` must be set on TPU. Cannot be `None`.')
   2616 
   2617     util_lib.check_positive_integer(steps, 'Eval steps')
   2618 
   2619     return [
   2620         evaluation._StopAfterNEvalsHook(  # pylint: disable=protected-access
   2621             num_evals=steps),
   2622         _SetEvalIterationsHook(steps)
   2623     ]
   2624 
   2625   def _call_input_fn(self, input_fn, mode):
   2626     """Calls the input function.
   2627 
   2628     Args:
   2629       input_fn: The input function.
   2630       mode: ModeKeys
   2631 
   2632     Returns:
   2633       In TPU mode, returns an input_fn to be called later in model_fn.
   2634       Otherwise, calls the input_fn and returns either fatures or
   2635         (features, labels).
   2636 
   2637     Raises:
   2638       ValueError: if input_fn takes invalid arguments or does not have `params`.
   2639     """
   2640     input_fn_args = function_utils.fn_args(input_fn)
   2641     config = self.config  # a deep copy.
   2642     kwargs = {}
   2643     if 'params' in input_fn_args:
   2644       kwargs['params'] = self.params  # a deep copy.
   2645     else:
   2646       raise ValueError('input_fn ({}) does not include params argument, '
   2647                        'required by TPUEstimator to pass batch size as '
   2648                        'params["batch_size"]'.format(input_fn))
   2649     if 'config' in input_fn_args:
   2650       kwargs['config'] = config
   2651 
   2652     if 'mode' in input_fn_args:
   2653       kwargs['mode'] = mode
   2654 
   2655     # Records the fact input_fn has been invoked.
   2656     self._is_input_fn_invoked = True
   2657 
   2658     with self._ctx.with_mode(mode) as ctx:
   2659       # Setting the batch size in params first. This helps user to have same
   2660       # input_fn for use_tpu=True/False.
   2661       batch_size_for_input_fn = ctx.batch_size_for_input_fn
   2662       if batch_size_for_input_fn is not None:
   2663         _add_item_to_params(kwargs['params'], _BATCH_SIZE_KEY,
   2664                             batch_size_for_input_fn)
   2665 
   2666       # For export_savedmodel, input_fn is never passed to Estimator. So,
   2667       # `is_export_mode` must be False.
   2668       if ctx.is_running_on_cpu(is_export_mode=False):
   2669         with ops.device('/device:CPU:0'):
   2670           return input_fn(**kwargs)
   2671 
   2672       # For TPU computation, input_fn should be invoked in a tf.while_loop for
   2673       # performance. While constructing the tf.while_loop, the structure of
   2674       # inputs returned by the `input_fn` needs to be recorded. The structure
   2675       # includes whether features or labels is dict or single Tensor, dict keys,
   2676       # tensor shapes, and dtypes. The recorded structure is used to create the
   2677       # infeed dequeue ops, which must be wrapped and passed as a Fn, called
   2678       # inside the TPU computation, as the TPU computation is wrapped inside a
   2679       # tf.while_loop also. So, we either pass input_fn to model_fn or pass
   2680       # dequeue_fn to model_fn. Here, `input_fn` is passed directly as
   2681       # `features` in `model_fn` signature.
   2682       def _input_fn(ctx):
   2683         _add_item_to_params(kwargs['params'], _CTX_KEY, ctx)
   2684         return input_fn(**kwargs)
   2685 
   2686       return _input_fn
   2687 
   2688   def _validate_features_in_predict_input(self, result):
   2689     """Skip the validation.
   2690 
   2691     For TPUEstimator, we do not need to check the result type. `_InputPipeline`
   2692     has stronger check. Parent class's check generates confusing warning msg.
   2693 
   2694     Args:
   2695       result: `features` returned by input_fn.
   2696     """
   2697     pass
   2698 
   2699   def train(self,
   2700             input_fn,
   2701             hooks=None,
   2702             steps=None,
   2703             max_steps=None,
   2704             saving_listeners=None):
   2705     rendezvous = error_handling.ErrorRendezvous(num_sources=3)
   2706     self._rendezvous[model_fn_lib.ModeKeys.TRAIN] = rendezvous
   2707     try:
   2708       return super(TPUEstimator, self).train(
   2709           input_fn=input_fn,
   2710           hooks=hooks,
   2711           steps=steps,
   2712           max_steps=max_steps,
   2713           saving_listeners=saving_listeners)
   2714     except Exception:  # pylint: disable=broad-except
   2715       rendezvous.record_error('training_loop', sys.exc_info())
   2716     finally:
   2717       rendezvous.record_done('training_loop')
   2718       rendezvous.raise_errors()
   2719 
   2720   def evaluate(self,
   2721                input_fn,
   2722                steps=None,
   2723                hooks=None,
   2724                checkpoint_path=None,
   2725                name=None):
   2726     rendezvous = error_handling.ErrorRendezvous(num_sources=3)
   2727     self._rendezvous[model_fn_lib.ModeKeys.EVAL] = rendezvous
   2728     try:
   2729       return super(TPUEstimator, self).evaluate(
   2730           input_fn,
   2731           steps=steps,
   2732           hooks=hooks,
   2733           checkpoint_path=checkpoint_path,
   2734           name=name)
   2735     except Exception:  # pylint: disable=broad-except
   2736       rendezvous.record_error('evaluation_loop', sys.exc_info())
   2737     finally:
   2738       rendezvous.record_done('evaluation_loop')
   2739       rendezvous.raise_errors()
   2740 
   2741   def predict(self,
   2742               input_fn,
   2743               predict_keys=None,
   2744               hooks=None,
   2745               checkpoint_path=None,
   2746               yield_single_examples=True):
   2747     rendezvous = error_handling.ErrorRendezvous(num_sources=3)
   2748     self._rendezvous[model_fn_lib.ModeKeys.PREDICT] = rendezvous
   2749     try:
   2750       for result in super(TPUEstimator, self).predict(
   2751           input_fn=input_fn,
   2752           predict_keys=predict_keys,
   2753           hooks=hooks,
   2754           checkpoint_path=checkpoint_path,
   2755           yield_single_examples=yield_single_examples):
   2756         yield result
   2757     except Exception:  # pylint: disable=broad-except
   2758       rendezvous.record_error('prediction_loop', sys.exc_info())
   2759     finally:
   2760       rendezvous.record_done('prediction_loop')
   2761       rendezvous.raise_errors()
   2762 
   2763     rendezvous.record_done('prediction_loop')
   2764     rendezvous.raise_errors()
   2765 
   2766   def _augment_model_fn(self, model_fn, batch_axis):
   2767     """Returns a new model_fn, which wraps the TPU support."""
   2768 
   2769     def _model_fn(features, labels, mode, config, params):
   2770       """A Estimator `model_fn` for TPUEstimator."""
   2771 
   2772       # `input_fn` is called in `train()`, `evaluate()`, and `predict()`,
   2773       # but not in `export_savedmodel()`.
   2774       if self._is_input_fn_invoked:
   2775         is_export_mode = False
   2776       else:
   2777         is_export_mode = True
   2778 
   2779       # Clear the bit.
   2780       self._is_input_fn_invoked = None
   2781 
   2782       if is_export_mode:
   2783         if mode == _REWRITE_FOR_INFERENCE_MODE:
   2784           _add_item_to_params(params, _USE_TPU_KEY, True)
   2785           mode = model_fn_lib.ModeKeys.PREDICT
   2786         else:
   2787           _add_item_to_params(params, _USE_TPU_KEY, False)
   2788 
   2789       with self._ctx.with_mode(mode) as ctx:
   2790         model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, ctx)
   2791 
   2792         # examples_hook is added to training_hooks for both CPU and TPU
   2793         # execution.
   2794         if self._log_every_n_steps is not None:
   2795           examples_hook = ExamplesPerSecondHook(
   2796               ctx.global_batch_size,
   2797               # pylint:disable=g-long-ternary
   2798               output_dir=(self.model_dir
   2799                           if not config or config.save_summary_steps
   2800                           else None),
   2801               # pylint:enable=g-long-ternary
   2802               every_n_steps=self._log_every_n_steps)
   2803 
   2804         if ctx.is_running_on_cpu(is_export_mode=is_export_mode):
   2805           logging.info('Running %s on CPU', mode)
   2806           estimator_spec = model_fn_wrapper.call_without_tpu(
   2807               features, labels, is_export_mode=is_export_mode)
   2808           if self._log_every_n_steps is not None:
   2809             estimator_spec = estimator_spec._replace(
   2810                 training_hooks=estimator_spec.training_hooks + (examples_hook,))
   2811           return estimator_spec
   2812 
   2813         assert labels is None, '`labels` passed to `model_fn` must be `None`.'
   2814         # TPUEstimator._call_input_fn passes `input_fn` as features to here.
   2815         assert callable(features), '`input_fn` is not callable.'
   2816         input_fn = features
   2817 
   2818         tpu_init_ops = []
   2819         if ctx.embedding_config and mode == model_fn_lib.ModeKeys.TRAIN:
   2820           dummy_table_variables, dummy_table_variables_init = (
   2821               tpu_embedding_gradient.create_dummy_table_variables(
   2822                   ctx.embedding_config.tpu_embedding))
   2823           ctx.embedding_config.dummy_table_variables = dummy_table_variables
   2824           tpu_init_ops.append(dummy_table_variables_init)
   2825 
   2826         input_holders = _InputPipeline(input_fn, batch_axis, ctx)
   2827         enqueue_ops, dequeue_fn, input_hooks, run_infeed_loop_on_coordinator = (
   2828             input_holders.generate_infeed_enqueue_ops_and_dequeue_fn())
   2829 
   2830         graph = ops.get_default_graph()
   2831         for enqueue_op in enqueue_ops:
   2832           if isinstance(enqueue_op, list):
   2833             graph.get_collection_ref(_TPU_ENQUEUE_OPS).extend(enqueue_op)
   2834           else:
   2835             graph.add_to_collection(_TPU_ENQUEUE_OPS, enqueue_op)
   2836 
   2837         if mode == model_fn_lib.ModeKeys.TRAIN:
   2838           compile_op, loss, host_call, scaffold, training_hooks = (
   2839               _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn))
   2840           if ctx.embedding_config:
   2841             g = ops.get_default_graph()
   2842             table_to_config_dict = (
   2843                 ctx.embedding_config.tpu_embedding.table_to_config_dict)
   2844             optimization_parameters = (
   2845                 ctx.embedding_config.tpu_embedding.optimization_parameters)
   2846             embedding_variable_name_by_table, slot_variable_names_by_table = (
   2847                 _tpu_estimator_embedding.get_full_variable_names(
   2848                     g, table_to_config_dict, optimization_parameters
   2849                 )
   2850             )
   2851             embedding_variables_and_ops = (
   2852                 ctx.embedding_config.tpu_embedding.create_variables_and_ops(
   2853                     embedding_variable_name_by_table,
   2854                     slot_variable_names_by_table
   2855                 ))
   2856             tpu_init_ops.extend(embedding_variables_and_ops.load_ops())
   2857 
   2858           host_ops = host_call.create_tpu_hostcall()
   2859           if host_ops is None:
   2860             host_ops = []
   2861 
   2862           shutdown_hooks = []
   2863           shutdown_mode = os.environ.get('TF_TPU_GRACEFUL_SHUTDOWN_MODE',
   2864                                          'shutdown_worker')
   2865           if shutdown_mode:
   2866             if shutdown_mode == 'shutdown_worker':
   2867               finalizer_hooks = [
   2868                   session_support.ShutdownLameWorkers(timeout_ms=60 * 1000),
   2869               ]
   2870             elif shutdown_mode == 'shutdown_computation':
   2871               finalizer_hooks = [
   2872                   session_support.RestartComputation(timeout_ms=60 * 1000),
   2873               ]
   2874             else:
   2875               raise ValueError(
   2876                   'Unknown TF_TPU_GRACEFUL_SHUTDOWN_MODE "%s"' % shutdown_mode)
   2877 
   2878             shutdown_hooks.append(
   2879                 session_support.GracefulShutdownHook(
   2880                     checkpoint_prefix=self.model_dir + '/model.ckpt',
   2881                     on_shutdown_hooks=finalizer_hooks))
   2882 
   2883           with ops.control_dependencies([loss]):
   2884             global_step = array_ops.identity(training.get_global_step())
   2885           hooks = input_hooks + shutdown_hooks
   2886           hooks.extend([
   2887               TPUInfeedOutfeedSessionHook(
   2888                   ctx,
   2889                   enqueue_ops,
   2890                   host_ops,
   2891                   tpu_compile_op=compile_op,
   2892                   run_infeed_loop_on_coordinator=(
   2893                       run_infeed_loop_on_coordinator),
   2894                   rendezvous=self._rendezvous[mode],
   2895                   master=self._config.master,
   2896                   session_config=self._session_config,
   2897                   tpu_init_ops=tpu_init_ops),
   2898               InstallSignalHandlerHook()
   2899           ])
   2900           if self._log_every_n_steps is not None:
   2901             logging_hook_frequency = (  # Divide and round up
   2902                 (self._log_every_n_steps +
   2903                  self._config.tpu_config.iterations_per_loop - 1) //
   2904                 self._config.tpu_config.iterations_per_loop)
   2905             hooks.append(
   2906                 training.LoggingTensorHook({
   2907                     'loss': array_ops.identity(loss),
   2908                     'step': global_step,
   2909                 },
   2910                                            every_n_iter=logging_hook_frequency))
   2911             examples_hook._set_steps_per_run(  # pylint: disable=protected-access
   2912                 self._config.tpu_config.iterations_per_loop)
   2913             hooks.append(examples_hook)
   2914 
   2915           if training_hooks:
   2916             hooks.extend(training_hooks)
   2917 
   2918           chief_hooks = []
   2919           if (self._config.save_checkpoints_secs or
   2920               self._config.save_checkpoints_steps):
   2921             checkpoint_hook = training.CheckpointSaverHook(
   2922                 self.model_dir,
   2923                 save_secs=self._config.save_checkpoints_secs,
   2924                 save_steps=self._config.save_checkpoints_steps,
   2925                 scaffold=scaffold)
   2926             checkpoint_hook._set_steps_per_run(  # pylint: disable=protected-access
   2927                 self._config.tpu_config.iterations_per_loop)
   2928             chief_hooks.append(checkpoint_hook)
   2929 
   2930           summary.scalar(model_fn_lib.LOSS_METRIC_KEY, loss)
   2931           with ops.control_dependencies([loss]):
   2932             update_ops = _sync_variables_ops(ctx)
   2933             if ctx.embedding_config:
   2934               update_ops.extend(embedding_variables_and_ops.retrieve_ops())
   2935 
   2936           # Validate the TPU training graph to catch basic errors
   2937           _validate_tpu_training_graph()
   2938 
   2939           train_op = control_flow_ops.group(*update_ops)
   2940           graph.add_to_collection(_TPU_TRAIN_OP, train_op)
   2941 
   2942           return model_fn_lib.EstimatorSpec(
   2943               mode,
   2944               loss=loss,
   2945               training_chief_hooks=chief_hooks,
   2946               training_hooks=hooks,
   2947               train_op=train_op,
   2948               scaffold=scaffold)
   2949 
   2950         if mode == model_fn_lib.ModeKeys.EVAL:
   2951           compile_op, total_loss, host_calls, scaffold, eval_hooks = (
   2952               _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn))
   2953           if ctx.embedding_config:
   2954             g = ops.get_default_graph()
   2955             table_to_config_dict = (
   2956                 ctx.embedding_config.tpu_embedding.table_to_config_dict)
   2957             embedding_variable_name_by_table, _ = (
   2958                 _tpu_estimator_embedding.get_full_variable_names(
   2959                     g, table_to_config_dict)
   2960             )
   2961             embedding_variables_and_ops = (
   2962                 ctx.embedding_config.tpu_embedding.create_variables_and_ops(
   2963                     embedding_variable_name_by_table
   2964                 ))
   2965             tpu_init_ops.extend(embedding_variables_and_ops.load_ops())
   2966           iterations_per_loop_var = _create_or_get_iterations_per_loop()
   2967           mean_loss = math_ops.div(
   2968               total_loss,
   2969               math_ops.cast(iterations_per_loop_var, dtype=total_loss.dtype))
   2970 
   2971           with ops.control_dependencies([mean_loss]):
   2972             # After TPU evaluation computation is done (the mean_loss tensor),
   2973             # reads all variables back from TPU and updates the eval step
   2974             # counter properly
   2975             internal_ops_to_run = _sync_variables_ops(ctx)
   2976             internal_ops_to_run.append(
   2977                 _increase_eval_step_op(iterations_per_loop_var))
   2978 
   2979           host_call_ret = host_calls.create_tpu_hostcall()
   2980           eval_metric_ops = {}
   2981           eval_update_ops = []
   2982 
   2983           eval_metrics = host_call_ret.get('eval_metrics', {})
   2984           if eval_metrics:
   2985             # Creates a dummy metric update_op for all metrics. Estimator
   2986             # expects all metrics in `eval_metric_ops` have update_op and calls
   2987             # them one by one. The real metric update_ops are invoked in a
   2988             # separated thread. So, here give Estimator the dummy op for all
   2989             # metrics.
   2990             with ops.control_dependencies(internal_ops_to_run):
   2991               dummy_update_op = control_flow_ops.no_op()
   2992 
   2993             for k, v in eval_metrics.items():
   2994               eval_metric_ops[k] = (v[0], dummy_update_op)
   2995               eval_update_ops.append(v[1])
   2996           else:
   2997             # If no eval metrics are passed, create an identity node for the
   2998             # loss and add `internal_ops_to_run` to its dependencies. So
   2999             # `internal_ops_to_run` can be executed.
   3000             with ops.control_dependencies(internal_ops_to_run):
   3001               mean_loss = array_ops.identity(mean_loss)
   3002 
   3003           if 'host_call' not in host_call_ret:
   3004             host_ops = []
   3005           else:
   3006             host_ops = host_call_ret['host_call']
   3007           hooks = [
   3008               TPUInfeedOutfeedSessionHook(
   3009                   ctx,
   3010                   enqueue_ops,
   3011                   eval_update_ops + host_ops,
   3012                   tpu_compile_op=compile_op,
   3013                   run_infeed_loop_on_coordinator=(
   3014                       run_infeed_loop_on_coordinator),
   3015                   rendezvous=self._rendezvous[mode],
   3016                   master=self._config.evaluation_master,
   3017                   session_config=self._session_config,
   3018                   tpu_init_ops=tpu_init_ops)
   3019           ] + input_hooks
   3020 
   3021           if eval_hooks:
   3022             hooks.extend(eval_hooks)
   3023 
   3024           return model_fn_lib.EstimatorSpec(
   3025               mode,
   3026               loss=mean_loss,
   3027               evaluation_hooks=hooks,
   3028               eval_metric_ops=eval_metric_ops,
   3029               scaffold=scaffold)
   3030 
   3031         # Predict
   3032         assert mode == model_fn_lib.ModeKeys.PREDICT
   3033 
   3034         (compile_op, dummy_predict_op, host_calls,
   3035          scaffold, prediction_hooks) = _predict_on_tpu_system(
   3036              ctx, model_fn_wrapper, dequeue_fn)
   3037         with ops.control_dependencies([dummy_predict_op]):
   3038           internal_ops_to_run = _sync_variables_ops(ctx)
   3039           with ops.control_dependencies(internal_ops_to_run):
   3040             dummy_predict_op = control_flow_ops.no_op()
   3041 
   3042         # In train and evaluation, the main TPU program is passed to monitored
   3043         # training session to run. Infeed enqueue and outfeed dequeue are
   3044         # executed in side threads. This is not the configuration for
   3045         # prediction mode.
   3046         #
   3047         # For prediction, the Estimator executes the EstimatorSpec.predictions
   3048         # directly and yield the element (via generator) to call site. So, the
   3049         # outfeed based prediction must be passed to MonitoredSession directly.
   3050         # Other parts of the TPU execution are organized as follows.
   3051         #
   3052         # 1. All outfeed based Tensors must be grouped with predictions Tensors
   3053         #    to form a single invocation. This avoid the issue we might trigger
   3054         #    multiple outfeeds incorrectly. To achieve this, `host_call` is
   3055         #    placed in control_dependencies of `stopping_signals`, and
   3056         #    `stopping_signals` is passed into _StoppingPredictHook, which sets
   3057         #    the `stopping_signals` as SessionRunArgs. MonitoredSession merges
   3058         #    all SessionRunArgs with the fetch in session.run together.
   3059         #
   3060         # 2. The TPU program (dummy_predict_op) and enqueue_ops (infeed Enqueue)
   3061         #    are grouped together. They will be launched once and only once in
   3062         #    side threads and they quit naturally according to the SAME stopping
   3063         #    condition.
   3064         enqueue_ops.append(dummy_predict_op)
   3065 
   3066         host_call_ret = host_calls.create_tpu_hostcall()
   3067         if 'host_call' not in host_call_ret:
   3068           host_ops = []
   3069         else:
   3070           host_ops = host_call_ret['host_call']
   3071 
   3072         predictions = host_call_ret['predictions']
   3073         _verify_cross_hosts_transfer_size(
   3074             predictions,
   3075             message=(
   3076                 'The estimated size for TPUEstimatorSpec.predictions is too '
   3077                 'large.'))
   3078         signals = host_call_ret['signals']
   3079 
   3080         with ops.control_dependencies(host_ops):
   3081           host_ops = []  # Empty, we do do not need it anymore.
   3082           scalar_stopping_signal = _StopSignals.as_scalar_stopping_signal(
   3083               signals)
   3084           predictions = _PaddingSignals.slice_tensor_or_dict(
   3085               predictions, signals)
   3086 
   3087         hooks = [
   3088             _StoppingPredictHook(scalar_stopping_signal),
   3089             TPUInfeedOutfeedSessionHookForPrediction(
   3090                 ctx, enqueue_ops, host_ops, rendezvous=self._rendezvous[mode],
   3091                 tpu_compile_op=compile_op,
   3092                 master=self._config.master,
   3093                 session_config=self._session_config),
   3094         ] + input_hooks
   3095 
   3096         if prediction_hooks:
   3097           hooks.extend(prediction_hooks)
   3098 
   3099         return model_fn_lib.EstimatorSpec(
   3100             mode,
   3101             prediction_hooks=hooks,
   3102             predictions=predictions,
   3103             scaffold=scaffold)
   3104 
   3105     return _model_fn
   3106 
   3107 
   3108 def _export_output_to_tensors(export_output):
   3109   """Get a list of `Tensors` used in `export_output`.
   3110 
   3111   Args:
   3112     export_output: an `ExportOutput` object such as `ClassificationOutput`,
   3113       `RegressionOutput`, or `PredictOutput`.
   3114 
   3115   Returns:
   3116     a list of tensors used in export_output.
   3117 
   3118   Raises:
   3119     ValueError: if `export_output` is not one of `ClassificationOutput`,
   3120         `RegressionOutput`, or `PredictOutput`.
   3121   """
   3122   if isinstance(export_output, export_output_lib.ClassificationOutput):
   3123     return [export_output.scores, export_output.classes]
   3124   elif isinstance(export_output, export_output_lib.RegressionOutput):
   3125     return [export_output.value]
   3126   elif isinstance(export_output, export_output_lib.PredictOutput):
   3127     return list(export_output.outputs.values())
   3128   else:
   3129     raise ValueError(
   3130         '`export_output` must be have type `ClassificationOutput`, '
   3131         '`RegressionOutput`, or `PredictOutput`; got {}.'.format(export_output))
   3132 
   3133 
   3134 def _clone_export_output_with_tensors(export_output, tensors):
   3135   """Clones `export_output` but with new `tensors`.
   3136 
   3137   Args:
   3138     export_output: an `ExportOutput` object such as `ClassificationOutput`,
   3139       `RegressionOutput`, or `PredictOutput`.
   3140     tensors: a list of `Tensors` used to construct a new `export_output`.
   3141 
   3142   Returns:
   3143     A dict similar to `export_output` but with `tensors`.
   3144 
   3145   Raises:
   3146     ValueError: if `export_output` is not one of `ClassificationOutput`,
   3147         `RegressionOutput`, or `PredictOutput`.
   3148   """
   3149   if isinstance(export_output, export_output_lib.ClassificationOutput):
   3150     if len(tensors) != 2:
   3151       raise ValueError('tensors must be of length 2; '
   3152                        'got {}.'.format(len(tensors)))
   3153     return export_output_lib.ClassificationOutput(*tensors)
   3154   elif isinstance(export_output, export_output_lib.RegressionOutput):
   3155     if len(tensors) != 1:
   3156       raise ValueError('tensors must be of length 1; '
   3157                        'got {}'.format(len(tensors)))
   3158     return export_output_lib.RegressionOutput(*tensors)
   3159   elif isinstance(export_output, export_output_lib.PredictOutput):
   3160     return export_output_lib.PredictOutput(
   3161         dict(zip(export_output.outputs.keys(), tensors)))
   3162   else:
   3163     raise ValueError(
   3164         '`export_output` must be have type `ClassificationOutput`, '
   3165         '`RegressionOutput`, or `PredictOutput`; got {}.'.format(export_output))
   3166 
   3167 
   3168 def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
   3169   """Executes `model_fn_wrapper` multiple times on all TPU shards."""
   3170   iterations_per_loop_var = _create_or_get_iterations_per_loop()
   3171 
   3172   (single_tpu_eval_step, host_calls, captured_scaffold_fn, captured_eval_hooks
   3173   ) = model_fn_wrapper.convert_to_single_tpu_eval_step(dequeue_fn)
   3174 
   3175   @tpu_function.on_device_training_loop
   3176   def multi_tpu_eval_steps_on_single_shard():
   3177     return training_loop.repeat(iterations_per_loop_var, single_tpu_eval_step,
   3178                                 [_ZERO_LOSS])
   3179 
   3180   (compile_op, loss,) = tpu.split_compile_and_shard(
   3181       multi_tpu_eval_steps_on_single_shard,
   3182       inputs=[],
   3183       num_shards=ctx.num_replicas,
   3184       outputs_from_all_shards=False,
   3185       device_assignment=ctx.device_assignment)
   3186 
   3187   loss = loss[0]
   3188   scaffold = _get_scaffold(captured_scaffold_fn)
   3189   return compile_op, loss, host_calls, scaffold, captured_eval_hooks.get()
   3190 
   3191 
   3192 def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
   3193   """Executes `model_fn_wrapper` multiple times on all TPU shards."""
   3194   iterations_per_loop_var = _create_or_get_iterations_per_loop()
   3195 
   3196   (single_tpu_train_step, host_call, captured_scaffold_fn,
   3197    captured_training_hooks) = (
   3198        model_fn_wrapper.convert_to_single_tpu_train_step(dequeue_fn))
   3199 
   3200   @tpu_function.on_device_training_loop
   3201   def multi_tpu_train_steps_on_single_shard():
   3202     return training_loop.repeat(iterations_per_loop_var, single_tpu_train_step,
   3203                                 [_INITIAL_LOSS])
   3204 
   3205   (compile_op, loss,) = tpu.split_compile_and_shard(
   3206       multi_tpu_train_steps_on_single_shard,
   3207       inputs=[],
   3208       num_shards=ctx.num_replicas,
   3209       outputs_from_all_shards=False,
   3210       device_assignment=ctx.device_assignment)
   3211 
   3212   loss = loss[0]
   3213   scaffold = _get_scaffold(captured_scaffold_fn)
   3214   return compile_op, loss, host_call, scaffold, captured_training_hooks.get()
   3215 
   3216 
   3217 def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
   3218   """Executes `model_fn_wrapper` multiple times on all TPU shards."""
   3219   (single_tpu_predict_step, host_calls, captured_scaffold_fn,
   3220    captured_predict_hooks
   3221   ) = model_fn_wrapper.convert_to_single_tpu_predict_step(dequeue_fn)
   3222 
   3223   @tpu_function.on_device_training_loop
   3224   def multi_tpu_predict_steps_on_single_shard():
   3225 
   3226     def cond(scalar_stopping_signal):
   3227       return math_ops.logical_not(
   3228           _StopSignals.should_stop(scalar_stopping_signal))
   3229 
   3230     inputs = [_StopSignals.NON_STOPPING_SIGNAL]
   3231     outputs = training_loop.while_loop(
   3232         cond, single_tpu_predict_step, inputs=inputs, name=b'loop')
   3233     return outputs
   3234 
   3235   (compile_op, dummy_predict_op,) = tpu.split_compile_and_shard(
   3236       multi_tpu_predict_steps_on_single_shard,
   3237       inputs=[],
   3238       num_shards=ctx.num_replicas,
   3239       outputs_from_all_shards=False,
   3240       device_assignment=ctx.device_assignment)
   3241 
   3242   dummy_predict_op = dummy_predict_op[0]
   3243   scaffold = _get_scaffold(captured_scaffold_fn)
   3244   return (compile_op, dummy_predict_op, host_calls, scaffold,
   3245           captured_predict_hooks.get())
   3246 
   3247 
   3248 def _wrap_computation_in_while_loop(device, op_fn):
   3249   """Wraps the ops generated by `op_fn` in tf.while_loop."""
   3250 
   3251   def computation(i):
   3252     with ops.control_dependencies(op_fn()):
   3253       return i + 1
   3254 
   3255   iterations_per_loop_var = _create_or_get_iterations_per_loop()
   3256   # By setting parallel_iterations=1, the parallel execution in while_loop is
   3257   # basically turned off.
   3258   with ops.device(device):
   3259     iterations = array_ops.identity(iterations_per_loop_var)
   3260     return control_flow_ops.while_loop(
   3261         lambda i: i < iterations,
   3262         computation, [constant_op.constant(0)],
   3263         parallel_iterations=1)
   3264 
   3265 
   3266 def _wrap_computation_in_while_loop_with_stopping_signals(device, op_fn):
   3267   """Wraps the ops generated by `op_fn` in tf.while_loop."""
   3268 
   3269   def cond(scalar_stopping_signal):
   3270     return math_ops.logical_not(
   3271         _StopSignals.should_stop(scalar_stopping_signal))
   3272 
   3273   def computation(unused_scalar_stopping_signal):
   3274     return_value = op_fn()
   3275     execute_ops = return_value['ops']
   3276     signals = return_value['signals']
   3277     with ops.control_dependencies(execute_ops):
   3278       return _StopSignals.as_scalar_stopping_signal(signals)
   3279 
   3280   # By setting parallel_iterations=1, the parallel execution in while_loop is
   3281   # basically turned off.
   3282   with ops.device(device):
   3283     return control_flow_ops.while_loop(
   3284         cond,
   3285         computation, [_StopSignals.NON_STOPPING_SIGNAL],
   3286         parallel_iterations=1)
   3287 
   3288 
   3289 def _validate_tpu_training_graph():
   3290   """Validate graph before running distributed training.
   3291 
   3292   Raises:
   3293     ValueError: If the graph seems invalid for running on device
   3294   """
   3295   operations = ops.get_default_graph().get_operations()
   3296 
   3297   # Check if there is atleast one CrossReplicaSum operation in the graph
   3298   # This should be introduced by using the CrossShardOptimizer wrapper
   3299   cross_replica_sum_ops = [
   3300       o for o in operations if o.type == _CROSS_REPLICA_SUM_OP
   3301   ]
   3302   if not cross_replica_sum_ops:
   3303     raise ValueError(
   3304         'CrossShardOptimizer must be used for model training on TPUs.')
   3305 
   3306 
   3307 class _CapturedObject(object):
   3308   """A placeholder to capture an object.
   3309 
   3310   This is useful when we need to capture a Python object in the Tensorflow
   3311   control flow body function and use it outside the control flow.
   3312   """
   3313 
   3314   def __init__(self):
   3315     self._object = None
   3316     self._captured = False
   3317 
   3318   def capture(self, o):
   3319     if self._captured:
   3320       raise RuntimeError(
   3321           'InternalError: Object can capture only once. Please file bug.')
   3322 
   3323     self._captured = True
   3324     self._object = o
   3325 
   3326   def get(self):
   3327     if not self._captured:
   3328       raise RuntimeError(
   3329           'InternalError: Object is not captured properly before `get`. '
   3330           'Please file bug.')
   3331     return self._object
   3332 
   3333 
   3334 def _get_scaffold(captured_scaffold_fn):
   3335   """Retrieves the Scaffold from `captured_scaffold_fn`."""
   3336   with _CapturingContext(message='Inside scaffold_fn'):
   3337     scaffold_fn = captured_scaffold_fn.get()
   3338     if scaffold_fn:
   3339       scaffold = scaffold_fn()
   3340       if scaffold is None:
   3341         raise ValueError(
   3342             'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed')
   3343     else:
   3344       scaffold = None
   3345 
   3346   if scaffold:
   3347     wrapped_finalize = scaffold.finalize
   3348 
   3349     def _finalize():
   3350       with _CapturingContext('Inside Scaffold.finalize'):
   3351         wrapped_finalize()
   3352 
   3353     scaffold.finalize = _finalize
   3354   return scaffold
   3355 
   3356 
   3357 class _CapturingContext(control_flow_ops.ControlFlowContext):
   3358   """Tracks references to Tensors defined in TPU replication."""
   3359 
   3360   def __init__(self, message):
   3361     control_flow_ops.ControlFlowContext.__init__(self)
   3362     self._message = message
   3363 
   3364   def to_control_flow_context_def(self, context_def, export_scope=None):
   3365     # pylint: disable=useless-super-delegation
   3366     # NOTE(slebedev): the method is required by `ControlFlowContext`.
   3367     super(_CapturingContext, self).to_control_flow_context_def(
   3368         context_def, export_scope)
   3369 
   3370   def AddOp(self, op):  # pylint: disable=invalid-name
   3371     for c in op.inputs:
   3372       if tpu._TPU_REPLICATE_ATTR in c.op.node_def.attr:  # pylint: disable=protected-access
   3373         raise ValueError('{}: Op {} depends on TPU computation {}, '
   3374                          'which is not allowed.'.format(self._message, op, c))
   3375 
   3376   def __enter__(self):
   3377     # pylint: disable=protected-access
   3378     self._g = ops.get_default_graph()
   3379     self._old = self._g._get_control_flow_context()
   3380     self._g._set_control_flow_context(self)
   3381     # pylint: enable=protected-access
   3382 
   3383   def __exit__(self, _, __, ___):  # pylint: disable=invalid-name
   3384     self._g._set_control_flow_context(self._old)  # pylint: disable=protected-access
   3385 
   3386 
   3387 class _Inputs(object):
   3388   """A data structure representing the input_fn returned values.
   3389 
   3390   This also supports the returned value from input_fn as `Dataset`.
   3391   """
   3392 
   3393   def __init__(self, features=None, labels=None, dataset=None, signals=None):
   3394     if dataset is not None and (features is not None or labels is not None or
   3395                                 signals is not None):
   3396       raise RuntimeError('Internal Error: Either (features and labels) or '
   3397                          'dataset should be provided, not both. Please file '
   3398                          'bug')
   3399 
   3400     self._features = features
   3401     self._labels = labels
   3402     self._signals = signals
   3403 
   3404     self._dataset = dataset
   3405     self._iterator = None
   3406 
   3407   @staticmethod
   3408   def from_input_fn(return_values):
   3409     """Returns an `_Inputs` instance according to `input_fn` return value."""
   3410     if isinstance(return_values, dataset_ops.DatasetV2):
   3411       dataset = return_values
   3412       return _Inputs(dataset=dataset)
   3413 
   3414     features, labels = _Inputs._parse_inputs(return_values)
   3415     return _Inputs(features, labels)
   3416 
   3417   @staticmethod
   3418   def _parse_inputs(return_values):
   3419     if isinstance(return_values, tuple):
   3420       features, labels = return_values
   3421     else:
   3422       features, labels = return_values, None
   3423     return features, labels
   3424 
   3425   @property
   3426   def is_dataset(self):
   3427     """Returns True if the return value from input_fn is Dataset."""
   3428     return self._dataset is not None
   3429 
   3430   def dataset_initializer(self):
   3431     """Returns the dataset's initializer.
   3432 
   3433     The initializer must be run before calling `features_and_labels`.
   3434     """
   3435     self._iterator = dataset_ops.make_initializable_iterator(self._dataset)
   3436     return self._iterator.initializer
   3437 
   3438   def features_and_labels(self):
   3439     """Gets `features` and `labels`."""
   3440     if self.is_dataset:
   3441       if self._iterator is None:
   3442         raise RuntimeError('Internal error: Must run dataset_initializer '
   3443                            'before calling features_and_labels(). Please file '
   3444                            'a bug!')
   3445       return _Inputs._parse_inputs(self._iterator.get_next())
   3446 
   3447     return (self._features, self._labels)
   3448 
   3449   def signals(self):
   3450     return self._signals
   3451 
   3452   @property
   3453   def dataset(self):
   3454     return self._dataset
   3455 
   3456 
   3457 class _InputsWithStoppingSignals(_Inputs):
   3458   """Inputs with `_StopSignals` inserted into the dataset."""
   3459 
   3460   def __init__(self,
   3461                dataset,
   3462                batch_size,
   3463                add_padding=False,
   3464                num_invocations_per_step=1):
   3465 
   3466     assert dataset is not None
   3467     user_provided_dataset = dataset.map(
   3468         _InputsWithStoppingSignals.insert_stopping_signal(
   3469             stop=False, batch_size=batch_size, add_padding=add_padding))
   3470     if num_invocations_per_step == 1:
   3471       final_batch_dataset = dataset.take(1).map(
   3472           _InputsWithStoppingSignals.insert_stopping_signal(
   3473               stop=True, batch_size=batch_size, add_padding=add_padding))
   3474     else:
   3475       # We append (2 * num_invocations_per_step - 1) batches for exhausting the
   3476       # user_provided_dataset and stop properly.
   3477       # For example, if num_invocations_per_step is 2, we append 3 additional
   3478       # padding batches: b1, b2, b3.
   3479       # If user_provided_dataset contains two batches: a1, a2
   3480       # Step 1: [a1, a2]
   3481       # Step 2: [b1, b2] -> STOP
   3482       # If user_provided_dataset contains three batches: a1, a2, a3.
   3483       # The training loops:
   3484       # Step 1: [a1, a2]
   3485       # Step 2: [a3, b1]
   3486       # Step 3: [b2, b3] -> STOP.
   3487       final_batch_dataset = dataset.take(1).map(
   3488           _InputsWithStoppingSignals.insert_stopping_signal(
   3489               stop=True, batch_size=batch_size, add_padding=add_padding))
   3490       final_batch_dataset = final_batch_dataset.repeat(
   3491           2 * num_invocations_per_step - 1)
   3492 
   3493       def _set_mask(data_dict):
   3494         signals = data_dict['signals']
   3495         signals['padding_mask'] = array_ops.ones_like(signals['padding_mask'])
   3496         data_dict['signals'] = signals
   3497         return data_dict
   3498 
   3499       # Mask out the extra batch.
   3500       final_batch_dataset = final_batch_dataset.map(_set_mask)
   3501 
   3502     dataset = user_provided_dataset.concatenate(final_batch_dataset).prefetch(2)
   3503 
   3504     super(_InputsWithStoppingSignals, self).__init__(dataset=dataset)
   3505     self._current_inputs = None
   3506 
   3507   def features_and_labels(self):
   3508     if self._current_inputs is not None:
   3509       raise RuntimeError(
   3510           'Internal Error: The previous inputs have not been properly '
   3511           'consumed. First call features_and_labels, then call signals.')
   3512 
   3513     inputs_with_signals = self._iterator.get_next()
   3514     features = inputs_with_signals['features']
   3515     labels = inputs_with_signals.get('labels')
   3516 
   3517     self._current_inputs = inputs_with_signals
   3518     return features, labels
   3519 
   3520   def signals(self):
   3521     """Returns the `Signals` from `_Inputs`."""
   3522     if self._current_inputs is None:
   3523       raise RuntimeError(
   3524           'Internal Error: The current inputs have not been properly '
   3525           'generated. First call features_and_labels, then call signals.')
   3526     signals = self._current_inputs['signals']
   3527     self._current_inputs = None
   3528     return signals
   3529 
   3530   @staticmethod
   3531   def insert_stopping_signal(stop, batch_size, add_padding=False):
   3532     """Inserts stopping_signal into dataset via _map_fn.
   3533 
   3534     Here we change the data structure in the dataset, such that the return value
   3535     is a dictionary now and `features`, `labels`, and `signals` are three
   3536     distinguished keys in that dict. This provides a better structure, which
   3537     eases the process to decompose the inputs (see `features_and_labels`).
   3538 
   3539     Args:
   3540       stop: bool, state of current stopping signals.
   3541       batch_size: int, batch size.
   3542       add_padding: bool, whether to pad the tensor to full batch size.
   3543 
   3544     Returns:
   3545       A map_fn passed to dataset.map API.
   3546     """
   3547 
   3548     def _map_fn(*args):
   3549       """The map fn to insert signals."""
   3550       if len(args) == 1:
   3551         # Unpack the single Tensor/dict argument as features. This is required
   3552         # for the input_fn returns no labels.
   3553         args = args[0]
   3554       features, labels = _Inputs._parse_inputs(args)
   3555       new_input_dict = {}
   3556 
   3557       if add_padding:
   3558         padding_mask, features, labels = (
   3559             _PaddingSignals.pad_features_and_labels(features, labels,
   3560                                                     batch_size))
   3561 
   3562         new_input_dict['features'] = features
   3563         if labels is not None:
   3564           new_input_dict['labels'] = labels
   3565 
   3566       else:
   3567         new_input_dict['features'] = features
   3568         if labels is not None:
   3569           new_input_dict['labels'] = labels
   3570         padding_mask = None
   3571 
   3572       new_input_dict['signals'] = _StopSignals(
   3573           stop=stop, batch_size=batch_size,
   3574           padding_mask=padding_mask).as_dict()
   3575 
   3576       return new_input_dict
   3577 
   3578     return _map_fn
   3579 
   3580 
   3581 class _StopSignals(object):
   3582   """Signals class holding all logic to handle TPU stopping condition."""
   3583 
   3584   NON_STOPPING_SIGNAL = False
   3585   STOPPING_SIGNAL = True
   3586 
   3587   def __init__(self, stop, batch_size, padding_mask=None):
   3588     self._stop = stop
   3589     self._batch_size = batch_size
   3590     self._padding_mask = padding_mask
   3591 
   3592   def as_dict(self):
   3593     """Returns the signals as Python dict."""
   3594     shape = [self._batch_size, 1]
   3595     dtype = dtypes.bool
   3596 
   3597     if self._stop:
   3598       stopping = array_ops.ones(shape=shape, dtype=dtype)
   3599     else:
   3600       stopping = array_ops.zeros(shape=shape, dtype=dtype)
   3601 
   3602     signals = {'stopping': stopping}
   3603     if self._padding_mask is not None:
   3604       signals['padding_mask'] = self._padding_mask
   3605     return signals
   3606 
   3607   @staticmethod
   3608   def as_scalar_stopping_signal(signals):
   3609     return array_ops.identity(signals['stopping'][0][0])
   3610 
   3611   @staticmethod
   3612   def should_stop(scalar_stopping_signal):
   3613     """Detects whether scalar_stopping_signal indicates stopping."""
   3614     if isinstance(scalar_stopping_signal, ops.Tensor):
   3615       # STOPPING_SIGNAL is a constant True. Here, the logical_and is just the TF
   3616       # way to express the bool check whether scalar_stopping_signal is True.
   3617       return math_ops.logical_and(scalar_stopping_signal,
   3618                                   _StopSignals.STOPPING_SIGNAL)
   3619     else:
   3620       # For non Tensor case, it is used in SessionRunHook. So, we cannot modify
   3621       # the graph anymore. Here, we use pure Python.
   3622       return bool(scalar_stopping_signal)
   3623 
   3624 
   3625 class _PaddingSignals(object):
   3626   """Signals class holding all logic to handle padding."""
   3627 
   3628   @staticmethod
   3629   def pad_features_and_labels(features, labels, batch_size):
   3630     """Pads out the batch dimension of features and labels."""
   3631     real_batch_size = array_ops.shape(
   3632         _PaddingSignals._find_any_tensor(features))[0]
   3633 
   3634     batch_size_tensor = constant_op.constant(batch_size, dtypes.int32)
   3635 
   3636     check_greater = check_ops.assert_greater_equal(
   3637         batch_size_tensor,
   3638         real_batch_size,
   3639         data=(batch_size_tensor, real_batch_size),
   3640         message='The real batch size should not be greater than batch_size.')
   3641 
   3642     with ops.control_dependencies([check_greater]):
   3643       missing_count = batch_size_tensor - real_batch_size
   3644 
   3645     def pad_single_tensor(tensor):
   3646       """Pads out the batch dimension of a tensor to the complete batch_size."""
   3647       rank = len(tensor.shape)
   3648       assert rank > 0
   3649       padding = array_ops.stack([[0, missing_count]] + [[0, 0]] * (rank - 1))
   3650       padded_shape = (batch_size,) + tuple(tensor.shape[1:])
   3651       padded_tensor = array_ops.pad(tensor, padding)
   3652       padded_tensor.set_shape(padded_shape)
   3653       return padded_tensor
   3654 
   3655     def nest_pad(tensor_or_dict):
   3656       return nest.map_structure(pad_single_tensor, tensor_or_dict)
   3657 
   3658     features = nest_pad(features)
   3659     if labels is not None:
   3660       labels = nest_pad(labels)
   3661 
   3662     padding_mask = _PaddingSignals._padding_mask(real_batch_size, missing_count,
   3663                                                  batch_size)
   3664 
   3665     return padding_mask, features, labels
   3666 
   3667   @staticmethod
   3668   def slice_tensor_or_dict(tensor_or_dict, signals):
   3669     """Slice the real Tensors according to padding mask in signals."""
   3670 
   3671     padding_mask = signals['padding_mask']
   3672     batch_size = array_ops.shape(padding_mask)[0]
   3673 
   3674     def verify_batch_size(tensor):
   3675       check_batch_size = math_ops.equal(batch_size, tensor.shape[0])
   3676       with ops.control_dependencies([check_batch_size]):
   3677         return array_ops.identity(tensor)
   3678 
   3679     def slice_single_tensor(tensor):
   3680       rank = len(tensor.shape)
   3681       assert rank > 0
   3682       real_batch_size = batch_size - math_ops.reduce_sum(padding_mask)
   3683       return verify_batch_size(tensor)[0:real_batch_size]
   3684 
   3685     # As we split the Tensors to all TPU cores and concat them back, it is
   3686     # important to ensure the real data is placed before padded ones, i.e.,
   3687     # order is preserved. By that, the sliced padding mask should have all 0's.
   3688     # If this assertion failed, # the slice logic here would not hold.
   3689     sliced_padding_mask = slice_single_tensor(padding_mask)
   3690     assert_padding_mask = math_ops.equal(
   3691         math_ops.reduce_sum(sliced_padding_mask), 0)
   3692 
   3693     with ops.control_dependencies([assert_padding_mask]):
   3694       should_stop = _StopSignals.should_stop(
   3695           _StopSignals.as_scalar_stopping_signal(signals))
   3696 
   3697     is_full_batch = math_ops.equal(math_ops.reduce_sum(padding_mask), 0)
   3698 
   3699     def slice_fn(tensor):
   3700       # If the current batch is full batch or part of stopping signals, we do
   3701       # not need to slice to save performance.
   3702       return control_flow_ops.cond(
   3703           math_ops.logical_or(should_stop, is_full_batch),
   3704           (lambda: verify_batch_size(tensor)),
   3705           (lambda: slice_single_tensor(tensor)))
   3706 
   3707     return nest.map_structure(slice_fn, tensor_or_dict)
   3708 
   3709   @staticmethod
   3710   def _find_any_tensor(batch_features):
   3711     tensors = [
   3712         x for x in nest.flatten(batch_features) if isinstance(x, ops.Tensor)
   3713     ]
   3714     if not tensors:
   3715       raise ValueError('Cannot find any Tensor in features dict.')
   3716     return tensors[0]
   3717 
   3718   @staticmethod
   3719   def _padding_mask(real_batch_size, missing_count, batch_size):
   3720     padding_mask = array_ops.concat([
   3721         array_ops.zeros((real_batch_size,), dtype=dtypes.int32),
   3722         array_ops.ones((missing_count,), dtype=dtypes.int32)
   3723     ],
   3724                                     axis=0)
   3725     padding_mask.set_shape((batch_size,))
   3726     return padding_mask
   3727 
   3728 
   3729 def _verify_cross_hosts_transfer_size(tensor_dict, message):
   3730   total_size = 0
   3731   tensor_structure = {}
   3732   for key, tensor in tensor_dict.items():
   3733     shape = tensor.shape
   3734     size = np.product(shape) * tensor.dtype.size
   3735     tensor_structure[key] = shape
   3736     total_size += size
   3737   if total_size >= _ONE_GIGABYTE:
   3738     raise ValueError(
   3739         '{} The transfer size is larger than the protobuf limit. Please '
   3740         'consider to use Tensors with smaller shapes or reduce batch '
   3741         'size. Given:\n'
   3742         '{}'.format(
   3743             message, '\n'.join([
   3744                 ' -- Key: {}, Shape: {}'.format(k, v)
   3745                 for k, v in tensor_structure.items()
   3746             ])))
   3747 
   3748 
   3749 def _add_item_to_params(params, key, value):
   3750   """Adds a new item into `params`."""
   3751   if hasattr(params, 'set_hparam'):
   3752     # For HParams, we need to use special API.
   3753     if key in params:
   3754       params.set_hparam(key, value)
   3755     else:
   3756       params.add_hparam(key, value)
   3757   else:
   3758     # Now params is Python dict.
   3759     params[key] = value
   3760 
   3761 
   3762 def export_estimator_savedmodel(estimator,
   3763                                 export_dir_base,
   3764                                 serving_input_receiver_fn,
   3765                                 assets_extra=None,
   3766                                 as_text=False,
   3767                                 checkpoint_path=None,
   3768                                 strip_default_attrs=False):
   3769   """Export `Estimator` trained model for TPU inference.
   3770 
   3771   Args:
   3772     estimator: `Estimator` with which model has been trained.
   3773     export_dir_base: A string containing a directory in which to create
   3774       timestamped subdirectories containing exported SavedModels.
   3775     serving_input_receiver_fn: A function that takes no argument and returns a
   3776       `ServingInputReceiver` or `TensorServingInputReceiver`.
   3777     assets_extra: A dict specifying how to populate the assets.extra directory
   3778       within the exported SavedModel, or `None` if no extra assets are needed.
   3779     as_text: whether to write the SavedModel proto in text format.
   3780     checkpoint_path: The checkpoint path to export.  If `None` (the default),
   3781       the most recent checkpoint found within the model directory is chosen.
   3782     strip_default_attrs: Boolean. If `True`, default-valued attributes will be
   3783       removed from the NodeDefs.
   3784 
   3785   Returns:
   3786     The string path to the exported directory.
   3787   """
   3788   # `TPUEstimator` requires `tpu_config.RunConfig`, so we cannot use
   3789   # `estimator.config`.
   3790   config = tpu_config.RunConfig(model_dir=estimator.model_dir)
   3791   est = TPUEstimator(
   3792       estimator._model_fn,  # pylint: disable=protected-access
   3793       config=config,
   3794       params=estimator.params,
   3795       use_tpu=True,
   3796       train_batch_size=2048,  # Does not matter.
   3797       eval_batch_size=2048,  # Does not matter.
   3798   )
   3799   return est.export_savedmodel(export_dir_base, serving_input_receiver_fn,
   3800                                assets_extra, as_text, checkpoint_path,
   3801                                strip_default_attrs)
   3802