Home | History | Annotate | Download | only in tpu
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ===================================================================
     15 
     16 """A RunConfig subclass with TPU support."""
     17 
     18 from __future__ import absolute_import
     19 from __future__ import division
     20 from __future__ import print_function
     21 
     22 import collections
     23 import json
     24 import os
     25 
     26 import numpy as np
     27 
     28 from tensorflow.contrib.tpu.python.tpu import util as util_lib
     29 from tensorflow.python.estimator import run_config as run_config_lib
     30 from tensorflow.python.platform import tf_logging as logging
     31 
     32 # pylint: disable=protected-access
     33 _TF_CONFIG_ENV = run_config_lib._TF_CONFIG_ENV
     34 _SERVICE_KEY = run_config_lib._SERVICE_KEY
     35 _TPU_WORKER_JOB_NAME = 'tpu_worker_job_name'
     36 _NUM_CORES_PER_HOST = 8
     37 
     38 # pylint: enable=protected-access
     39 
     40 
     41 # TODO(b/72511246) Provide a simplified api to configure model parallelism.
     42 class TPUConfig(
     43     collections.namedtuple('TPUConfig', [
     44         'iterations_per_loop',
     45         'num_shards',
     46         'computation_shape',
     47         'per_host_input_for_training',
     48         'tpu_job_name',
     49         'initial_infeed_sleep_secs',
     50     ])):
     51   r"""TPU related configuration required by `TPUEstimator`.
     52 
     53   Args:
     54     iterations_per_loop: This is the number of train steps running in TPU
     55       system before returning to CPU host for each `Session.run`. This means
     56       global step is increased `iterations_per_loop` times in one `Session.run`.
     57       It is recommended to be set as number of global steps for next checkpoint.
     58     num_shards: (Deprecated, ignored by TPUEstimator).
     59       The number of model replicas in the system. For non-model-parallelism
     60       case, this number equals the total number of TPU cores. For
     61       model-parallelism, the total number of TPU cores equals
     62       product(computation_shape) * num_shards.
     63     computation_shape: Defaults to `None`, which disables model parallelism. A
     64       list of size 3 which describes the shape of a model replica's block of
     65       cores. This is required by model-parallelism which enables partitioning
     66       the model to multiple cores. For example, [2, 2, 1] means the model is
     67       partitioned across 4 cores which span two cores in both x and y
     68       coordinates.  Please refer to ${tf.contrib.tpu.TopologyProto} for the
     69       geometry of a TPU mesh.
     70     per_host_input_for_training: If `True`, `input_fn` is invoked Per-Host
     71       rather than Per-Core. With Per-Host input pipeline deployment, `input_fn`
     72       is invoked once on each host. With Per-Core input pipeline deployment, it
     73       is invoked once for each core. To be precise, with a global batch size
     74       `train_batch_size` in `TPUEstimator` constructor, the batch size for each
     75       shard is `train_batch_size` // #hosts. With Per-Core input pipeline
     76       deployment, the shard batch size is `train_batch_size` // #cores.
     77     tpu_job_name: The name of the TPU job. Typically, this name is auto-inferred
     78       within TPUEstimator, however when using ClusterSpec propagation in more
     79       esoteric cluster configurations, you may need to specify the job name as a
     80       string.
     81     initial_infeed_sleep_secs: The number of seconds the infeed thread should
     82       wait before enqueueing the first batch. This helps avoid timeouts for
     83       models that require a long compilation time.
     84 
     85     Raises:
     86       ValueError: If `computation_shape` or `computation_shape` are invalid.
     87   """
     88 
     89   def __new__(cls,
     90               iterations_per_loop=2,
     91               num_shards=None,
     92               computation_shape=None,
     93               per_host_input_for_training=True,
     94               tpu_job_name=None,
     95               initial_infeed_sleep_secs=None):
     96 
     97     # Check iterations_per_loop.
     98     util_lib.check_positive_integer(iterations_per_loop,
     99                                     'TPUConfig iterations_per_loop')
    100 
    101     # Check num_shards.
    102     if num_shards is not None:
    103       util_lib.check_positive_integer(num_shards, 'TPUConfig num_shards')
    104 
    105     # Check computation_shape
    106     if computation_shape is not None and len(computation_shape) != 3:
    107       raise ValueError(
    108           'computation_shape must be a list with length 3 or None; got {}'.
    109           format(str(computation_shape)))
    110 
    111     if computation_shape is not None:
    112       computation_shape_array = np.asarray(computation_shape, dtype=np.int32)
    113       # This prevents any computation being replicated across multiple hosts, so
    114       # that each host feeds the same number of computations.
    115       if any(computation_shape_array < 1) or any(computation_shape_array > 2):
    116         raise ValueError('computation_shape elements can only be 1 or 2; got '
    117                          'computation_shape={}'.format(computation_shape))
    118 
    119     # Check initial_infeed_sleep_secs.
    120     if initial_infeed_sleep_secs:
    121       util_lib.check_positive_integer(initial_infeed_sleep_secs,
    122                                       'TPUConfig initial_infeed_sleep_secs')
    123 
    124     tpu_job_name = tpu_job_name or _get_tpu_job_name_from_tf_config()
    125 
    126     return super(TPUConfig, cls).__new__(
    127         cls,
    128         iterations_per_loop=iterations_per_loop,
    129         num_shards=num_shards,
    130         computation_shape=computation_shape,
    131         per_host_input_for_training=per_host_input_for_training,
    132         tpu_job_name=tpu_job_name,
    133         initial_infeed_sleep_secs=initial_infeed_sleep_secs)
    134 
    135 
    136 class RunConfig(run_config_lib.RunConfig):
    137   """RunConfig with TPU support."""
    138 
    139   def __init__(self,
    140                tpu_config=None,
    141                evaluation_master=None,
    142                master=None,
    143                **kwargs):
    144     """Constructs a RunConfig.
    145 
    146     Args:
    147       tpu_config: the TPUConfig that specifies TPU-specific configuration.
    148       evaluation_master: a string. The address of the master to use for eval.
    149         Defaults to master if not set.
    150       master: a string. The address of the master to use for training.
    151       **kwargs: keyword config parameters.
    152     """
    153     super(RunConfig, self).__init__(**kwargs)
    154     self._tpu_config = tpu_config or TPUConfig()
    155 
    156     # If user sets master and/or evaluation_master explicilty, including empty
    157     # string '', take it. Otherwise, take the values set by parent class.
    158     if master is not None:
    159       self._master = master
    160 
    161     if evaluation_master is not None:
    162       self._evaluation_master = evaluation_master
    163     elif (not self._evaluation_master and
    164           self.task_type != run_config_lib.TaskType.EVALUATOR):
    165       # If the task type is EVALUATOR, it means some cluster manager sets the
    166       # TF_CONFIG. In that case, we respect the configuration in TF_CONFIG.
    167       #
    168       # Otherwise, it means user executes the code without external cluster
    169       # manager. For that, we optimize the user experience by setting
    170       # evaluation_master to master, unless user overwrites it.
    171       self._evaluation_master = self._master
    172 
    173   @property
    174   def evaluation_master(self):
    175     return self._evaluation_master
    176 
    177   @property
    178   def master(self):
    179     return self._master
    180 
    181   @property
    182   def tpu_config(self):
    183     return self._tpu_config
    184 
    185   def replace(self, **kwargs):
    186     if 'tpu_config' not in kwargs:
    187       return super(RunConfig, self).replace(**kwargs)
    188 
    189     tpu_config = kwargs.pop('tpu_config')
    190     new_instance = super(RunConfig, self).replace(**kwargs)
    191     new_instance._tpu_config = tpu_config  # pylint: disable=protected-access
    192     return new_instance
    193 
    194 
    195 def _get_tpu_job_name_from_tf_config():
    196   """Extracts the TPU job name from TF_CONFIG env variable."""
    197   # TODO(xiejw): Extends this to support both TF_CONFIG env variable and cluster
    198   # spec propagation.
    199   tf_config = json.loads(os.environ.get(_TF_CONFIG_ENV, '{}'))
    200   tpu_job_name = tf_config.get(_SERVICE_KEY, {}).get(_TPU_WORKER_JOB_NAME)
    201   if tpu_job_name:
    202     logging.info('Load TPU job name from TF_CONFIG: %s', tpu_job_name)
    203   return tpu_job_name
    204