Home | History | Annotate | Download | only in estimator
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Environment configuration object for Estimators."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import copy
     22 import json
     23 import os
     24 
     25 import six
     26 
     27 from tensorflow.core.protobuf import config_pb2
     28 from tensorflow.python.platform import tf_logging as logging
     29 from tensorflow.python.training import server_lib
     30 from tensorflow.python.util import compat_internal
     31 from tensorflow.python.util.tf_export import tf_export
     32 
     33 
     34 _USE_DEFAULT = object()
     35 
     36 # A list of the property names in RunConfig that the user is allowed to change.
     37 _DEFAULT_REPLACEABLE_LIST = [
     38     'model_dir',
     39     'tf_random_seed',
     40     'save_summary_steps',
     41     'save_checkpoints_steps',
     42     'save_checkpoints_secs',
     43     'session_config',
     44     'keep_checkpoint_max',
     45     'keep_checkpoint_every_n_hours',
     46     'log_step_count_steps'
     47 ]
     48 
     49 _SAVE_CKPT_ERR = (
     50     '`save_checkpoints_steps` and `save_checkpoints_secs` cannot be both set.'
     51 )
     52 
     53 _TF_CONFIG_ENV = 'TF_CONFIG'
     54 _TASK_ENV_KEY = 'task'
     55 _TASK_TYPE_KEY = 'type'
     56 _TASK_ID_KEY = 'index'
     57 _CLUSTER_KEY = 'cluster'
     58 _SERVICE_KEY = 'service'
     59 _SESSION_MASTER_KEY = 'session_master'
     60 _EVAL_SESSION_MASTER_KEY = 'eval_session_master'
     61 _MODEL_DIR_KEY = 'model_dir'
     62 _LOCAL_MASTER = ''
     63 _GRPC_SCHEME = 'grpc://'
     64 
     65 
     66 def _get_session_master(cluster_spec, task_type, task_id, tf_config):
     67   """Returns the appropriate address for TensorFlow master.
     68 
     69   The order of precedence to deteremine the TF session master is as follows:
     70   1. If `tf_session_master` is set in TF_CONFIG environment variable, takes it.
     71   2. If the cluster has only one node, returns empty string ''.
     72   3. Returns the grpc address according to the task type and id in the cluster.
     73      This is between-graph replication.
     74 
     75   Note: task_type and task_id must be validated. Typically, validated using
     76   `_validate_task_type_and_task_id`.
     77 
     78   Args:
     79     cluster_spec: A `ClusterSpec` instance.
     80     task_type: String. Task type for current node.
     81     task_id: Int. Task id for current node.
     82     tf_config: Dict. Python dict for the TF_CONFIG environment variable.
     83 
     84   Raises:
     85     RuntimeError: If `cluster_spec` is not set.
     86 
     87   """
     88   if _SESSION_MASTER_KEY in tf_config:
     89     return tf_config[_SESSION_MASTER_KEY]
     90 
     91   if not cluster_spec:
     92     raise RuntimeError('Internal error: `_get_session_master` '
     93                        'does not expect empty cluster_spec.')
     94 
     95   jobs = cluster_spec.jobs
     96 
     97   # If there is only one node in the cluster, do things locally by setting
     98   # master to ''.  If a service or user sets TF_CONFIG with a single node, it's
     99   # more performant to use a direct master rather than an RPC service.
    100   if len(jobs) == 1 and len(cluster_spec.job_tasks(jobs[0])) == 1:
    101     return _LOCAL_MASTER
    102 
    103   # Lookup the master in cluster_spec using task_type and task_id,
    104   # if possible.
    105   addresses = cluster_spec.job_tasks(task_type)
    106   return _GRPC_SCHEME + addresses[task_id]
    107 
    108 
    109 def _get_eval_session_master(task_type, tf_config):
    110   """Returns the appropriate address for TensorFlow evaluation master."""
    111   if task_type == TaskType.EVALUATOR:
    112     return tf_config.get(_EVAL_SESSION_MASTER_KEY, _LOCAL_MASTER)
    113 
    114   if _EVAL_SESSION_MASTER_KEY in tf_config:
    115     raise ValueError('Key ({}) should not be set for task type other than {}. '
    116                      'Task type: {}'.format(_EVAL_SESSION_MASTER_KEY,
    117                                             TaskType.EVALUATOR, task_type))
    118   return _LOCAL_MASTER
    119 
    120 
    121 def _count_ps(cluster_spec):
    122   """Counts the number of parameter servers in cluster_spec."""
    123   if not cluster_spec:
    124     raise RuntimeError(
    125         'Internal error: `_count_ps` does not expect empty cluster_spec.')
    126 
    127   return len(cluster_spec.as_dict().get(TaskType.PS, []))
    128 
    129 
    130 def _count_worker(cluster_spec, chief_task_type):
    131   """Counts the number of workers (including chief) in cluster_spec."""
    132   if not cluster_spec:
    133     raise RuntimeError(
    134         'Internal error: `_count_worker` does not expect empty cluster_spec.')
    135 
    136   return (len(cluster_spec.as_dict().get(TaskType.WORKER, [])) +
    137           len(cluster_spec.as_dict().get(chief_task_type, [])))
    138 
    139 
    140 def _validate_service(service):
    141   """Validates the service key."""
    142   if service is not None and not isinstance(service, dict):
    143     raise TypeError(
    144         'If "service" is set in TF_CONFIG, it must be a dict. Given %s' %
    145         type(service))
    146   return service
    147 
    148 
    149 def _validate_task_type_and_task_id(cluster_spec, task_env, chief_task_type):
    150   """Validates the task type and index in `task_env` according to cluster."""
    151   if chief_task_type not in cluster_spec.jobs:
    152     raise ValueError(
    153         'If "cluster" is set in TF_CONFIG, it must have one "%s" node.' %
    154         chief_task_type)
    155   if len(cluster_spec.job_tasks(chief_task_type)) > 1:
    156     raise ValueError(
    157         'The "cluster" in TF_CONFIG must have only one "%s" node.' %
    158         chief_task_type)
    159 
    160   task_type = task_env.get(_TASK_TYPE_KEY, None)
    161   task_id = task_env.get(_TASK_ID_KEY, None)
    162 
    163   if not task_type:
    164     raise ValueError(
    165         'If "cluster" is set in TF_CONFIG, task type must be set.')
    166   if task_id is None:
    167     raise ValueError(
    168         'If "cluster" is set in TF_CONFIG, task index must be set.')
    169 
    170   task_id = int(task_id)
    171 
    172   # Check the task id bounds. Upper bound is not necessary as
    173   # - for evaluator, there is no upper bound.
    174   # - for non-evaluator, task id is upper bounded by the number of jobs in
    175   # cluster spec, which will be checked later (when retrieving the `master`)
    176   if task_id < 0:
    177     raise ValueError('Task index must be non-negative number.')
    178 
    179   # Evaluator is not part of the training cluster.
    180   if task_type == TaskType.EVALUATOR:
    181     return task_type, task_id
    182 
    183   if task_type not in cluster_spec.jobs:
    184     raise ValueError(
    185         '%s is not a valid task_type in the cluster_spec:\n'
    186         '%s\n\n'
    187         'Note that these values may be coming from the TF_CONFIG environment '
    188         'variable.' % (task_type, cluster_spec))
    189   addresses = cluster_spec.job_tasks(task_type)
    190   if not 0 <= task_id < len(addresses):
    191     raise ValueError(
    192         '%d is not a valid task_id for task_type %s in the cluster_spec:\n'
    193         '%s\n\n'
    194         'Note that these values may be coming from the TF_CONFIG environment '
    195         'variable.' % (task_id, task_type, cluster_spec))
    196 
    197   return task_type, task_id
    198 
    199 
    200 def _get_global_id_in_cluster(
    201     cluster_spec, task_type, task_id, chief_task_type):
    202   """Returns the global id in cluster."""
    203   # Note: This is implementation details, which user should not rely on.
    204   # The first id is 0, which is always for the `chief` node. All other nodes,
    205   # except `ps`, are ordered alphabetical based on task type (alphabetically)
    206   # and task id (ascendingly). `ps` are ordered last.
    207 
    208   # Sort task names in cluster
    209   task_type_ordered_list = [chief_task_type]
    210   task_type_ordered_list.extend([
    211       t for t in sorted(cluster_spec.jobs)
    212       if t != chief_task_type and t != TaskType.PS
    213   ])
    214   if TaskType.PS in cluster_spec.jobs:
    215     task_type_ordered_list.append(TaskType.PS)
    216 
    217   next_global_id = 0
    218   for t in task_type_ordered_list:
    219     if t == task_type:
    220       return next_global_id + task_id
    221     next_global_id += len(cluster_spec.job_tasks(t))
    222 
    223   # This should never happen.
    224   raise RuntimeError('Internal Error: `task_type` ({}) is not in '
    225                      'cluster_spec ({}).'.format(task_type, cluster_spec))
    226 
    227 
    228 def _validate_save_ckpt_with_replaced_keys(new_copy, replaced_keys):
    229   """Validates the save ckpt properties."""
    230   # Ensure one (and only one) of save_steps and save_secs is not None.
    231   # Also, if user sets one save ckpt property, say steps, the other one (secs)
    232   # should be set as None to improve usability.
    233 
    234   save_steps = new_copy.save_checkpoints_steps
    235   save_secs = new_copy.save_checkpoints_secs
    236 
    237   if ('save_checkpoints_steps' in replaced_keys and
    238       'save_checkpoints_secs' in replaced_keys):
    239     # If user sets both properties explicitly, we need to error out if both
    240     # are set or neither of them are set.
    241     if save_steps is not None and save_secs is not None:
    242       raise ValueError(_SAVE_CKPT_ERR)
    243   elif 'save_checkpoints_steps' in replaced_keys and save_steps is not None:
    244     new_copy._save_checkpoints_secs = None  # pylint: disable=protected-access
    245   elif 'save_checkpoints_secs' in replaced_keys and save_secs is not None:
    246     new_copy._save_checkpoints_steps = None  # pylint: disable=protected-access
    247 
    248 
    249 def _validate_properties(run_config):
    250   """Validates the properties."""
    251   def _validate(property_name, cond, message):
    252     property_value = getattr(run_config, property_name)
    253     if property_value is not None and not cond(property_value):
    254       raise ValueError(message)
    255 
    256   _validate('model_dir', lambda dir: dir,
    257             message='model_dir should be non-empty')
    258 
    259   _validate('save_summary_steps', lambda steps: steps >= 0,
    260             message='save_summary_steps should be >= 0')
    261 
    262   _validate('save_checkpoints_steps', lambda steps: steps >= 0,
    263             message='save_checkpoints_steps should be >= 0')
    264   _validate('save_checkpoints_secs', lambda secs: secs >= 0,
    265             message='save_checkpoints_secs should be >= 0')
    266 
    267   _validate('session_config',
    268             lambda sc: isinstance(sc, config_pb2.ConfigProto),
    269             message='session_config must be instance of ConfigProto')
    270 
    271   _validate('keep_checkpoint_max', lambda keep_max: keep_max >= 0,
    272             message='keep_checkpoint_max should be >= 0')
    273   _validate('keep_checkpoint_every_n_hours', lambda keep_hours: keep_hours > 0,
    274             message='keep_checkpoint_every_n_hours should be > 0')
    275   _validate('log_step_count_steps', lambda num_steps: num_steps > 0,
    276             message='log_step_count_steps should be > 0')
    277 
    278   _validate('tf_random_seed', lambda seed: isinstance(seed, six.integer_types),
    279             message='tf_random_seed must be integer.')
    280 
    281 
    282 class TaskType(object):
    283   MASTER = 'master'
    284   PS = 'ps'
    285   WORKER = 'worker'
    286   CHIEF = 'chief'
    287   EVALUATOR = 'evaluator'
    288 
    289 
    290 @tf_export('estimator.RunConfig')
    291 class RunConfig(object):
    292   """This class specifies the configurations for an `Estimator` run."""
    293 
    294   def __init__(self,
    295                model_dir=None,
    296                tf_random_seed=None,
    297                save_summary_steps=100,
    298                save_checkpoints_steps=_USE_DEFAULT,
    299                save_checkpoints_secs=_USE_DEFAULT,
    300                session_config=None,
    301                keep_checkpoint_max=5,
    302                keep_checkpoint_every_n_hours=10000,
    303                log_step_count_steps=100):
    304     """Constructs a RunConfig.
    305 
    306     All distributed training related properties `cluster_spec`, `is_chief`,
    307     `master` , `num_worker_replicas`, `num_ps_replicas`, `task_id`, and
    308     `task_type` are set based on the `TF_CONFIG` environment variable, if the
    309     pertinent information is present. The `TF_CONFIG` environment variable is a
    310     JSON object with attributes: `cluster` and `task`.
    311 
    312     `cluster` is a JSON serialized version of `ClusterSpec`'s Python dict from
    313     `server_lib.py`, mapping task types (usually one of the `TaskType` enums) to
    314     a list of task addresses.
    315 
    316     `task` has two attributes: `type` and `index`, where `type` can be any of
    317     the task types in `cluster`. ` When `TF_CONFIG` contains said information,
    318     the following properties are set on this class:
    319 
    320     * `cluster_spec` is parsed from `TF_CONFIG['cluster']`. Defaults to {}. If
    321       present, must have one and only one node in the `chief` attribute of
    322       `cluster_spec`.
    323     * `task_type` is set to `TF_CONFIG['task']['type']`. Must set if
    324       `cluster_spec` is present; must be `worker` (the default value) if
    325       `cluster_spec` is not set.
    326     * `task_id` is set to `TF_CONFIG['task']['index']`. Must set if
    327       `cluster_spec` is present; must be 0 (the default value) if
    328       `cluster_spec` is not set.
    329     * `master` is determined by looking up `task_type` and `task_id` in the
    330       `cluster_spec`. Defaults to ''.
    331     * `num_ps_replicas` is set by counting the number of nodes listed
    332       in the `ps` attribute of `cluster_spec`. Defaults to 0.
    333     * `num_worker_replicas` is set by counting the number of nodes listed
    334       in the `worker` and `chief` attributes of `cluster_spec`. Defaults to 1.
    335     * `is_chief` is determined based on `task_type` and `cluster`.
    336 
    337     There is a special node with `task_type` as `evaluator`, which is not part
    338     of the (training) `cluster_spec`. It handles the distributed evaluation job.
    339 
    340     Example of non-chief node:
    341     ```
    342       cluster = {'chief': ['host0:2222'],
    343                  'ps': ['host1:2222', 'host2:2222'],
    344                  'worker': ['host3:2222', 'host4:2222', 'host5:2222']}
    345       os.environ['TF_CONFIG'] = json.dumps(
    346           {'cluster': cluster,
    347            'task': {'type': 'worker', 'index': 1}})
    348       config = ClusterConfig()
    349       assert config.master == 'host4:2222'
    350       assert config.task_id == 1
    351       assert config.num_ps_replicas == 2
    352       assert config.num_worker_replicas == 4
    353       assert config.cluster_spec == server_lib.ClusterSpec(cluster)
    354       assert config.task_type == 'worker'
    355       assert not config.is_chief
    356     ```
    357 
    358     Example of chief node:
    359     ```
    360       cluster = {'chief': ['host0:2222'],
    361                  'ps': ['host1:2222', 'host2:2222'],
    362                  'worker': ['host3:2222', 'host4:2222', 'host5:2222']}
    363       os.environ['TF_CONFIG'] = json.dumps(
    364           {'cluster': cluster,
    365            'task': {'type': 'chief', 'index': 0}})
    366       config = ClusterConfig()
    367       assert config.master == 'host0:2222'
    368       assert config.task_id == 0
    369       assert config.num_ps_replicas == 2
    370       assert config.num_worker_replicas == 4
    371       assert config.cluster_spec == server_lib.ClusterSpec(cluster)
    372       assert config.task_type == 'chief'
    373       assert config.is_chief
    374     ```
    375 
    376     Example of evaluator node (evaluator is not part of training cluster):
    377     ```
    378       cluster = {'chief': ['host0:2222'],
    379                  'ps': ['host1:2222', 'host2:2222'],
    380                  'worker': ['host3:2222', 'host4:2222', 'host5:2222']}
    381       os.environ['TF_CONFIG'] = json.dumps(
    382           {'cluster': cluster,
    383            'task': {'type': 'evaluator', 'index': 0}})
    384       config = ClusterConfig()
    385       assert config.master == ''
    386       assert config.evaluator_master == ''
    387       assert config.task_id == 0
    388       assert config.num_ps_replicas == 0
    389       assert config.num_worker_replicas == 0
    390       assert config.cluster_spec == {}
    391       assert config.task_type == 'evaluator'
    392       assert not config.is_chief
    393     ```
    394 
    395     N.B.: If `save_checkpoints_steps` or `save_checkpoints_secs` is set,
    396     `keep_checkpoint_max` might need to be adjusted accordingly, especially in
    397     distributed training. For example, setting `save_checkpoints_secs` as 60
    398     without adjusting `keep_checkpoint_max` (defaults to 5) leads to situation
    399     that checkpoint would be garbage collected after 5 minutes. In distributed
    400     training, the evaluation job starts asynchronously and might fail to load or
    401     find the checkpoint due to race condition.
    402 
    403     Args:
    404       model_dir: directory where model parameters, graph, etc are saved. If
    405         `PathLike` object, the path will be resolved. If `None`, will use a
    406         default value set by the Estimator.
    407       tf_random_seed: Random seed for TensorFlow initializers.
    408         Setting this value allows consistency between reruns.
    409       save_summary_steps: Save summaries every this many steps.
    410       save_checkpoints_steps: Save checkpoints every this many steps. Can not be
    411           specified with `save_checkpoints_secs`.
    412       save_checkpoints_secs: Save checkpoints every this many seconds. Can not
    413           be specified with `save_checkpoints_steps`. Defaults to 600 seconds if
    414           both `save_checkpoints_steps` and `save_checkpoints_secs` are not set
    415           in constructor.  If both `save_checkpoints_steps` and
    416           `save_checkpoints_secs` are None, then checkpoints are disabled.
    417       session_config: a ConfigProto used to set session parameters, or None.
    418       keep_checkpoint_max: The maximum number of recent checkpoint files to
    419         keep. As new files are created, older files are deleted. If None or 0,
    420         all checkpoint files are kept. Defaults to 5 (that is, the 5 most recent
    421         checkpoint files are kept.)
    422       keep_checkpoint_every_n_hours: Number of hours between each checkpoint
    423         to be saved. The default value of 10,000 hours effectively disables
    424         the feature.
    425       log_step_count_steps: The frequency, in number of global steps, that the
    426         global step/sec will be logged during training.
    427 
    428 
    429     Raises:
    430       ValueError: If both `save_checkpoints_steps` and `save_checkpoints_secs`
    431       are set.
    432     """
    433     if (save_checkpoints_steps == _USE_DEFAULT and
    434         save_checkpoints_secs == _USE_DEFAULT):
    435       save_checkpoints_steps = None
    436       save_checkpoints_secs = 600
    437     elif save_checkpoints_secs == _USE_DEFAULT:
    438       save_checkpoints_secs = None
    439     elif save_checkpoints_steps == _USE_DEFAULT:
    440       save_checkpoints_steps = None
    441     elif (save_checkpoints_steps is not None and
    442           save_checkpoints_secs is not None):
    443       raise ValueError(_SAVE_CKPT_ERR)
    444 
    445     tf_config = json.loads(os.environ.get(_TF_CONFIG_ENV, '{}'))
    446     if tf_config:
    447       logging.info('TF_CONFIG environment variable: %s', tf_config)
    448 
    449     model_dir = _get_model_dir(tf_config,
    450                                compat_internal.path_to_str(model_dir))
    451 
    452     RunConfig._replace(
    453         self,
    454         allowed_properties_list=_DEFAULT_REPLACEABLE_LIST,
    455         model_dir=model_dir,
    456         tf_random_seed=tf_random_seed,
    457         save_summary_steps=save_summary_steps,
    458         save_checkpoints_steps=save_checkpoints_steps,
    459         save_checkpoints_secs=save_checkpoints_secs,
    460         session_config=session_config,
    461         keep_checkpoint_max=keep_checkpoint_max,
    462         keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
    463         log_step_count_steps=log_step_count_steps)
    464 
    465     self._init_distributed_setting_from_environment_var(tf_config)
    466 
    467   def _init_distributed_setting_from_environment_var(self, tf_config):
    468     """Initialize distributed properties based on `tf_config`."""
    469 
    470     self._service = _validate_service(tf_config.get(_SERVICE_KEY))
    471     self._cluster_spec = server_lib.ClusterSpec(tf_config.get(_CLUSTER_KEY, {}))
    472     task_env = tf_config.get(_TASK_ENV_KEY, {})
    473 
    474     if self._cluster_spec and TaskType.MASTER in self._cluster_spec.jobs:
    475       return self._init_distributed_setting_from_environment_var_with_master(
    476           tf_config)
    477 
    478     if self._cluster_spec:
    479       # Distributed mode.
    480       self._task_type, self._task_id = _validate_task_type_and_task_id(
    481           self._cluster_spec, task_env, TaskType.CHIEF)
    482 
    483       self._evaluation_master = _get_eval_session_master(
    484           self._task_type, tf_config)
    485 
    486       if self._task_type != TaskType.EVALUATOR:
    487         self._master = _get_session_master(self._cluster_spec, self._task_type,
    488                                            self._task_id, tf_config)
    489         self._num_ps_replicas = _count_ps(self._cluster_spec)
    490         self._num_worker_replicas = _count_worker(
    491             self._cluster_spec, chief_task_type=TaskType.CHIEF)
    492         self._global_id_in_cluster = _get_global_id_in_cluster(
    493             self._cluster_spec,
    494             self._task_type,
    495             self._task_id,
    496             chief_task_type=TaskType.CHIEF)
    497       else:
    498         # Evaluator is not part of the training cluster.
    499         self._cluster_spec = server_lib.ClusterSpec({})
    500         self._master = _LOCAL_MASTER
    501         self._num_ps_replicas = 0
    502         self._num_worker_replicas = 0
    503         self._global_id_in_cluster = None  # undefined
    504 
    505       self._is_chief = self._task_type == TaskType.CHIEF
    506     else:
    507       # Local mode.
    508       self._task_type = task_env.get(_TASK_TYPE_KEY, TaskType.WORKER)
    509       self._task_id = int(task_env.get(_TASK_ID_KEY, 0))
    510       self._global_id_in_cluster = 0
    511 
    512       if self._task_type != TaskType.WORKER:
    513         raise ValueError(
    514             'If "cluster" is not set in TF_CONFIG, task type must be WORKER.')
    515       if self._task_id != 0:
    516         raise ValueError(
    517             'If "cluster" is not set in TF_CONFIG, task index must be 0.')
    518 
    519       self._master = tf_config.get(_SESSION_MASTER_KEY, _LOCAL_MASTER)
    520       self._evaluation_master = tf_config.get(_EVAL_SESSION_MASTER_KEY,
    521                                               _LOCAL_MASTER)
    522       self._is_chief = True
    523       self._num_ps_replicas = 0
    524       self._num_worker_replicas = 1
    525 
    526   def _init_distributed_setting_from_environment_var_with_master(self,
    527                                                                  tf_config):
    528     """Initialize distributed properties for legacy cluster with `master`."""
    529     # There is no tech reason, why user cannot have chief and master in the same
    530     # cluster, but it is super confusing (which is really the chief?). So, block
    531     # this case.
    532     if TaskType.CHIEF in self._cluster_spec.jobs:
    533       raise ValueError('If `master` node exists in `cluster`, job '
    534                        '`chief` is not supported.')
    535 
    536     task_env = tf_config.get(_TASK_ENV_KEY, {})
    537 
    538     self._task_type, self._task_id = _validate_task_type_and_task_id(
    539         self._cluster_spec, task_env, TaskType.MASTER)
    540 
    541     if self._task_type == TaskType.EVALUATOR:
    542       raise ValueError('If `master` node exists in `cluster`, task_type '
    543                        '`evaluator` is not supported.')
    544 
    545     self._global_id_in_cluster = _get_global_id_in_cluster(
    546         self._cluster_spec,
    547         self._task_type,
    548         self._task_id,
    549         chief_task_type=TaskType.MASTER)
    550 
    551     self._master = _get_session_master(self._cluster_spec, self._task_type,
    552                                        self._task_id, tf_config)
    553     self._evaluation_master = _get_eval_session_master(self._task_type,
    554                                                        tf_config)
    555     self._num_ps_replicas = _count_ps(self._cluster_spec)
    556     self._num_worker_replicas = _count_worker(
    557         self._cluster_spec, chief_task_type=TaskType.MASTER)
    558 
    559     self._is_chief = self._task_type == TaskType.MASTER
    560 
    561   @property
    562   def cluster_spec(self):
    563     return self._cluster_spec
    564 
    565   @property
    566   def evaluation_master(self):
    567     return self._evaluation_master
    568 
    569   @property
    570   def is_chief(self):
    571     return self._is_chief
    572 
    573   @property
    574   def master(self):
    575     return self._master
    576 
    577   @property
    578   def num_ps_replicas(self):
    579     return self._num_ps_replicas
    580 
    581   @property
    582   def num_worker_replicas(self):
    583     return self._num_worker_replicas
    584 
    585   @property
    586   def task_id(self):
    587     return self._task_id
    588 
    589   @property
    590   def global_id_in_cluster(self):
    591     """The global id in the training cluster.
    592 
    593     All global ids in the training cluster are assigned from an increasing
    594     sequence of consecutive integers. The first id is 0.
    595 
    596     Note: Task id (the property field `task_id`) is tracking the index of the
    597     node among all nodes with the SAME task type. For example, given the cluster
    598     definition as follows:
    599 
    600     ```
    601       cluster = {'chief': ['host0:2222'],
    602                  'ps': ['host1:2222', 'host2:2222'],
    603                  'worker': ['host3:2222', 'host4:2222', 'host5:2222']}
    604     ```
    605 
    606     Nodes with task type `worker` can have id 0, 1, 2.  Nodes with task type
    607     `ps` can have id, 0, 1. So, `task_id` is not unique, but the pair
    608     (`task_type`, `task_id`) can uniquely determine a node in the cluster.
    609 
    610     Global id, i.e., this field, is tracking the index of the node among ALL
    611     nodes in the cluster. It is uniquely assigned.  For example, for the cluster
    612     spec given above, the global ids are assigned as:
    613     ```
    614       task_type  | task_id  |  global_id
    615       --------------------------------
    616       chief      | 0        |  0
    617       worker     | 0        |  1
    618       worker     | 1        |  2
    619       worker     | 2        |  3
    620       ps         | 0        |  4
    621       ps         | 1        |  5
    622     ```
    623 
    624     Returns:
    625       An integer id.
    626     """
    627     return self._global_id_in_cluster
    628 
    629   @property
    630   def task_type(self):
    631     return self._task_type
    632 
    633   @property
    634   def tf_random_seed(self):
    635     return self._tf_random_seed
    636 
    637   @property
    638   def save_summary_steps(self):
    639     return self._save_summary_steps
    640 
    641   @property
    642   def save_checkpoints_secs(self):
    643     return self._save_checkpoints_secs
    644 
    645   @property
    646   def session_config(self):
    647     return self._session_config
    648 
    649   @property
    650   def save_checkpoints_steps(self):
    651     return self._save_checkpoints_steps
    652 
    653   @property
    654   def keep_checkpoint_max(self):
    655     return self._keep_checkpoint_max
    656 
    657   @property
    658   def keep_checkpoint_every_n_hours(self):
    659     return self._keep_checkpoint_every_n_hours
    660 
    661   @property
    662   def log_step_count_steps(self):
    663     return self._log_step_count_steps
    664 
    665   @property
    666   def model_dir(self):
    667     return self._model_dir
    668 
    669   @property
    670   def service(self):
    671     """Returns the platform defined (in TF_CONFIG) service dict."""
    672     return self._service
    673 
    674   def replace(self, **kwargs):
    675     """Returns a new instance of `RunConfig` replacing specified properties.
    676 
    677     Only the properties in the following list are allowed to be replaced:
    678 
    679       - `model_dir`.
    680       - `tf_random_seed`,
    681       - `save_summary_steps`,
    682       - `save_checkpoints_steps`,
    683       - `save_checkpoints_secs`,
    684       - `session_config`,
    685       - `keep_checkpoint_max`,
    686       - `keep_checkpoint_every_n_hours`,
    687       - `log_step_count_steps`,
    688 
    689     In addition, either `save_checkpoints_steps` or `save_checkpoints_secs`
    690     can be set (should not be both).
    691 
    692     Args:
    693       **kwargs: keyword named properties with new values.
    694 
    695     Raises:
    696       ValueError: If any property name in `kwargs` does not exist or is not
    697         allowed to be replaced, or both `save_checkpoints_steps` and
    698         `save_checkpoints_secs` are set.
    699 
    700     Returns:
    701       a new instance of `RunConfig`.
    702     """
    703     return RunConfig._replace(
    704         copy.deepcopy(self),
    705         allowed_properties_list=_DEFAULT_REPLACEABLE_LIST,
    706         **kwargs)
    707 
    708   @staticmethod
    709   def _replace(config, allowed_properties_list=None, **kwargs):
    710     """See `replace`.
    711 
    712     N.B.: This implementation assumes that for key named "foo", the underlying
    713     property the RunConfig holds is "_foo" (with one leading underscore).
    714 
    715     Args:
    716       config: The RunConfig to replace the values of.
    717       allowed_properties_list: The property name list allowed to be replaced.
    718       **kwargs: keyword named properties with new values.
    719 
    720     Raises:
    721       ValueError: If any property name in `kwargs` does not exist or is not
    722         allowed to be replaced, or both `save_checkpoints_steps` and
    723         `save_checkpoints_secs` are set.
    724 
    725     Returns:
    726       a new instance of `RunConfig`.
    727     """
    728 
    729     allowed_properties_list = allowed_properties_list or []
    730 
    731     for key, new_value in six.iteritems(kwargs):
    732       if key in allowed_properties_list:
    733         setattr(config, '_' + key, new_value)
    734         continue
    735 
    736       raise ValueError(
    737           'Replacing {} is not supported. Allowed properties are {}.'.format(
    738               key, allowed_properties_list))
    739 
    740     _validate_save_ckpt_with_replaced_keys(config, kwargs.keys())
    741     _validate_properties(config)
    742     return config
    743 
    744 
    745 def _get_model_dir(tf_config, model_dir):
    746   """Returns `model_dir` based user provided `tf_config` or `model_dir`."""
    747   # pylint: disable=g-explicit-bool-comparison
    748 
    749   # Empty string is treated as False in Python condition check, which triggers
    750   # some confusing error messages. For example, 'a or b' returns None if a is ''
    751   # and b is None. `None` is allowed for model_dir but '' is not allowed. Here,
    752   # explicitly check empty string to provide clear error message.
    753   if model_dir == '':
    754     raise ValueError('model_dir should be non-empty.')
    755 
    756   model_dir_in_tf_config = tf_config.get('model_dir')
    757   if model_dir_in_tf_config == '':
    758     raise ValueError('model_dir in TF_CONFIG should be non-empty.')
    759 
    760   if model_dir_in_tf_config:
    761     if model_dir and model_dir_in_tf_config != model_dir:
    762       raise ValueError(
    763           '`model_dir` provided in RunConfig construct, if set, '
    764           'must have the same value as the model_dir in TF_CONFIG. '
    765           'model_dir: {}\nTF_CONFIG["model_dir"]: {}.\n'.format(
    766               model_dir, model_dir_in_tf_config))
    767 
    768     logging.info('Using model_dir in TF_CONFIG: %s', model_dir_in_tf_config)
    769 
    770   return model_dir or model_dir_in_tf_config
    771