1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # =================================================================== 15 16 """A RunConfig subclass with TPU support.""" 17 18 from __future__ import absolute_import 19 from __future__ import division 20 from __future__ import print_function 21 22 import collections 23 import json 24 import os 25 26 import numpy as np 27 28 from tensorflow.contrib.tpu.python.tpu import util as util_lib 29 from tensorflow.python.estimator import run_config as run_config_lib 30 from tensorflow.python.platform import tf_logging as logging 31 32 # pylint: disable=protected-access 33 _TF_CONFIG_ENV = run_config_lib._TF_CONFIG_ENV 34 _SERVICE_KEY = run_config_lib._SERVICE_KEY 35 _TPU_WORKER_JOB_NAME = 'tpu_worker_job_name' 36 _NUM_CORES_PER_HOST = 8 37 38 # pylint: enable=protected-access 39 40 41 # TODO(b/72511246) Provide a simplified api to configure model parallelism. 42 class TPUConfig( 43 collections.namedtuple('TPUConfig', [ 44 'iterations_per_loop', 45 'num_shards', 46 'computation_shape', 47 'per_host_input_for_training', 48 'tpu_job_name', 49 'initial_infeed_sleep_secs', 50 ])): 51 r"""TPU related configuration required by `TPUEstimator`. 52 53 Args: 54 iterations_per_loop: This is the number of train steps running in TPU 55 system before returning to CPU host for each `Session.run`. This means 56 global step is increased `iterations_per_loop` times in one `Session.run`. 57 It is recommended to be set as number of global steps for next checkpoint. 58 num_shards: (Deprecated, ignored by TPUEstimator). 59 The number of model replicas in the system. For non-model-parallelism 60 case, this number equals the total number of TPU cores. For 61 model-parallelism, the total number of TPU cores equals 62 product(computation_shape) * num_shards. 63 computation_shape: Defaults to `None`, which disables model parallelism. A 64 list of size 3 which describes the shape of a model replica's block of 65 cores. This is required by model-parallelism which enables partitioning 66 the model to multiple cores. For example, [2, 2, 1] means the model is 67 partitioned across 4 cores which span two cores in both x and y 68 coordinates. Please refer to ${tf.contrib.tpu.TopologyProto} for the 69 geometry of a TPU mesh. 70 per_host_input_for_training: If `True`, `input_fn` is invoked Per-Host 71 rather than Per-Core. With Per-Host input pipeline deployment, `input_fn` 72 is invoked once on each host. With Per-Core input pipeline deployment, it 73 is invoked once for each core. To be precise, with a global batch size 74 `train_batch_size` in `TPUEstimator` constructor, the batch size for each 75 shard is `train_batch_size` // #hosts. With Per-Core input pipeline 76 deployment, the shard batch size is `train_batch_size` // #cores. 77 tpu_job_name: The name of the TPU job. Typically, this name is auto-inferred 78 within TPUEstimator, however when using ClusterSpec propagation in more 79 esoteric cluster configurations, you may need to specify the job name as a 80 string. 81 initial_infeed_sleep_secs: The number of seconds the infeed thread should 82 wait before enqueueing the first batch. This helps avoid timeouts for 83 models that require a long compilation time. 84 85 Raises: 86 ValueError: If `computation_shape` or `computation_shape` are invalid. 87 """ 88 89 def __new__(cls, 90 iterations_per_loop=2, 91 num_shards=None, 92 computation_shape=None, 93 per_host_input_for_training=True, 94 tpu_job_name=None, 95 initial_infeed_sleep_secs=None): 96 97 # Check iterations_per_loop. 98 util_lib.check_positive_integer(iterations_per_loop, 99 'TPUConfig iterations_per_loop') 100 101 # Check num_shards. 102 if num_shards is not None: 103 util_lib.check_positive_integer(num_shards, 'TPUConfig num_shards') 104 105 # Check computation_shape 106 if computation_shape is not None and len(computation_shape) != 3: 107 raise ValueError( 108 'computation_shape must be a list with length 3 or None; got {}'. 109 format(str(computation_shape))) 110 111 if computation_shape is not None: 112 computation_shape_array = np.asarray(computation_shape, dtype=np.int32) 113 # This prevents any computation being replicated across multiple hosts, so 114 # that each host feeds the same number of computations. 115 if any(computation_shape_array < 1) or any(computation_shape_array > 2): 116 raise ValueError('computation_shape elements can only be 1 or 2; got ' 117 'computation_shape={}'.format(computation_shape)) 118 119 # Check initial_infeed_sleep_secs. 120 if initial_infeed_sleep_secs: 121 util_lib.check_positive_integer(initial_infeed_sleep_secs, 122 'TPUConfig initial_infeed_sleep_secs') 123 124 tpu_job_name = tpu_job_name or _get_tpu_job_name_from_tf_config() 125 126 return super(TPUConfig, cls).__new__( 127 cls, 128 iterations_per_loop=iterations_per_loop, 129 num_shards=num_shards, 130 computation_shape=computation_shape, 131 per_host_input_for_training=per_host_input_for_training, 132 tpu_job_name=tpu_job_name, 133 initial_infeed_sleep_secs=initial_infeed_sleep_secs) 134 135 136 class RunConfig(run_config_lib.RunConfig): 137 """RunConfig with TPU support.""" 138 139 def __init__(self, 140 tpu_config=None, 141 evaluation_master=None, 142 master=None, 143 **kwargs): 144 """Constructs a RunConfig. 145 146 Args: 147 tpu_config: the TPUConfig that specifies TPU-specific configuration. 148 evaluation_master: a string. The address of the master to use for eval. 149 Defaults to master if not set. 150 master: a string. The address of the master to use for training. 151 **kwargs: keyword config parameters. 152 """ 153 super(RunConfig, self).__init__(**kwargs) 154 self._tpu_config = tpu_config or TPUConfig() 155 156 # If user sets master and/or evaluation_master explicilty, including empty 157 # string '', take it. Otherwise, take the values set by parent class. 158 if master is not None: 159 self._master = master 160 161 if evaluation_master is not None: 162 self._evaluation_master = evaluation_master 163 elif (not self._evaluation_master and 164 self.task_type != run_config_lib.TaskType.EVALUATOR): 165 # If the task type is EVALUATOR, it means some cluster manager sets the 166 # TF_CONFIG. In that case, we respect the configuration in TF_CONFIG. 167 # 168 # Otherwise, it means user executes the code without external cluster 169 # manager. For that, we optimize the user experience by setting 170 # evaluation_master to master, unless user overwrites it. 171 self._evaluation_master = self._master 172 173 @property 174 def evaluation_master(self): 175 return self._evaluation_master 176 177 @property 178 def master(self): 179 return self._master 180 181 @property 182 def tpu_config(self): 183 return self._tpu_config 184 185 def replace(self, **kwargs): 186 if 'tpu_config' not in kwargs: 187 return super(RunConfig, self).replace(**kwargs) 188 189 tpu_config = kwargs.pop('tpu_config') 190 new_instance = super(RunConfig, self).replace(**kwargs) 191 new_instance._tpu_config = tpu_config # pylint: disable=protected-access 192 return new_instance 193 194 195 def _get_tpu_job_name_from_tf_config(): 196 """Extracts the TPU job name from TF_CONFIG env variable.""" 197 # TODO(xiejw): Extends this to support both TF_CONFIG env variable and cluster 198 # spec propagation. 199 tf_config = json.loads(os.environ.get(_TF_CONFIG_ENV, '{}')) 200 tpu_job_name = tf_config.get(_SERVICE_KEY, {}).get(_TPU_WORKER_JOB_NAME) 201 if tpu_job_name: 202 logging.info('Load TPU job name from TF_CONFIG: %s', tpu_job_name) 203 return tpu_job_name 204