Home | History | Annotate | Download | only in tpu
      1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ===================================================================
     15 """TPU system metadata and associated tooling."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from contextlib import contextmanager
     22 import copy
     23 
     24 from tensorflow.python.estimator import model_fn as model_fn_lib
     25 from tensorflow.python.platform import tf_logging as logging
     26 from tensorflow.python.tpu import _tpu_estimator_embedding
     27 from tensorflow.python.tpu import device_assignment as tpu_device_assignment
     28 from tensorflow.python.tpu import tpu_config
     29 from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
     30 
     31 
     32 _DEFAULT_JOB_NAME = 'tpu_worker'
     33 _DEFAULT_COORDINATOR_JOB_NAME = 'coordinator'
     34 _LOCAL_MASTERS = ('', 'local')
     35 _NUM_CORES_TO_COMPUTATION_SHAPE = {
     36     1: [1, 1, 1],
     37     2: [1, 1, 2],
     38     4: [1, 2, 2],
     39     8: [2, 2, 2],
     40     16: [4, 2, 2],
     41 }
     42 
     43 
     44 class TPUContext(object):
     45   """A context that holds the current configuration of the TPU computation."""
     46 
     47   def __init__(self,
     48                internal_ctx,
     49                input_device=None,
     50                invocation_index=None,
     51                call_from_input_fn=True):
     52     self._internal_ctx = internal_ctx
     53     self._input_device = input_device
     54     self._invocation_index = invocation_index
     55     self._call_from_input_fn = call_from_input_fn
     56 
     57   def current_input_fn_deployment(self):
     58     """The configuration of the current input_fn invocation.
     59 
     60     The configuration depends on `TPUConfig.per_host_input_for_training`. See
     61     `TPUConfig` for details.
     62 
     63     Only set in params dict of input_fn
     64 
     65     Returns:
     66       A tuple of
     67         1. Device spec string: String, is the current CPU host where the
     68            input_fn is invoked.
     69         2. Current invocation index: Int, 0-based index of the input_fn
     70            invocation. See next item for details.
     71         3. Total invocation count: Int, the total number of times to invoke the
     72            input_fn on all CPU hosts. Each invocation will be passed with a new
     73            `TPUContext` instance with current invocation index set properly.
     74         4. Total number of replicas consumed by current_invocation: Int, the
     75            number of replicas fed by the data returned by current input_fn. For
     76            example, for per_core input pipeline deployment
     77            and non-model-parallelism, total invocation count is equal to
     78            the number of cores in the system and num replicas consumed by
     79            current invocation is 1. For per-host v2 input pipeline deployment,
     80            total invocation count is equal to the number of hosts in the system
     81            and num replicas consumed by current invocation is equal to number of
     82            cores per host.
     83 
     84     Raises:
     85       RuntimeError: If this method must not be called from input_fn.
     86     """
     87     if not self._call_from_input_fn:
     88       raise RuntimeError('This TPUContext instance must not be called from'
     89                          ' model_fn.')
     90 
     91     if self._internal_ctx.is_input_sharded_per_core():
     92       total_invocation_count = (self._internal_ctx.num_hosts
     93                                 * self._internal_ctx.num_of_replicas_per_host)
     94       replicas_consumed = 1
     95     elif self._internal_ctx.is_input_broadcast_with_iterators():
     96       total_invocation_count = 1
     97       replicas_consumed = self._internal_ctx.num_replicas
     98     else:
     99       total_invocation_count = self._internal_ctx.num_hosts
    100       replicas_consumed = self._internal_ctx.num_of_replicas_per_host
    101     return (self._input_device, self._invocation_index,
    102             total_invocation_count, replicas_consumed)
    103 
    104   @property
    105   def num_replicas(self):
    106     """The total number of replicas.
    107 
    108     For non-model-parallelism, num_replicas should be the total num of TPU
    109     cores in the system.
    110 
    111     Returns:
    112       The number of replicas.
    113     """
    114     return self._internal_ctx.num_replicas
    115 
    116   @property
    117   def num_hosts(self):
    118     """The number of hosts for the TPU system."""
    119     return self._internal_ctx.num_hosts
    120 
    121   @property
    122   def current_host(self):
    123     """The current host index for the TPU system."""
    124     return self._invocation_index
    125 
    126   @property
    127   def num_of_replicas_per_host(self):
    128     """The number of replicas for each host."""
    129     if self._internal_ctx.model_parallelism_enabled:
    130       raise ValueError(
    131           'num_of_replicas_per_host is not supported for model_parallelism')
    132     return self._internal_ctx.num_of_replicas_per_host
    133 
    134   @property
    135   def device_assignment(self):
    136     """Returns device_assignment object."""
    137     if self._call_from_input_fn:
    138       raise RuntimeError('This TPUContext instance must not be called from'
    139                          ' input_fn.')
    140     return self._internal_ctx.device_assignment
    141 
    142   def device_for_replica(self, replica_id):
    143     """Returns the tuple of (CPU device and device ordinal) for replica.
    144 
    145     This should be used for full replicate for non-model-parallelism.
    146 
    147     Args:
    148        replica_id: Int, the replica index.
    149 
    150     Returns:
    151        A tuple of device spec for CPU device and int device ordinal.
    152     """
    153     # Note that: For the non-model parallelism, the mapping could be
    154     # a random permutation. The order should not matter in most cases
    155     # as far as model is replicated to all cores in the system.
    156     return self._internal_ctx.device_for_replica(replica_id)
    157 
    158   @property
    159   def tpu_host_placement_function(self):
    160     """Returns the TPU host place function.
    161 
    162     The place function takes host_id as the input and returns the TF device
    163     for the correspoding host.
    164     """
    165 
    166     def _placement_function(host_id):
    167       """Return the host device given host_id."""
    168       return self._internal_ctx.tpu_host_placement_function(host_id=host_id)
    169 
    170     return _placement_function
    171 
    172 
    173 class _InternalTPUContext(object):
    174   """A context holds immutable states of TPU computation.
    175 
    176   This immutable object holds TPUEstimator config, train/eval batch size, and
    177   `TPUEstimator.use_tpu`, which is expected to be passed around. It also
    178   provides utility functions, based on the current state, to determine other
    179   information commonly required by TPU computation, such as TPU device names,
    180   TPU hosts, shard batch size, etc.
    181 
    182   if eval_on_tpu is False, then execution of eval on TPU is disabled.
    183   if eval_on_tpu is True, but use_tpu is False, a warning is issued,
    184   and TPU execution is disabled for all modes.
    185 
    186   N.B. As `mode` is not immutable state in Estimator, but essential to
    187   distinguish between TPU training and evaluation, a common usage for
    188   _InternalTPUContext with `mode` is as follows:
    189   ```
    190   with _ctx.with_mode(mode) as ctx:
    191     if ctx.is_running_on_cpu():
    192        ...
    193   ```
    194   """
    195 
    196   def __init__(self,
    197                config,
    198                train_batch_size,
    199                eval_batch_size,
    200                predict_batch_size,
    201                use_tpu,
    202                eval_on_tpu=True,
    203                embedding_config_spec=None):
    204     self._config = config
    205     self._train_batch_size = train_batch_size
    206     self._eval_batch_size = eval_batch_size
    207     self._predict_batch_size = predict_batch_size
    208     self._use_tpu = use_tpu
    209     logging.info('_TPUContext: eval_on_tpu %s', eval_on_tpu)
    210     if not use_tpu and eval_on_tpu:
    211       logging.warning('eval_on_tpu ignored because use_tpu is False.')
    212 
    213     self._eval_on_tpu = eval_on_tpu
    214     self._model_parallelism_enabled = (
    215         use_tpu and config.tpu_config.num_cores_per_replica)
    216     self._mode = None
    217     num_cores_per_replica = config.tpu_config.num_cores_per_replica
    218     if self._model_parallelism_enabled:
    219       self._computation_shape = _NUM_CORES_TO_COMPUTATION_SHAPE[
    220           num_cores_per_replica]
    221     else:
    222       self._computation_shape = None
    223     self._lazy_tpu_system_metadata_dict = {}  # key by master address
    224     self._lazy_device_assignment_dict = {}  # key by master address
    225     self._lazy_validation_dict = {}  # key by ModeKeys
    226     self._embedding_config_spec = embedding_config_spec
    227     self._lazy_embedding_config_dict = {}  # key by master address
    228 
    229   def _assert_mode(self):
    230     if self._mode is None:
    231       raise RuntimeError(
    232           '`mode` needs to be set via contextmanager `with_mode`.')
    233     return self._mode
    234 
    235   @contextmanager
    236   def with_mode(self, mode):
    237     # NOTE(xiejw): Shallow copy is enough. It will share he lazy dictionaries,
    238     # such as _lazy_tpu_system_metadata_dict between new copy and the original
    239     # one. Note that all lazy states stored in properties _lazy_foo are sort of
    240     # immutable as they should be same for the process lifetime.
    241     new_ctx = copy.copy(self)
    242     new_ctx._mode = mode  # pylint: disable=protected-access
    243     yield new_ctx
    244 
    245   @property
    246   def mode(self):
    247     return self._assert_mode()
    248 
    249   def _get_master_address(self):
    250     mode = self._assert_mode()
    251     config = self._config
    252     master = (
    253         config.master
    254         if mode != model_fn_lib.ModeKeys.EVAL else config.evaluation_master)
    255     return master
    256 
    257   def _get_tpu_system_metadata(self):
    258     """Gets the (maybe cached) TPU system metadata."""
    259     master = self._get_master_address()
    260     tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master)
    261     if tpu_system_metadata is not None:
    262       return tpu_system_metadata
    263 
    264     cluster_def = None
    265     if (self._config.session_config and
    266         self._config.session_config.cluster_def.job):
    267       cluster_def = self._config.session_config.cluster_def
    268 
    269     # pylint: disable=protected-access
    270     tpu_system_metadata = (
    271         tpu_system_metadata_lib._query_tpu_system_metadata(
    272             master,
    273             cluster_def=cluster_def,
    274             query_topology=self.model_parallelism_enabled))
    275 
    276     self._lazy_tpu_system_metadata_dict[master] = tpu_system_metadata
    277     return tpu_system_metadata
    278 
    279   def _get_device_assignment(self):
    280     """Gets the (maybe cached) TPU device assignment."""
    281     master = self._get_master_address()
    282     device_assignment = self._lazy_device_assignment_dict.get(master)
    283     if device_assignment is not None:
    284       return device_assignment
    285 
    286     tpu_system_metadata = self._get_tpu_system_metadata()
    287 
    288     device_assignment = tpu_device_assignment.device_assignment(
    289         tpu_system_metadata.topology,
    290         computation_shape=self._computation_shape,
    291         num_replicas=self.num_replicas)
    292 
    293     logging.info('num_cores_per_replica: %s',
    294                  str(self._config.tpu_config.num_cores_per_replica))
    295     logging.info('computation_shape: %s', str(self._computation_shape))
    296     logging.info('num_replicas: %d', self.num_replicas)
    297     logging.info('device_assignment.topology.device_coordinates: %s',
    298                  str(device_assignment.topology.device_coordinates))
    299     logging.info('device_assignment.core_assignment: %s',
    300                  str(device_assignment.core_assignment))
    301 
    302     self._lazy_device_assignment_dict[master] = device_assignment
    303     return device_assignment
    304 
    305   @property
    306   def embedding_config(self):
    307     """Returns the embedding config based on current mode."""
    308     master = self._get_master_address()
    309     if master in self._lazy_embedding_config_dict:
    310       embedding_config = self._lazy_embedding_config_dict[master]
    311     else:
    312       embedding_config = None
    313       if self._use_tpu and self._embedding_config_spec:
    314         embedding_config = _tpu_estimator_embedding.EmbeddingConfig(
    315             self._embedding_config_spec, self._train_batch_size,
    316             self._eval_batch_size, self.num_hosts, self.num_cores, self.config)
    317         if not embedding_config.has_embedding_tables():
    318           embedding_config = None
    319       self._lazy_embedding_config_dict[master] = embedding_config
    320 
    321     if embedding_config is not None:
    322       mode = self._assert_mode()
    323       # Dynamically attach tpu_embedding based on mode. With
    324       # this, we could keep embedding_config immutable but call site always
    325       # accesses the unified API '.tpu_embedding'.
    326       embedding_config.tpu_embedding = embedding_config.get_tpu_embedding(mode)
    327     return embedding_config
    328 
    329   @property
    330   def model_parallelism_enabled(self):
    331     return self._model_parallelism_enabled
    332 
    333   @property
    334   def input_partition_dims(self):
    335     return self._config.tpu_config.input_partition_dims
    336 
    337   @property
    338   def device_assignment(self):
    339     return (self._get_device_assignment()
    340             if self._model_parallelism_enabled else None)
    341 
    342   @property
    343   def num_of_cores_per_host(self):
    344     metadata = self._get_tpu_system_metadata()
    345     return metadata.num_of_cores_per_host
    346 
    347   @property
    348   def num_cores(self):
    349     metadata = self._get_tpu_system_metadata()
    350     return metadata.num_cores
    351 
    352   @property
    353   def num_of_replicas_per_host(self):
    354     """Return the number of replicas per host."""
    355     if self.model_parallelism_enabled:
    356       return self.num_replicas // self.num_hosts
    357     else:
    358       return self.num_of_cores_per_host
    359 
    360   @property
    361   def num_replicas(self):
    362     num_cores_in_system = self.num_cores
    363 
    364     if self.model_parallelism_enabled:
    365       num_cores_per_replica = self._config.tpu_config.num_cores_per_replica
    366       if num_cores_per_replica > num_cores_in_system:
    367         raise ValueError(
    368             'The num of cores required by the model parallelism, specified by '
    369             'TPUConfig.num_cores_per_replica, is larger than the total num of '
    370             'TPU cores in the system. num_cores_per_replica: {}, num cores '
    371             'in the system: {}'.format(num_cores_per_replica,
    372                                        num_cores_in_system))
    373 
    374       if num_cores_in_system % num_cores_per_replica != 0:
    375         raise RuntimeError(
    376             'The num of cores in the system ({}) is not divisible by the num '
    377             'of cores ({}) required by the model parallelism, specified by '
    378             'TPUConfig.num_cores_per_replica. This should never happen!'.format(
    379                 num_cores_in_system, num_cores_per_replica))
    380 
    381       return num_cores_in_system // num_cores_per_replica
    382     else:
    383       return num_cores_in_system
    384 
    385   @property
    386   def num_hosts(self):
    387     metadata = self._get_tpu_system_metadata()
    388     return metadata.num_hosts
    389 
    390   @property
    391   def config(self):
    392     return self._config
    393 
    394   def is_input_sharded_per_core(self):
    395     """Return true if input_fn is invoked per-core (other than per-host)."""
    396     mode = self._assert_mode()
    397     return (mode == model_fn_lib.ModeKeys.TRAIN and
    398             (self._config.tpu_config.per_host_input_for_training is
    399              tpu_config.InputPipelineConfig.PER_SHARD_V1))
    400 
    401   def is_input_per_host_with_iterators(self):
    402     """Return true if input_fn should be run in the per-host v2 config."""
    403     return (self._config.tpu_config.per_host_input_for_training is
    404             tpu_config.InputPipelineConfig.PER_HOST_V2)
    405 
    406   def is_input_broadcast_with_iterators(self):
    407     """Return true if input_fn should be run in the full_replicae config."""
    408     mode = self._assert_mode()
    409     return ((self._config.tpu_config.per_host_input_for_training is
    410              tpu_config.InputPipelineConfig.BROADCAST) or
    411             (mode != model_fn_lib.ModeKeys.TRAIN and
    412              self._config.tpu_config.eval_training_input_configuration is
    413              tpu_config.InputPipelineConfig.SLICED))
    414 
    415   def is_running_on_cpu(self, is_export_mode=False):
    416     """Determines whether the input_fn and model_fn should be invoked on CPU.
    417 
    418     This API also validates user provided configuration, such as batch size,
    419     according the lazy initialized TPU system metadata.
    420 
    421     Args:
    422       is_export_mode: Indicates whether the current mode is for exporting the
    423         model, when mode == PREDICT. Only with this bool, we could
    424         tell whether user is calling the Estimator.predict or
    425         Estimator.export_savedmodel, which are running on TPU and CPU
    426         respectively. Parent class Estimator does not distinguish these two.
    427 
    428     Returns:
    429       bool, whether current input_fn or model_fn should be running on CPU.
    430 
    431     Raises:
    432       ValueError: any configuration is invalid.
    433     """
    434 
    435     is_running_on_cpu = self._is_running_on_cpu(is_export_mode)
    436     if not is_running_on_cpu:
    437       self._validate_tpu_configuration()
    438     return is_running_on_cpu
    439 
    440   def _is_running_on_cpu(self, is_export_mode):
    441     """Determines whether the input_fn and model_fn should be invoked on CPU."""
    442     mode = self._assert_mode()
    443 
    444     if not self._use_tpu:
    445       return True
    446 
    447     if mode == model_fn_lib.ModeKeys.EVAL and not self._eval_on_tpu:
    448       logging.info('_is_running_on_cpu: eval_on_tpu disabled')
    449       return True
    450 
    451     if is_export_mode:
    452       return True
    453 
    454     return False
    455 
    456   @property
    457   def global_batch_size(self):
    458     mode = self._assert_mode()
    459     if mode == model_fn_lib.ModeKeys.TRAIN:
    460       return self._train_batch_size
    461     elif mode == model_fn_lib.ModeKeys.EVAL:
    462       return self._eval_batch_size
    463     elif mode == model_fn_lib.ModeKeys.PREDICT:
    464       return self._predict_batch_size
    465     else:
    466       return None
    467 
    468   @property
    469   def batch_size_for_input_fn(self):
    470     """Returns the shard batch size for `input_fn`."""
    471     global_batch_size = self.global_batch_size
    472     if (self.is_running_on_cpu() or self.is_input_broadcast_with_iterators()):
    473       return global_batch_size
    474 
    475     # On TPU
    476     if self.is_input_sharded_per_core() or (
    477         self.is_input_per_host_with_iterators()):
    478       return global_batch_size // self.num_replicas
    479     else:
    480       return global_batch_size // self.num_hosts
    481 
    482   @property
    483   def batch_size_for_model_fn(self):
    484     """Returns the shard batch size for `model_fn`."""
    485     global_batch_size = self.global_batch_size
    486 
    487     if (self.is_running_on_cpu() or self.is_input_broadcast_with_iterators()):
    488       return global_batch_size
    489 
    490     # On TPU. always sharded per shard.
    491     return global_batch_size // self.num_replicas
    492 
    493   @property
    494   def master_job(self):
    495     """Returns the job name to use to place TPU computations on.
    496 
    497     Returns:
    498       A string containing the job name, or None if no job should be specified.
    499 
    500     Raises:
    501       ValueError: If the user needs to specify a tpu_job_name, because we are
    502         unable to infer the job name automatically, or if the user-specified job
    503         names are inappropriate.
    504     """
    505     run_config = self._config
    506     # If the user specifies the tpu_job_name, use that.
    507     if run_config.tpu_config.tpu_job_name:
    508       return run_config.tpu_config.tpu_job_name
    509 
    510     # The tpu job is determined by the run_config. Right now, this method is
    511     # required as tpu_config is not part of the RunConfig.
    512     mode = self._assert_mode()
    513     master = (
    514         run_config.evaluation_master
    515         if mode == model_fn_lib.ModeKeys.EVAL else run_config.master)
    516     cluster_def = (run_config.session_config.cluster_def
    517                    if run_config.session_config else None)
    518 
    519     return tpu_system_metadata_lib.master_job(master, cluster_def)
    520 
    521   @property
    522   def tpu_host_placement_function(self):
    523     """Returns the TPU host place function."""
    524 
    525     master = self.master_job
    526 
    527     def _placement_function(_sentinal=None, replica_id=None, host_id=None):  # pylint: disable=invalid-name
    528       """Return the host device given replica_id or host_id."""
    529       assert _sentinal is None
    530       if replica_id is not None and host_id is not None:
    531         raise RuntimeError(
    532             'replica_id and host_id can have only one non-None value.')
    533 
    534       if master is None:
    535         return '/replica:0/task:0/device:CPU:0'
    536       else:
    537         if replica_id is not None:
    538           if self.model_parallelism_enabled:
    539             return self.device_assignment.host_device(
    540                 replica=replica_id, job=master)
    541           else:
    542             host_id = replica_id / self.num_of_cores_per_host
    543 
    544         return '/job:%s/task:%d/device:CPU:0' % (master, host_id)
    545 
    546     return _placement_function
    547 
    548   @property
    549   def tpu_device_placement_function(self):
    550     """Returns a TPU device placement Fn."""
    551     master = self.master_job
    552     job_device = '' if master is None else ('/job:%s' % master)
    553 
    554     def _placement_function(i):
    555       if self.model_parallelism_enabled:
    556         return self.device_assignment.tpu_device(replica=i, job=master)
    557       else:
    558         num_of_cores_per_host = self.num_of_cores_per_host
    559         host_id = i / num_of_cores_per_host
    560         ordinal_id = i % num_of_cores_per_host
    561         return '%s/task:%d/device:TPU:%d' % (job_device, host_id, ordinal_id)
    562 
    563     return _placement_function
    564 
    565   def tpu_ordinal_function(self, host_id):
    566     """Returns the TPU ordinal fn."""
    567 
    568     def _tpu_ordinal_function(shard_index_in_host):
    569       """Return the TPU ordinal associated with a shard.
    570 
    571       Required because the enqueue ops are placed on CPU.
    572 
    573       Args:
    574         shard_index_in_host: the shard index
    575 
    576       Returns:
    577         The ordinal of the TPU device the shard's infeed should be placed on.
    578       """
    579       if self.model_parallelism_enabled:
    580         # We put both enqueue/dequeue ops at tpu.core(0) in each replica.
    581         replica = self.device_assignment.lookup_replicas(host_id,
    582                                                          0)[shard_index_in_host]
    583         return self.device_assignment.tpu_ordinal(replica=replica)
    584       else:
    585         return shard_index_in_host % self.num_of_cores_per_host
    586 
    587     return _tpu_ordinal_function
    588 
    589   def _validate_tpu_configuration(self):
    590     """Validates the configuration based on the TPU system metadata."""
    591     mode = self._assert_mode()
    592     if self._lazy_validation_dict.get(mode):
    593       return
    594 
    595     # All following information is obtained from TPU system metadata.
    596     num_cores = self.num_cores
    597     num_replicas = self.num_replicas
    598     num_hosts = self.num_hosts
    599 
    600     if not num_cores:
    601       tpu_system_metadata = self._get_tpu_system_metadata()
    602       raise RuntimeError(
    603           'Cannot find any TPU cores in the system. Please double check '
    604           'Tensorflow master address and TPU worker(s). Available devices '
    605           'are {}.'.format(tpu_system_metadata.devices))
    606 
    607     if self._config.tpu_config.num_shards:
    608       user_provided_num_replicas = self._config.tpu_config.num_shards
    609       if user_provided_num_replicas != num_replicas:
    610         message = (
    611             'TPUConfig.num_shards is not set correctly. According to TPU '
    612             'system metadata for Tensorflow master ({}): num_replicas should '
    613             'be ({}), got ({}). For non-model-parallelism, num_replicas should '
    614             'be the total num of TPU cores in the system. For '
    615             'model-parallelism, the total number of TPU cores should be '
    616             'num_cores_per_replica * num_replicas. Please set it '
    617             'accordingly or leave it as `None`'.format(
    618                 self._get_master_address(), num_replicas,
    619                 user_provided_num_replicas))
    620 
    621         raise ValueError(message)
    622 
    623     if self._config.tpu_config.num_cores_per_replica:
    624       num_cores_per_replica = self._config.tpu_config.num_cores_per_replica
    625       num_cores_per_host = self._get_tpu_system_metadata().num_of_cores_per_host
    626       if num_cores_per_replica > num_cores_per_host:
    627         raise ValueError(
    628             'The num of cores required by the model parallelism, specified by '
    629             'TPUConfig.num_cores_per_replica, is larger than the '
    630             'num_cores_per_host. num_cores_per_replica: {}, '
    631             'num_cores_per_host: {}'.format(num_cores_per_replica,
    632                                             num_cores_per_host))
    633 
    634     if mode == model_fn_lib.ModeKeys.TRAIN:
    635       if (self._train_batch_size % num_replicas != 0 and
    636           not self.is_input_broadcast_with_iterators()):
    637         raise ValueError(
    638             'train batch size {} must be divisible by number of replicas {}'
    639             .format(self._train_batch_size, num_replicas))
    640 
    641     elif mode == model_fn_lib.ModeKeys.EVAL:
    642       if self._eval_batch_size is None:
    643         raise ValueError(
    644             'eval_batch_size in TPUEstimator constructor cannot be `None`'
    645             'if .evaluate is running on TPU.')
    646       if (self._eval_batch_size % num_replicas != 0 and
    647           not self.is_input_broadcast_with_iterators()):
    648         raise ValueError(
    649             'eval batch size {} must be divisible by number of replicas {}'
    650             .format(self._eval_batch_size, num_replicas))
    651       if num_hosts > 1 and not self.is_input_broadcast_with_iterators():
    652         raise ValueError(
    653             'TPUEstimator.evaluate should be running on single TPU'
    654             ' instead of a Pod.')
    655     else:
    656       assert mode == model_fn_lib.ModeKeys.PREDICT
    657       if self._predict_batch_size is None:
    658         raise ValueError(
    659             'predict_batch_size in TPUEstimator constructor should not be '
    660             '`None` if .predict is running on TPU.')
    661       if (self._predict_batch_size % num_replicas != 0 and
    662           not self.is_input_broadcast_with_iterators()):
    663         raise ValueError(
    664             'predict batch size {} must be divisible by number of replicas {}'
    665             .format(self._predict_batch_size, num_replicas))
    666       if num_hosts > 1 and not self.is_input_broadcast_with_iterators():
    667         raise ValueError(
    668             'TPUEstimator.predict should be running on single TPU worker. '
    669             'got {}.'.format(num_hosts))
    670 
    671     # Record the state "validated" into lazy dictionary.
    672     self._lazy_validation_dict[mode] = True
    673 
    674   def device_for_replica(self, replica_id):
    675     """Returns the tuple of (CPU device and device ordinal) for replica.
    676 
    677     This should be used for full replicate for non-model-parallelism.
    678 
    679     Args:
    680        replica_id: Int, the replica index.
    681 
    682     Returns:
    683        A tuple of device spec for CPU device and int device ordinal.
    684     """
    685     master = self.master_job
    686 
    687     if self.model_parallelism_enabled:
    688       return (self.device_assignment.host_device(
    689           replica=replica_id, job=master),
    690               self.device_assignment.tpu_ordinal(replica=replica_id))
    691 
    692     job_device = '' if master is None else ('/job:%s' % master)
    693 
    694     num_of_replicas_per_host = self.num_of_replicas_per_host
    695     host_id = replica_id / num_of_replicas_per_host
    696     ordinal_id = replica_id % num_of_replicas_per_host
    697 
    698     host_device = '%s/task:%d/device:CPU:0' % (job_device, host_id)
    699     return (host_device, ordinal_id)
    700 
    701 
    702 class _OneCoreTPUContext(_InternalTPUContext):
    703   """Special _InternalTPUContext for one core usage."""
    704 
    705   def __init__(self, config, train_batch_size, eval_batch_size,
    706                predict_batch_size, use_tpu):
    707 
    708     super(_OneCoreTPUContext, self).__init__(
    709         config, train_batch_size, eval_batch_size,
    710         predict_batch_size, use_tpu)
    711 
    712   def _get_tpu_system_metadata(self):
    713     """Gets the (maybe cached) TPU system metadata."""
    714     master = self._get_master_address()
    715     tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master)
    716     if tpu_system_metadata is not None:
    717       return tpu_system_metadata
    718 
    719     tpu_system_metadata = (
    720         tpu_system_metadata_lib._TPUSystemMetadata(  # pylint: disable=protected-access
    721             num_cores=1,
    722             num_hosts=1,
    723             num_of_cores_per_host=1,
    724             topology=None,
    725             devices=[]))
    726 
    727     self._lazy_tpu_system_metadata_dict[master] = tpu_system_metadata
    728     return tpu_system_metadata
    729 
    730 
    731 def _get_tpu_context(config, train_batch_size, eval_batch_size,
    732                      predict_batch_size, use_tpu, eval_on_tpu,
    733                      embedding_config_spec):
    734   """Returns an instance of `_InternalTPUContext`."""
    735 
    736   if (config.tpu_config.num_shards == 1 and
    737       config.tpu_config.num_cores_per_replica is None):
    738     if embedding_config_spec is not None:
    739       raise ValueError('Setting TPUConfig.num_shards==1 is unsupported '
    740                        'when embedding_config_spec is not None.')
    741     logging.warning(
    742         'Setting TPUConfig.num_shards==1 is an unsupported behavior. '
    743         'Please fix as soon as possible (leaving num_shards as None.)')
    744     return _OneCoreTPUContext(config, train_batch_size, eval_batch_size,
    745                               predict_batch_size, use_tpu)
    746 
    747   return _InternalTPUContext(config, train_batch_size, eval_batch_size,
    748                              predict_batch_size, use_tpu, eval_on_tpu,
    749                              embedding_config_spec)
    750