1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Training helper that checkpoints models and creates session.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import time 21 import numpy as np 22 23 from tensorflow.python.client import session 24 from tensorflow.python.framework import errors 25 from tensorflow.python.framework import ops 26 from tensorflow.python.platform import tf_logging as logging 27 from tensorflow.python.training import saver as saver_mod 28 from tensorflow.python.util.tf_export import tf_export 29 30 31 def _maybe_name(obj): 32 """Returns object name if it has one, or a message otherwise. 33 34 This is useful for names that apper in error messages. 35 Args: 36 obj: Object to get the name of. 37 Returns: 38 name, "None", or a "no name" message. 39 """ 40 if obj is None: 41 return "None" 42 elif hasattr(obj, "name"): 43 return obj.name 44 else: 45 return "<no name for %s>" % type(obj) 46 47 48 @tf_export("train.SessionManager") 49 class SessionManager(object): 50 """Training helper that restores from checkpoint and creates session. 51 52 This class is a small wrapper that takes care of session creation and 53 checkpoint recovery. It also provides functions that to facilitate 54 coordination among multiple training threads or processes. 55 56 * Checkpointing trained variables as the training progresses. 57 * Initializing variables on startup, restoring them from the most recent 58 checkpoint after a crash, or wait for checkpoints to become available. 59 60 ### Usage: 61 62 ```python 63 with tf.Graph().as_default(): 64 ...add operations to the graph... 65 # Create a SessionManager that will checkpoint the model in '/tmp/mydir'. 66 sm = SessionManager() 67 sess = sm.prepare_session(master, init_op, saver, checkpoint_dir) 68 # Use the session to train the graph. 69 while True: 70 sess.run(<my_train_op>) 71 ``` 72 73 `prepare_session()` initializes or restores a model. It requires `init_op` 74 and `saver` as an argument. 75 76 A second process could wait for the model to be ready by doing the following: 77 78 ```python 79 with tf.Graph().as_default(): 80 ...add operations to the graph... 81 # Create a SessionManager that will wait for the model to become ready. 82 sm = SessionManager() 83 sess = sm.wait_for_session(master) 84 # Use the session to train the graph. 85 while True: 86 sess.run(<my_train_op>) 87 ``` 88 89 `wait_for_session()` waits for a model to be initialized by other processes. 90 91 """ 92 93 def __init__(self, 94 local_init_op=None, 95 ready_op=None, 96 ready_for_local_init_op=None, 97 graph=None, 98 recovery_wait_secs=30): 99 """Creates a SessionManager. 100 101 The `local_init_op` is an `Operation` that is run always after a new session 102 was created. If `None`, this step is skipped. 103 104 The `ready_op` is an `Operation` used to check if the model is ready. The 105 model is considered ready if that operation returns an empty 1D string 106 tensor. If the operation returns a non empty 1D string tensor, the elements 107 are concatenated and used to indicate to the user why the model is not 108 ready. 109 110 The `ready_for_local_init_op` is an `Operation` used to check if the model 111 is ready to run local_init_op. The model is considered ready if that 112 operation returns an empty 1D string tensor. If the operation returns a non 113 empty 1D string tensor, the elements are concatenated and used to indicate 114 to the user why the model is not ready. 115 116 If `ready_op` is `None`, the model is not checked for readiness. 117 118 `recovery_wait_secs` is the number of seconds between checks that 119 the model is ready. It is used by processes to wait for a model to 120 be initialized or restored. Defaults to 30 seconds. 121 122 Args: 123 local_init_op: An `Operation` run immediately after session creation. 124 Usually used to initialize tables and local variables. 125 ready_op: An `Operation` to check if the model is initialized. 126 ready_for_local_init_op: An `Operation` to check if the model is ready 127 to run local_init_op. 128 graph: The `Graph` that the model will use. 129 recovery_wait_secs: Seconds between checks for the model to be ready. 130 131 Raises: 132 ValueError: If ready_for_local_init_op is not None but local_init_op is 133 None 134 """ 135 # Sets default values of arguments. 136 if graph is None: 137 graph = ops.get_default_graph() 138 self._local_init_op = local_init_op 139 self._ready_op = ready_op 140 self._ready_for_local_init_op = ready_for_local_init_op 141 self._graph = graph 142 self._recovery_wait_secs = recovery_wait_secs 143 self._target = None 144 if ready_for_local_init_op is not None and local_init_op is None: 145 raise ValueError("If you pass a ready_for_local_init_op " 146 "you must also pass a local_init_op " 147 ", ready_for_local_init_op [%s]" % 148 ready_for_local_init_op) 149 150 def _restore_checkpoint(self, 151 master, 152 saver=None, 153 checkpoint_dir=None, 154 checkpoint_filename_with_path=None, 155 wait_for_checkpoint=False, 156 max_wait_secs=7200, 157 config=None): 158 """Creates a `Session`, and tries to restore a checkpoint. 159 160 161 Args: 162 master: `String` representation of the TensorFlow master to use. 163 saver: A `Saver` object used to restore a model. 164 checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the 165 dir will be used to restore. 166 checkpoint_filename_with_path: Full file name path to the checkpoint file. 167 wait_for_checkpoint: Whether to wait for checkpoint to become available. 168 max_wait_secs: Maximum time to wait for checkpoints to become available. 169 config: Optional `ConfigProto` proto used to configure the session. 170 171 Returns: 172 A pair (sess, is_restored) where 'is_restored' is `True` if 173 the session could be restored, `False` otherwise. 174 175 Raises: 176 ValueError: If both checkpoint_dir and checkpoint_filename_with_path are 177 set. 178 """ 179 self._target = master 180 sess = session.Session(self._target, graph=self._graph, config=config) 181 182 if checkpoint_dir and checkpoint_filename_with_path: 183 raise ValueError("Can not provide both checkpoint_dir and " 184 "checkpoint_filename_with_path.") 185 # If either saver or checkpoint_* is not specified, cannot restore. Just 186 # return. 187 if not saver or not (checkpoint_dir or checkpoint_filename_with_path): 188 return sess, False 189 190 if checkpoint_filename_with_path: 191 saver.restore(sess, checkpoint_filename_with_path) 192 return sess, True 193 194 # Waits up until max_wait_secs for checkpoint to become available. 195 wait_time = 0 196 ckpt = saver_mod.get_checkpoint_state(checkpoint_dir) 197 while not ckpt or not ckpt.model_checkpoint_path: 198 if wait_for_checkpoint and wait_time < max_wait_secs: 199 logging.info("Waiting for checkpoint to be available.") 200 time.sleep(self._recovery_wait_secs) 201 wait_time += self._recovery_wait_secs 202 ckpt = saver_mod.get_checkpoint_state(checkpoint_dir) 203 else: 204 return sess, False 205 206 # Loads the checkpoint. 207 saver.restore(sess, ckpt.model_checkpoint_path) 208 saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths) 209 return sess, True 210 211 def prepare_session(self, 212 master, 213 init_op=None, 214 saver=None, 215 checkpoint_dir=None, 216 checkpoint_filename_with_path=None, 217 wait_for_checkpoint=False, 218 max_wait_secs=7200, 219 config=None, 220 init_feed_dict=None, 221 init_fn=None): 222 """Creates a `Session`. Makes sure the model is ready to be used. 223 224 Creates a `Session` on 'master'. If a `saver` object is passed in, and 225 `checkpoint_dir` points to a directory containing valid checkpoint 226 files, then it will try to recover the model from checkpoint. If 227 no checkpoint files are available, and `wait_for_checkpoint` is 228 `True`, then the process would check every `recovery_wait_secs`, 229 up to `max_wait_secs`, for recovery to succeed. 230 231 If the model cannot be recovered successfully then it is initialized by 232 either running the provided `init_op`, or calling the provided `init_fn`. 233 The local_init_op is also run after init_op and init_fn, regardless of 234 whether the model was recovered successfully, but only if 235 ready_for_local_init_op passes. 236 237 It is an error if the model cannot be recovered and no `init_op` 238 or `init_fn` or `local_init_op` are passed. 239 240 Args: 241 master: `String` representation of the TensorFlow master to use. 242 init_op: Optional `Operation` used to initialize the model. 243 saver: A `Saver` object used to restore a model. 244 checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the 245 dir will be used to restore. 246 checkpoint_filename_with_path: Full file name path to the checkpoint file. 247 wait_for_checkpoint: Whether to wait for checkpoint to become available. 248 max_wait_secs: Maximum time to wait for checkpoints to become available. 249 config: Optional `ConfigProto` proto used to configure the session. 250 init_feed_dict: Optional dictionary that maps `Tensor` objects to feed 251 values. This feed dictionary is passed to the session `run()` call when 252 running the init op. 253 init_fn: Optional callable used to initialize the model. Called after the 254 optional `init_op` is called. The callable must accept one argument, 255 the session being initialized. 256 257 Returns: 258 A `Session` object that can be used to drive the model. 259 260 Raises: 261 RuntimeError: If the model cannot be initialized or recovered. 262 263 Raises: 264 ValueError: If both checkpoint_dir and checkpoint_filename_with_path are 265 set. 266 """ 267 268 sess, is_loaded_from_checkpoint = self._restore_checkpoint( 269 master, 270 saver, 271 checkpoint_dir=checkpoint_dir, 272 checkpoint_filename_with_path=checkpoint_filename_with_path, 273 wait_for_checkpoint=wait_for_checkpoint, 274 max_wait_secs=max_wait_secs, 275 config=config) 276 if not is_loaded_from_checkpoint: 277 if init_op is None and not init_fn and self._local_init_op is None: 278 raise RuntimeError("Model is not initialized and no init_op or " 279 "init_fn or local_init_op was given") 280 if init_op is not None: 281 sess.run(init_op, feed_dict=init_feed_dict) 282 if init_fn: 283 init_fn(sess) 284 285 local_init_success, msg = self._try_run_local_init_op(sess) 286 if not local_init_success: 287 raise RuntimeError( 288 "Init operations did not make model ready for local_init. " 289 "Init op: %s, init fn: %s, error: %s" % (_maybe_name(init_op), 290 init_fn, 291 msg)) 292 293 is_ready, msg = self._model_ready(sess) 294 if not is_ready: 295 raise RuntimeError( 296 "Init operations did not make model ready. " 297 "Init op: %s, init fn: %s, local_init_op: %s, error: %s" % 298 (_maybe_name(init_op), init_fn, self._local_init_op, msg)) 299 return sess 300 301 def recover_session(self, 302 master, 303 saver=None, 304 checkpoint_dir=None, 305 checkpoint_filename_with_path=None, 306 wait_for_checkpoint=False, 307 max_wait_secs=7200, 308 config=None): 309 """Creates a `Session`, recovering if possible. 310 311 Creates a new session on 'master'. If the session is not initialized 312 and can be recovered from a checkpoint, recover it. 313 314 Args: 315 master: `String` representation of the TensorFlow master to use. 316 saver: A `Saver` object used to restore a model. 317 checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the 318 dir will be used to restore. 319 checkpoint_filename_with_path: Full file name path to the checkpoint file. 320 wait_for_checkpoint: Whether to wait for checkpoint to become available. 321 max_wait_secs: Maximum time to wait for checkpoints to become available. 322 config: Optional `ConfigProto` proto used to configure the session. 323 324 Returns: 325 A pair (sess, initialized) where 'initialized' is `True` if 326 the session could be recovered and initialized, `False` otherwise. 327 328 Raises: 329 ValueError: If both checkpoint_dir and checkpoint_filename_with_path are 330 set. 331 """ 332 333 sess, is_loaded_from_checkpoint = self._restore_checkpoint( 334 master, 335 saver, 336 checkpoint_dir=checkpoint_dir, 337 checkpoint_filename_with_path=checkpoint_filename_with_path, 338 wait_for_checkpoint=wait_for_checkpoint, 339 max_wait_secs=max_wait_secs, 340 config=config) 341 342 # Always try to run local_init_op 343 local_init_success, msg = self._try_run_local_init_op(sess) 344 345 if not is_loaded_from_checkpoint: 346 # Do not need to run checks for readiness 347 return sess, False 348 349 restoring_file = checkpoint_dir or checkpoint_filename_with_path 350 if not local_init_success: 351 logging.info( 352 "Restoring model from %s did not make model ready for local init:" 353 " %s", restoring_file, msg) 354 return sess, False 355 356 is_ready, msg = self._model_ready(sess) 357 if not is_ready: 358 logging.info("Restoring model from %s did not make model ready: %s", 359 restoring_file, msg) 360 return sess, False 361 362 logging.info("Restored model from %s", restoring_file) 363 return sess, is_loaded_from_checkpoint 364 365 def wait_for_session(self, master, config=None, max_wait_secs=float("Inf")): 366 """Creates a new `Session` and waits for model to be ready. 367 368 Creates a new `Session` on 'master'. Waits for the model to be 369 initialized or recovered from a checkpoint. It's expected that 370 another thread or process will make the model ready, and that this 371 is intended to be used by threads/processes that participate in a 372 distributed training configuration where a different thread/process 373 is responsible for initializing or recovering the model being trained. 374 375 NB: The amount of time this method waits for the session is bounded 376 by max_wait_secs. By default, this function will wait indefinitely. 377 378 Args: 379 master: `String` representation of the TensorFlow master to use. 380 config: Optional ConfigProto proto used to configure the session. 381 max_wait_secs: Maximum time to wait for the session to become available. 382 383 Returns: 384 A `Session`. May be None if the operation exceeds the timeout 385 specified by config.operation_timeout_in_ms. 386 387 Raises: 388 tf.DeadlineExceededError: if the session is not available after 389 max_wait_secs. 390 """ 391 self._target = master 392 393 if max_wait_secs is None: 394 max_wait_secs = float("Inf") 395 timer = _CountDownTimer(max_wait_secs) 396 397 while True: 398 sess = session.Session(self._target, graph=self._graph, config=config) 399 not_ready_msg = None 400 not_ready_local_msg = None 401 local_init_success, not_ready_local_msg = self._try_run_local_init_op( 402 sess) 403 if local_init_success: 404 # Successful if local_init_op is None, or ready_for_local_init_op passes 405 is_ready, not_ready_msg = self._model_ready(sess) 406 if is_ready: 407 return sess 408 409 self._safe_close(sess) 410 411 # Do we have enough time left to try again? 412 remaining_ms_after_wait = ( 413 timer.secs_remaining() - self._recovery_wait_secs) 414 if remaining_ms_after_wait < 0: 415 raise errors.DeadlineExceededError( 416 None, None, 417 "Session was not ready after waiting %d secs." % (max_wait_secs,)) 418 419 logging.info("Waiting for model to be ready. " 420 "Ready_for_local_init_op: %s, ready: %s", 421 not_ready_local_msg, not_ready_msg) 422 time.sleep(self._recovery_wait_secs) 423 424 def _safe_close(self, sess): 425 """Closes a session without raising an exception. 426 427 Just like sess.close() but ignores exceptions. 428 429 Args: 430 sess: A `Session`. 431 """ 432 # pylint: disable=broad-except 433 try: 434 sess.close() 435 except Exception: 436 # Intentionally not logging to avoid user complaints that 437 # they get cryptic errors. We really do not care that Close 438 # fails. 439 pass 440 # pylint: enable=broad-except 441 442 def _model_ready(self, sess): 443 """Checks if the model is ready or not. 444 445 Args: 446 sess: A `Session`. 447 448 Returns: 449 A tuple (is_ready, msg), where is_ready is True if ready and False 450 otherwise, and msg is `None` if the model is ready, a `String` with the 451 reason why it is not ready otherwise. 452 """ 453 return _ready(self._ready_op, sess, "Model not ready") 454 455 def _model_ready_for_local_init(self, sess): 456 """Checks if the model is ready to run local_init_op. 457 458 Args: 459 sess: A `Session`. 460 461 Returns: 462 A tuple (is_ready, msg), where is_ready is True if ready to run 463 local_init_op and False otherwise, and msg is `None` if the model is 464 ready to run local_init_op, a `String` with the reason why it is not ready 465 otherwise. 466 """ 467 return _ready(self._ready_for_local_init_op, sess, 468 "Model not ready for local init") 469 470 def _try_run_local_init_op(self, sess): 471 """Tries to run _local_init_op, if not None, and is ready for local init. 472 473 Args: 474 sess: A `Session`. 475 476 Returns: 477 A tuple (is_successful, msg), where is_successful is True if 478 _local_init_op is None, or we ran _local_init_op, and False otherwise; 479 and msg is a `String` with the reason why the model was not ready to run 480 local init. 481 """ 482 if self._local_init_op is not None: 483 is_ready_for_local_init, msg = self._model_ready_for_local_init(sess) 484 if is_ready_for_local_init: 485 logging.info("Running local_init_op.") 486 sess.run(self._local_init_op) 487 logging.info("Done running local_init_op.") 488 return True, None 489 else: 490 return False, msg 491 return True, None 492 493 494 def _ready(op, sess, msg): 495 """Checks if the model is ready or not, as determined by op. 496 497 Args: 498 op: An op, either _ready_op or _ready_for_local_init_op, which defines the 499 readiness of the model. 500 sess: A `Session`. 501 msg: A message to log to warning if not ready 502 503 Returns: 504 A tuple (is_ready, msg), where is_ready is True if ready and False 505 otherwise, and msg is `None` if the model is ready, a `String` with the 506 reason why it is not ready otherwise. 507 """ 508 if op is None: 509 return True, None 510 else: 511 try: 512 ready_value = sess.run(op) 513 # The model is considered ready if ready_op returns an empty 1-D tensor. 514 # Also compare to `None` and dtype being int32 for backward 515 # compatibility. 516 if (ready_value is None or ready_value.dtype == np.int32 or 517 ready_value.size == 0): 518 return True, None 519 else: 520 # TODO(sherrym): If a custom ready_op returns other types of tensor, 521 # or strings other than variable names, this message could be 522 # confusing. 523 non_initialized_varnames = ", ".join( 524 [i.decode("utf-8") for i in ready_value]) 525 return False, "Variables not initialized: " + non_initialized_varnames 526 except errors.FailedPreconditionError as e: 527 if "uninitialized" not in str(e): 528 logging.warning("%s : error [%s]", msg, str(e)) 529 raise e 530 return False, str(e) 531 532 533 class _CountDownTimer(object): 534 535 def __init__(self, duration_secs): 536 self._start_time_secs = time.time() 537 self._duration_secs = duration_secs 538 539 def secs_remaining(self): 540 diff = self._duration_secs - (time.time() - self._start_time_secs) 541 return max(0, diff) 542