1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Environment configuration object for Estimators.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import copy 22 import json 23 import os 24 25 import six 26 27 from tensorflow.core.protobuf import config_pb2 28 from tensorflow.python.platform import tf_logging as logging 29 from tensorflow.python.training import server_lib 30 from tensorflow.python.util import compat_internal 31 from tensorflow.python.util.tf_export import tf_export 32 33 34 _USE_DEFAULT = object() 35 36 # A list of the property names in RunConfig that the user is allowed to change. 37 _DEFAULT_REPLACEABLE_LIST = [ 38 'model_dir', 39 'tf_random_seed', 40 'save_summary_steps', 41 'save_checkpoints_steps', 42 'save_checkpoints_secs', 43 'session_config', 44 'keep_checkpoint_max', 45 'keep_checkpoint_every_n_hours', 46 'log_step_count_steps' 47 ] 48 49 _SAVE_CKPT_ERR = ( 50 '`save_checkpoints_steps` and `save_checkpoints_secs` cannot be both set.' 51 ) 52 53 _TF_CONFIG_ENV = 'TF_CONFIG' 54 _TASK_ENV_KEY = 'task' 55 _TASK_TYPE_KEY = 'type' 56 _TASK_ID_KEY = 'index' 57 _CLUSTER_KEY = 'cluster' 58 _SERVICE_KEY = 'service' 59 _SESSION_MASTER_KEY = 'session_master' 60 _EVAL_SESSION_MASTER_KEY = 'eval_session_master' 61 _MODEL_DIR_KEY = 'model_dir' 62 _LOCAL_MASTER = '' 63 _GRPC_SCHEME = 'grpc://' 64 65 66 def _get_session_master(cluster_spec, task_type, task_id, tf_config): 67 """Returns the appropriate address for TensorFlow master. 68 69 The order of precedence to deteremine the TF session master is as follows: 70 1. If `tf_session_master` is set in TF_CONFIG environment variable, takes it. 71 2. If the cluster has only one node, returns empty string ''. 72 3. Returns the grpc address according to the task type and id in the cluster. 73 This is between-graph replication. 74 75 Note: task_type and task_id must be validated. Typically, validated using 76 `_validate_task_type_and_task_id`. 77 78 Args: 79 cluster_spec: A `ClusterSpec` instance. 80 task_type: String. Task type for current node. 81 task_id: Int. Task id for current node. 82 tf_config: Dict. Python dict for the TF_CONFIG environment variable. 83 84 Raises: 85 RuntimeError: If `cluster_spec` is not set. 86 87 """ 88 if _SESSION_MASTER_KEY in tf_config: 89 return tf_config[_SESSION_MASTER_KEY] 90 91 if not cluster_spec: 92 raise RuntimeError('Internal error: `_get_session_master` ' 93 'does not expect empty cluster_spec.') 94 95 jobs = cluster_spec.jobs 96 97 # If there is only one node in the cluster, do things locally by setting 98 # master to ''. If a service or user sets TF_CONFIG with a single node, it's 99 # more performant to use a direct master rather than an RPC service. 100 if len(jobs) == 1 and len(cluster_spec.job_tasks(jobs[0])) == 1: 101 return _LOCAL_MASTER 102 103 # Lookup the master in cluster_spec using task_type and task_id, 104 # if possible. 105 addresses = cluster_spec.job_tasks(task_type) 106 return _GRPC_SCHEME + addresses[task_id] 107 108 109 def _get_eval_session_master(task_type, tf_config): 110 """Returns the appropriate address for TensorFlow evaluation master.""" 111 if task_type == TaskType.EVALUATOR: 112 return tf_config.get(_EVAL_SESSION_MASTER_KEY, _LOCAL_MASTER) 113 114 if _EVAL_SESSION_MASTER_KEY in tf_config: 115 raise ValueError('Key ({}) should not be set for task type other than {}. ' 116 'Task type: {}'.format(_EVAL_SESSION_MASTER_KEY, 117 TaskType.EVALUATOR, task_type)) 118 return _LOCAL_MASTER 119 120 121 def _count_ps(cluster_spec): 122 """Counts the number of parameter servers in cluster_spec.""" 123 if not cluster_spec: 124 raise RuntimeError( 125 'Internal error: `_count_ps` does not expect empty cluster_spec.') 126 127 return len(cluster_spec.as_dict().get(TaskType.PS, [])) 128 129 130 def _count_worker(cluster_spec, chief_task_type): 131 """Counts the number of workers (including chief) in cluster_spec.""" 132 if not cluster_spec: 133 raise RuntimeError( 134 'Internal error: `_count_worker` does not expect empty cluster_spec.') 135 136 return (len(cluster_spec.as_dict().get(TaskType.WORKER, [])) + 137 len(cluster_spec.as_dict().get(chief_task_type, []))) 138 139 140 def _validate_service(service): 141 """Validates the service key.""" 142 if service is not None and not isinstance(service, dict): 143 raise TypeError( 144 'If "service" is set in TF_CONFIG, it must be a dict. Given %s' % 145 type(service)) 146 return service 147 148 149 def _validate_task_type_and_task_id(cluster_spec, task_env, chief_task_type): 150 """Validates the task type and index in `task_env` according to cluster.""" 151 if chief_task_type not in cluster_spec.jobs: 152 raise ValueError( 153 'If "cluster" is set in TF_CONFIG, it must have one "%s" node.' % 154 chief_task_type) 155 if len(cluster_spec.job_tasks(chief_task_type)) > 1: 156 raise ValueError( 157 'The "cluster" in TF_CONFIG must have only one "%s" node.' % 158 chief_task_type) 159 160 task_type = task_env.get(_TASK_TYPE_KEY, None) 161 task_id = task_env.get(_TASK_ID_KEY, None) 162 163 if not task_type: 164 raise ValueError( 165 'If "cluster" is set in TF_CONFIG, task type must be set.') 166 if task_id is None: 167 raise ValueError( 168 'If "cluster" is set in TF_CONFIG, task index must be set.') 169 170 task_id = int(task_id) 171 172 # Check the task id bounds. Upper bound is not necessary as 173 # - for evaluator, there is no upper bound. 174 # - for non-evaluator, task id is upper bounded by the number of jobs in 175 # cluster spec, which will be checked later (when retrieving the `master`) 176 if task_id < 0: 177 raise ValueError('Task index must be non-negative number.') 178 179 # Evaluator is not part of the training cluster. 180 if task_type == TaskType.EVALUATOR: 181 return task_type, task_id 182 183 if task_type not in cluster_spec.jobs: 184 raise ValueError( 185 '%s is not a valid task_type in the cluster_spec:\n' 186 '%s\n\n' 187 'Note that these values may be coming from the TF_CONFIG environment ' 188 'variable.' % (task_type, cluster_spec)) 189 addresses = cluster_spec.job_tasks(task_type) 190 if not 0 <= task_id < len(addresses): 191 raise ValueError( 192 '%d is not a valid task_id for task_type %s in the cluster_spec:\n' 193 '%s\n\n' 194 'Note that these values may be coming from the TF_CONFIG environment ' 195 'variable.' % (task_id, task_type, cluster_spec)) 196 197 return task_type, task_id 198 199 200 def _get_global_id_in_cluster( 201 cluster_spec, task_type, task_id, chief_task_type): 202 """Returns the global id in cluster.""" 203 # Note: This is implementation details, which user should not rely on. 204 # The first id is 0, which is always for the `chief` node. All other nodes, 205 # except `ps`, are ordered alphabetical based on task type (alphabetically) 206 # and task id (ascendingly). `ps` are ordered last. 207 208 # Sort task names in cluster 209 task_type_ordered_list = [chief_task_type] 210 task_type_ordered_list.extend([ 211 t for t in sorted(cluster_spec.jobs) 212 if t != chief_task_type and t != TaskType.PS 213 ]) 214 if TaskType.PS in cluster_spec.jobs: 215 task_type_ordered_list.append(TaskType.PS) 216 217 next_global_id = 0 218 for t in task_type_ordered_list: 219 if t == task_type: 220 return next_global_id + task_id 221 next_global_id += len(cluster_spec.job_tasks(t)) 222 223 # This should never happen. 224 raise RuntimeError('Internal Error: `task_type` ({}) is not in ' 225 'cluster_spec ({}).'.format(task_type, cluster_spec)) 226 227 228 def _validate_save_ckpt_with_replaced_keys(new_copy, replaced_keys): 229 """Validates the save ckpt properties.""" 230 # Ensure one (and only one) of save_steps and save_secs is not None. 231 # Also, if user sets one save ckpt property, say steps, the other one (secs) 232 # should be set as None to improve usability. 233 234 save_steps = new_copy.save_checkpoints_steps 235 save_secs = new_copy.save_checkpoints_secs 236 237 if ('save_checkpoints_steps' in replaced_keys and 238 'save_checkpoints_secs' in replaced_keys): 239 # If user sets both properties explicitly, we need to error out if both 240 # are set or neither of them are set. 241 if save_steps is not None and save_secs is not None: 242 raise ValueError(_SAVE_CKPT_ERR) 243 elif 'save_checkpoints_steps' in replaced_keys and save_steps is not None: 244 new_copy._save_checkpoints_secs = None # pylint: disable=protected-access 245 elif 'save_checkpoints_secs' in replaced_keys and save_secs is not None: 246 new_copy._save_checkpoints_steps = None # pylint: disable=protected-access 247 248 249 def _validate_properties(run_config): 250 """Validates the properties.""" 251 def _validate(property_name, cond, message): 252 property_value = getattr(run_config, property_name) 253 if property_value is not None and not cond(property_value): 254 raise ValueError(message) 255 256 _validate('model_dir', lambda dir: dir, 257 message='model_dir should be non-empty') 258 259 _validate('save_summary_steps', lambda steps: steps >= 0, 260 message='save_summary_steps should be >= 0') 261 262 _validate('save_checkpoints_steps', lambda steps: steps >= 0, 263 message='save_checkpoints_steps should be >= 0') 264 _validate('save_checkpoints_secs', lambda secs: secs >= 0, 265 message='save_checkpoints_secs should be >= 0') 266 267 _validate('session_config', 268 lambda sc: isinstance(sc, config_pb2.ConfigProto), 269 message='session_config must be instance of ConfigProto') 270 271 _validate('keep_checkpoint_max', lambda keep_max: keep_max >= 0, 272 message='keep_checkpoint_max should be >= 0') 273 _validate('keep_checkpoint_every_n_hours', lambda keep_hours: keep_hours > 0, 274 message='keep_checkpoint_every_n_hours should be > 0') 275 _validate('log_step_count_steps', lambda num_steps: num_steps > 0, 276 message='log_step_count_steps should be > 0') 277 278 _validate('tf_random_seed', lambda seed: isinstance(seed, six.integer_types), 279 message='tf_random_seed must be integer.') 280 281 282 class TaskType(object): 283 MASTER = 'master' 284 PS = 'ps' 285 WORKER = 'worker' 286 CHIEF = 'chief' 287 EVALUATOR = 'evaluator' 288 289 290 @tf_export('estimator.RunConfig') 291 class RunConfig(object): 292 """This class specifies the configurations for an `Estimator` run.""" 293 294 def __init__(self, 295 model_dir=None, 296 tf_random_seed=None, 297 save_summary_steps=100, 298 save_checkpoints_steps=_USE_DEFAULT, 299 save_checkpoints_secs=_USE_DEFAULT, 300 session_config=None, 301 keep_checkpoint_max=5, 302 keep_checkpoint_every_n_hours=10000, 303 log_step_count_steps=100): 304 """Constructs a RunConfig. 305 306 All distributed training related properties `cluster_spec`, `is_chief`, 307 `master` , `num_worker_replicas`, `num_ps_replicas`, `task_id`, and 308 `task_type` are set based on the `TF_CONFIG` environment variable, if the 309 pertinent information is present. The `TF_CONFIG` environment variable is a 310 JSON object with attributes: `cluster` and `task`. 311 312 `cluster` is a JSON serialized version of `ClusterSpec`'s Python dict from 313 `server_lib.py`, mapping task types (usually one of the `TaskType` enums) to 314 a list of task addresses. 315 316 `task` has two attributes: `type` and `index`, where `type` can be any of 317 the task types in `cluster`. ` When `TF_CONFIG` contains said information, 318 the following properties are set on this class: 319 320 * `cluster_spec` is parsed from `TF_CONFIG['cluster']`. Defaults to {}. If 321 present, must have one and only one node in the `chief` attribute of 322 `cluster_spec`. 323 * `task_type` is set to `TF_CONFIG['task']['type']`. Must set if 324 `cluster_spec` is present; must be `worker` (the default value) if 325 `cluster_spec` is not set. 326 * `task_id` is set to `TF_CONFIG['task']['index']`. Must set if 327 `cluster_spec` is present; must be 0 (the default value) if 328 `cluster_spec` is not set. 329 * `master` is determined by looking up `task_type` and `task_id` in the 330 `cluster_spec`. Defaults to ''. 331 * `num_ps_replicas` is set by counting the number of nodes listed 332 in the `ps` attribute of `cluster_spec`. Defaults to 0. 333 * `num_worker_replicas` is set by counting the number of nodes listed 334 in the `worker` and `chief` attributes of `cluster_spec`. Defaults to 1. 335 * `is_chief` is determined based on `task_type` and `cluster`. 336 337 There is a special node with `task_type` as `evaluator`, which is not part 338 of the (training) `cluster_spec`. It handles the distributed evaluation job. 339 340 Example of non-chief node: 341 ``` 342 cluster = {'chief': ['host0:2222'], 343 'ps': ['host1:2222', 'host2:2222'], 344 'worker': ['host3:2222', 'host4:2222', 'host5:2222']} 345 os.environ['TF_CONFIG'] = json.dumps( 346 {'cluster': cluster, 347 'task': {'type': 'worker', 'index': 1}}) 348 config = ClusterConfig() 349 assert config.master == 'host4:2222' 350 assert config.task_id == 1 351 assert config.num_ps_replicas == 2 352 assert config.num_worker_replicas == 4 353 assert config.cluster_spec == server_lib.ClusterSpec(cluster) 354 assert config.task_type == 'worker' 355 assert not config.is_chief 356 ``` 357 358 Example of chief node: 359 ``` 360 cluster = {'chief': ['host0:2222'], 361 'ps': ['host1:2222', 'host2:2222'], 362 'worker': ['host3:2222', 'host4:2222', 'host5:2222']} 363 os.environ['TF_CONFIG'] = json.dumps( 364 {'cluster': cluster, 365 'task': {'type': 'chief', 'index': 0}}) 366 config = ClusterConfig() 367 assert config.master == 'host0:2222' 368 assert config.task_id == 0 369 assert config.num_ps_replicas == 2 370 assert config.num_worker_replicas == 4 371 assert config.cluster_spec == server_lib.ClusterSpec(cluster) 372 assert config.task_type == 'chief' 373 assert config.is_chief 374 ``` 375 376 Example of evaluator node (evaluator is not part of training cluster): 377 ``` 378 cluster = {'chief': ['host0:2222'], 379 'ps': ['host1:2222', 'host2:2222'], 380 'worker': ['host3:2222', 'host4:2222', 'host5:2222']} 381 os.environ['TF_CONFIG'] = json.dumps( 382 {'cluster': cluster, 383 'task': {'type': 'evaluator', 'index': 0}}) 384 config = ClusterConfig() 385 assert config.master == '' 386 assert config.evaluator_master == '' 387 assert config.task_id == 0 388 assert config.num_ps_replicas == 0 389 assert config.num_worker_replicas == 0 390 assert config.cluster_spec == {} 391 assert config.task_type == 'evaluator' 392 assert not config.is_chief 393 ``` 394 395 N.B.: If `save_checkpoints_steps` or `save_checkpoints_secs` is set, 396 `keep_checkpoint_max` might need to be adjusted accordingly, especially in 397 distributed training. For example, setting `save_checkpoints_secs` as 60 398 without adjusting `keep_checkpoint_max` (defaults to 5) leads to situation 399 that checkpoint would be garbage collected after 5 minutes. In distributed 400 training, the evaluation job starts asynchronously and might fail to load or 401 find the checkpoint due to race condition. 402 403 Args: 404 model_dir: directory where model parameters, graph, etc are saved. If 405 `PathLike` object, the path will be resolved. If `None`, will use a 406 default value set by the Estimator. 407 tf_random_seed: Random seed for TensorFlow initializers. 408 Setting this value allows consistency between reruns. 409 save_summary_steps: Save summaries every this many steps. 410 save_checkpoints_steps: Save checkpoints every this many steps. Can not be 411 specified with `save_checkpoints_secs`. 412 save_checkpoints_secs: Save checkpoints every this many seconds. Can not 413 be specified with `save_checkpoints_steps`. Defaults to 600 seconds if 414 both `save_checkpoints_steps` and `save_checkpoints_secs` are not set 415 in constructor. If both `save_checkpoints_steps` and 416 `save_checkpoints_secs` are None, then checkpoints are disabled. 417 session_config: a ConfigProto used to set session parameters, or None. 418 keep_checkpoint_max: The maximum number of recent checkpoint files to 419 keep. As new files are created, older files are deleted. If None or 0, 420 all checkpoint files are kept. Defaults to 5 (that is, the 5 most recent 421 checkpoint files are kept.) 422 keep_checkpoint_every_n_hours: Number of hours between each checkpoint 423 to be saved. The default value of 10,000 hours effectively disables 424 the feature. 425 log_step_count_steps: The frequency, in number of global steps, that the 426 global step/sec will be logged during training. 427 428 429 Raises: 430 ValueError: If both `save_checkpoints_steps` and `save_checkpoints_secs` 431 are set. 432 """ 433 if (save_checkpoints_steps == _USE_DEFAULT and 434 save_checkpoints_secs == _USE_DEFAULT): 435 save_checkpoints_steps = None 436 save_checkpoints_secs = 600 437 elif save_checkpoints_secs == _USE_DEFAULT: 438 save_checkpoints_secs = None 439 elif save_checkpoints_steps == _USE_DEFAULT: 440 save_checkpoints_steps = None 441 elif (save_checkpoints_steps is not None and 442 save_checkpoints_secs is not None): 443 raise ValueError(_SAVE_CKPT_ERR) 444 445 tf_config = json.loads(os.environ.get(_TF_CONFIG_ENV, '{}')) 446 if tf_config: 447 logging.info('TF_CONFIG environment variable: %s', tf_config) 448 449 model_dir = _get_model_dir(tf_config, 450 compat_internal.path_to_str(model_dir)) 451 452 RunConfig._replace( 453 self, 454 allowed_properties_list=_DEFAULT_REPLACEABLE_LIST, 455 model_dir=model_dir, 456 tf_random_seed=tf_random_seed, 457 save_summary_steps=save_summary_steps, 458 save_checkpoints_steps=save_checkpoints_steps, 459 save_checkpoints_secs=save_checkpoints_secs, 460 session_config=session_config, 461 keep_checkpoint_max=keep_checkpoint_max, 462 keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, 463 log_step_count_steps=log_step_count_steps) 464 465 self._init_distributed_setting_from_environment_var(tf_config) 466 467 def _init_distributed_setting_from_environment_var(self, tf_config): 468 """Initialize distributed properties based on `tf_config`.""" 469 470 self._service = _validate_service(tf_config.get(_SERVICE_KEY)) 471 self._cluster_spec = server_lib.ClusterSpec(tf_config.get(_CLUSTER_KEY, {})) 472 task_env = tf_config.get(_TASK_ENV_KEY, {}) 473 474 if self._cluster_spec and TaskType.MASTER in self._cluster_spec.jobs: 475 return self._init_distributed_setting_from_environment_var_with_master( 476 tf_config) 477 478 if self._cluster_spec: 479 # Distributed mode. 480 self._task_type, self._task_id = _validate_task_type_and_task_id( 481 self._cluster_spec, task_env, TaskType.CHIEF) 482 483 self._evaluation_master = _get_eval_session_master( 484 self._task_type, tf_config) 485 486 if self._task_type != TaskType.EVALUATOR: 487 self._master = _get_session_master(self._cluster_spec, self._task_type, 488 self._task_id, tf_config) 489 self._num_ps_replicas = _count_ps(self._cluster_spec) 490 self._num_worker_replicas = _count_worker( 491 self._cluster_spec, chief_task_type=TaskType.CHIEF) 492 self._global_id_in_cluster = _get_global_id_in_cluster( 493 self._cluster_spec, 494 self._task_type, 495 self._task_id, 496 chief_task_type=TaskType.CHIEF) 497 else: 498 # Evaluator is not part of the training cluster. 499 self._cluster_spec = server_lib.ClusterSpec({}) 500 self._master = _LOCAL_MASTER 501 self._num_ps_replicas = 0 502 self._num_worker_replicas = 0 503 self._global_id_in_cluster = None # undefined 504 505 self._is_chief = self._task_type == TaskType.CHIEF 506 else: 507 # Local mode. 508 self._task_type = task_env.get(_TASK_TYPE_KEY, TaskType.WORKER) 509 self._task_id = int(task_env.get(_TASK_ID_KEY, 0)) 510 self._global_id_in_cluster = 0 511 512 if self._task_type != TaskType.WORKER: 513 raise ValueError( 514 'If "cluster" is not set in TF_CONFIG, task type must be WORKER.') 515 if self._task_id != 0: 516 raise ValueError( 517 'If "cluster" is not set in TF_CONFIG, task index must be 0.') 518 519 self._master = tf_config.get(_SESSION_MASTER_KEY, _LOCAL_MASTER) 520 self._evaluation_master = tf_config.get(_EVAL_SESSION_MASTER_KEY, 521 _LOCAL_MASTER) 522 self._is_chief = True 523 self._num_ps_replicas = 0 524 self._num_worker_replicas = 1 525 526 def _init_distributed_setting_from_environment_var_with_master(self, 527 tf_config): 528 """Initialize distributed properties for legacy cluster with `master`.""" 529 # There is no tech reason, why user cannot have chief and master in the same 530 # cluster, but it is super confusing (which is really the chief?). So, block 531 # this case. 532 if TaskType.CHIEF in self._cluster_spec.jobs: 533 raise ValueError('If `master` node exists in `cluster`, job ' 534 '`chief` is not supported.') 535 536 task_env = tf_config.get(_TASK_ENV_KEY, {}) 537 538 self._task_type, self._task_id = _validate_task_type_and_task_id( 539 self._cluster_spec, task_env, TaskType.MASTER) 540 541 if self._task_type == TaskType.EVALUATOR: 542 raise ValueError('If `master` node exists in `cluster`, task_type ' 543 '`evaluator` is not supported.') 544 545 self._global_id_in_cluster = _get_global_id_in_cluster( 546 self._cluster_spec, 547 self._task_type, 548 self._task_id, 549 chief_task_type=TaskType.MASTER) 550 551 self._master = _get_session_master(self._cluster_spec, self._task_type, 552 self._task_id, tf_config) 553 self._evaluation_master = _get_eval_session_master(self._task_type, 554 tf_config) 555 self._num_ps_replicas = _count_ps(self._cluster_spec) 556 self._num_worker_replicas = _count_worker( 557 self._cluster_spec, chief_task_type=TaskType.MASTER) 558 559 self._is_chief = self._task_type == TaskType.MASTER 560 561 @property 562 def cluster_spec(self): 563 return self._cluster_spec 564 565 @property 566 def evaluation_master(self): 567 return self._evaluation_master 568 569 @property 570 def is_chief(self): 571 return self._is_chief 572 573 @property 574 def master(self): 575 return self._master 576 577 @property 578 def num_ps_replicas(self): 579 return self._num_ps_replicas 580 581 @property 582 def num_worker_replicas(self): 583 return self._num_worker_replicas 584 585 @property 586 def task_id(self): 587 return self._task_id 588 589 @property 590 def global_id_in_cluster(self): 591 """The global id in the training cluster. 592 593 All global ids in the training cluster are assigned from an increasing 594 sequence of consecutive integers. The first id is 0. 595 596 Note: Task id (the property field `task_id`) is tracking the index of the 597 node among all nodes with the SAME task type. For example, given the cluster 598 definition as follows: 599 600 ``` 601 cluster = {'chief': ['host0:2222'], 602 'ps': ['host1:2222', 'host2:2222'], 603 'worker': ['host3:2222', 'host4:2222', 'host5:2222']} 604 ``` 605 606 Nodes with task type `worker` can have id 0, 1, 2. Nodes with task type 607 `ps` can have id, 0, 1. So, `task_id` is not unique, but the pair 608 (`task_type`, `task_id`) can uniquely determine a node in the cluster. 609 610 Global id, i.e., this field, is tracking the index of the node among ALL 611 nodes in the cluster. It is uniquely assigned. For example, for the cluster 612 spec given above, the global ids are assigned as: 613 ``` 614 task_type | task_id | global_id 615 -------------------------------- 616 chief | 0 | 0 617 worker | 0 | 1 618 worker | 1 | 2 619 worker | 2 | 3 620 ps | 0 | 4 621 ps | 1 | 5 622 ``` 623 624 Returns: 625 An integer id. 626 """ 627 return self._global_id_in_cluster 628 629 @property 630 def task_type(self): 631 return self._task_type 632 633 @property 634 def tf_random_seed(self): 635 return self._tf_random_seed 636 637 @property 638 def save_summary_steps(self): 639 return self._save_summary_steps 640 641 @property 642 def save_checkpoints_secs(self): 643 return self._save_checkpoints_secs 644 645 @property 646 def session_config(self): 647 return self._session_config 648 649 @property 650 def save_checkpoints_steps(self): 651 return self._save_checkpoints_steps 652 653 @property 654 def keep_checkpoint_max(self): 655 return self._keep_checkpoint_max 656 657 @property 658 def keep_checkpoint_every_n_hours(self): 659 return self._keep_checkpoint_every_n_hours 660 661 @property 662 def log_step_count_steps(self): 663 return self._log_step_count_steps 664 665 @property 666 def model_dir(self): 667 return self._model_dir 668 669 @property 670 def service(self): 671 """Returns the platform defined (in TF_CONFIG) service dict.""" 672 return self._service 673 674 def replace(self, **kwargs): 675 """Returns a new instance of `RunConfig` replacing specified properties. 676 677 Only the properties in the following list are allowed to be replaced: 678 679 - `model_dir`. 680 - `tf_random_seed`, 681 - `save_summary_steps`, 682 - `save_checkpoints_steps`, 683 - `save_checkpoints_secs`, 684 - `session_config`, 685 - `keep_checkpoint_max`, 686 - `keep_checkpoint_every_n_hours`, 687 - `log_step_count_steps`, 688 689 In addition, either `save_checkpoints_steps` or `save_checkpoints_secs` 690 can be set (should not be both). 691 692 Args: 693 **kwargs: keyword named properties with new values. 694 695 Raises: 696 ValueError: If any property name in `kwargs` does not exist or is not 697 allowed to be replaced, or both `save_checkpoints_steps` and 698 `save_checkpoints_secs` are set. 699 700 Returns: 701 a new instance of `RunConfig`. 702 """ 703 return RunConfig._replace( 704 copy.deepcopy(self), 705 allowed_properties_list=_DEFAULT_REPLACEABLE_LIST, 706 **kwargs) 707 708 @staticmethod 709 def _replace(config, allowed_properties_list=None, **kwargs): 710 """See `replace`. 711 712 N.B.: This implementation assumes that for key named "foo", the underlying 713 property the RunConfig holds is "_foo" (with one leading underscore). 714 715 Args: 716 config: The RunConfig to replace the values of. 717 allowed_properties_list: The property name list allowed to be replaced. 718 **kwargs: keyword named properties with new values. 719 720 Raises: 721 ValueError: If any property name in `kwargs` does not exist or is not 722 allowed to be replaced, or both `save_checkpoints_steps` and 723 `save_checkpoints_secs` are set. 724 725 Returns: 726 a new instance of `RunConfig`. 727 """ 728 729 allowed_properties_list = allowed_properties_list or [] 730 731 for key, new_value in six.iteritems(kwargs): 732 if key in allowed_properties_list: 733 setattr(config, '_' + key, new_value) 734 continue 735 736 raise ValueError( 737 'Replacing {} is not supported. Allowed properties are {}.'.format( 738 key, allowed_properties_list)) 739 740 _validate_save_ckpt_with_replaced_keys(config, kwargs.keys()) 741 _validate_properties(config) 742 return config 743 744 745 def _get_model_dir(tf_config, model_dir): 746 """Returns `model_dir` based user provided `tf_config` or `model_dir`.""" 747 # pylint: disable=g-explicit-bool-comparison 748 749 # Empty string is treated as False in Python condition check, which triggers 750 # some confusing error messages. For example, 'a or b' returns None if a is '' 751 # and b is None. `None` is allowed for model_dir but '' is not allowed. Here, 752 # explicitly check empty string to provide clear error message. 753 if model_dir == '': 754 raise ValueError('model_dir should be non-empty.') 755 756 model_dir_in_tf_config = tf_config.get('model_dir') 757 if model_dir_in_tf_config == '': 758 raise ValueError('model_dir in TF_CONFIG should be non-empty.') 759 760 if model_dir_in_tf_config: 761 if model_dir and model_dir_in_tf_config != model_dir: 762 raise ValueError( 763 '`model_dir` provided in RunConfig construct, if set, ' 764 'must have the same value as the model_dir in TF_CONFIG. ' 765 'model_dir: {}\nTF_CONFIG["model_dir"]: {}.\n'.format( 766 model_dir, model_dir_in_tf_config)) 767 768 logging.info('Using model_dir in TF_CONFIG: %s', model_dir_in_tf_config) 769 770 return model_dir or model_dir_in_tf_config 771