1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Training helper that checkpoints models and computes summaries.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import contextlib 21 import os 22 import time 23 24 from tensorflow.core.framework.summary_pb2 import Summary 25 from tensorflow.core.util.event_pb2 import SessionLog 26 from tensorflow.python.eager import context 27 from tensorflow.python.framework import dtypes 28 from tensorflow.python.framework import meta_graph 29 from tensorflow.python.framework import ops 30 from tensorflow.python.ops import control_flow_ops 31 from tensorflow.python.ops import lookup_ops 32 from tensorflow.python.ops import variables 33 from tensorflow.python.platform import tf_logging as logging 34 from tensorflow.python.summary import summary as _summary 35 from tensorflow.python.training import coordinator 36 from tensorflow.python.training import saver as saver_mod 37 from tensorflow.python.training import session_manager as session_manager_mod 38 from tensorflow.python.training import training_util 39 from tensorflow.python.util import deprecation 40 from tensorflow.python.util.tf_export import tf_export 41 42 43 @tf_export("train.Supervisor") 44 class Supervisor(object): 45 """A training helper that checkpoints models and computes summaries. 46 47 This class is deprecated. Please use 48 ${tf.train.MonitoredTrainingSession} instead. 49 50 The Supervisor is a small wrapper around a `Coordinator`, a `Saver`, 51 and a `SessionManager` that takes care of common needs of TensorFlow 52 training programs. 53 54 #### Use for a single program 55 56 ```python 57 with tf.Graph().as_default(): 58 ...add operations to the graph... 59 # Create a Supervisor that will checkpoint the model in '/tmp/mydir'. 60 sv = Supervisor(logdir='/tmp/mydir') 61 # Get a TensorFlow session managed by the supervisor. 62 with sv.managed_session(FLAGS.master) as sess: 63 # Use the session to train the graph. 64 while not sv.should_stop(): 65 sess.run(<my_train_op>) 66 ``` 67 68 Within the `with sv.managed_session()` block all variables in the graph have 69 been initialized. In addition, a few services have been started to 70 checkpoint the model and add summaries to the event log. 71 72 If the program crashes and is restarted, the managed session automatically 73 reinitialize variables from the most recent checkpoint. 74 75 The supervisor is notified of any exception raised by one of the services. 76 After an exception is raised, `should_stop()` returns `True`. In that case 77 the training loop should also stop. This is why the training loop has to 78 check for `sv.should_stop()`. 79 80 Exceptions that indicate that the training inputs have been exhausted, 81 `tf.errors.OutOfRangeError`, also cause `sv.should_stop()` to return `True` 82 but are not re-raised from the `with` block: they indicate a normal 83 termination. 84 85 #### Use for multiple replicas 86 87 To train with replicas you deploy the same program in a `Cluster`. 88 One of the tasks must be identified as the *chief*: the task that handles 89 initialization, checkpoints, summaries, and recovery. The other tasks 90 depend on the *chief* for these services. 91 92 The only change you have to do to the single program code is to indicate 93 if the program is running as the *chief*. 94 95 ```python 96 # Choose a task as the chief. This could be based on server_def.task_index, 97 # or job_def.name, or job_def.tasks. It's entirely up to the end user. 98 # But there can be only one *chief*. 99 is_chief = (server_def.task_index == 0) 100 server = tf.train.Server(server_def) 101 102 with tf.Graph().as_default(): 103 ...add operations to the graph... 104 # Create a Supervisor that uses log directory on a shared file system. 105 # Indicate if you are the 'chief' 106 sv = Supervisor(logdir='/shared_directory/...', is_chief=is_chief) 107 # Get a Session in a TensorFlow server on the cluster. 108 with sv.managed_session(server.target) as sess: 109 # Use the session to train the graph. 110 while not sv.should_stop(): 111 sess.run(<my_train_op>) 112 ``` 113 114 In the *chief* task, the `Supervisor` works exactly as in the first example 115 above. In the other tasks `sv.managed_session()` waits for the Model to have 116 been initialized before returning a session to the training code. The 117 non-chief tasks depend on the chief task for initializing the model. 118 119 If one of the tasks crashes and restarts, `managed_session()` 120 checks if the Model is initialized. If yes, it just creates a session and 121 returns it to the training code that proceeds normally. If the model needs 122 to be initialized, the chief task takes care of reinitializing it; the other 123 tasks just wait for the model to have been initialized. 124 125 NOTE: This modified program still works fine as a single program. 126 The single program marks itself as the chief. 127 128 #### What `master` string to use 129 130 Whether you are running on your machine or in the cluster you can use the 131 following values for the --master flag: 132 133 * Specifying `''` requests an in-process session that does not use RPC. 134 135 * Specifying `'local'` requests a session that uses the RPC-based 136 "Master interface" to run TensorFlow programs. See 137 @{tf.train.Server.create_local_server} for 138 details. 139 140 * Specifying `'grpc://hostname:port'` requests a session that uses 141 the RPC interface to a specific host, and also allows the in-process 142 master to access remote tensorflow workers. Often, it is 143 appropriate to pass `server.target` (for some `tf.train.Server` 144 named `server). 145 146 #### Advanced use 147 148 ##### Launching additional services 149 150 `managed_session()` launches the Checkpoint and Summary services (threads). 151 If you need more services to run you can simply launch them in the block 152 controlled by `managed_session()`. 153 154 Example: Start a thread to print losses. We want this thread to run 155 every 60 seconds, so we launch it with `sv.loop()`. 156 157 ```python 158 ... 159 sv = Supervisor(logdir='/tmp/mydir') 160 with sv.managed_session(FLAGS.master) as sess: 161 sv.loop(60, print_loss, (sess, )) 162 while not sv.should_stop(): 163 sess.run(my_train_op) 164 ``` 165 166 ##### Launching fewer services 167 168 `managed_session()` launches the "summary" and "checkpoint" threads which use 169 either the optionally `summary_op` and `saver` passed to the constructor, or 170 default ones created automatically by the supervisor. If you want to run 171 your own summary and checkpointing logic, disable these services by passing 172 `None` to the `summary_op` and `saver` parameters. 173 174 Example: Create summaries manually every 100 steps in the chief. 175 176 ```python 177 # Create a Supervisor with no automatic summaries. 178 sv = Supervisor(logdir='/tmp/mydir', is_chief=is_chief, summary_op=None) 179 # As summary_op was None, managed_session() does not start the 180 # summary thread. 181 with sv.managed_session(FLAGS.master) as sess: 182 for step in xrange(1000000): 183 if sv.should_stop(): 184 break 185 if is_chief and step % 100 == 0: 186 # Create the summary every 100 chief steps. 187 sv.summary_computed(sess, sess.run(my_summary_op)) 188 else: 189 # Train normally 190 sess.run(my_train_op) 191 ``` 192 193 ##### Custom model initialization 194 195 `managed_session()` only supports initializing the model by running an 196 `init_op` or restoring from the latest checkpoint. If you have special 197 initialization needs, see how to specify a `local_init_op` when creating the 198 supervisor. You can also use the `SessionManager` directly to create a 199 session and check if it could be initialized automatically. 200 """ 201 202 # Value to pass for the 'ready_op', 'init_op', 'summary_op', 'saver', 203 # and 'global_step' parameters of Supervisor.__init__() to indicate that 204 # the default behavior should be used. 205 USE_DEFAULT = 0 206 207 @deprecation.deprecated(None, 208 "Please switch to tf.train.MonitoredTrainingSession") 209 def __init__(self, 210 graph=None, 211 ready_op=USE_DEFAULT, 212 ready_for_local_init_op=USE_DEFAULT, 213 is_chief=True, 214 init_op=USE_DEFAULT, 215 init_feed_dict=None, 216 local_init_op=USE_DEFAULT, 217 logdir=None, 218 summary_op=USE_DEFAULT, 219 saver=USE_DEFAULT, 220 global_step=USE_DEFAULT, 221 save_summaries_secs=120, 222 save_model_secs=600, 223 recovery_wait_secs=30, 224 stop_grace_secs=120, 225 checkpoint_basename="model.ckpt", 226 session_manager=None, 227 summary_writer=USE_DEFAULT, 228 init_fn=None): 229 """Create a `Supervisor`. 230 231 Args: 232 graph: A `Graph`. The graph that the model will use. Defaults to the 233 default `Graph`. The supervisor may add operations to the graph before 234 creating a session, but the graph should not be modified by the caller 235 after passing it to the supervisor. 236 ready_op: 1-D string `Tensor`. This tensor is evaluated by supervisors in 237 `prepare_or_wait_for_session()` to check if the model is ready to use. 238 The model is considered ready if it returns an empty array. Defaults to 239 the tensor returned from `tf.report_uninitialized_variables()` If 240 `None`, the model is not checked for readiness. 241 ready_for_local_init_op: 1-D string `Tensor`. This tensor is evaluated by 242 supervisors in `prepare_or_wait_for_session()` to check if the model is 243 ready to run the local_init_op. 244 The model is considered ready if it returns an empty array. Defaults to 245 the tensor returned from 246 `tf.report_uninitialized_variables(tf.global_variables())`. If `None`, 247 the model is not checked for readiness before running local_init_op. 248 is_chief: If True, create a chief supervisor in charge of initializing 249 and restoring the model. If False, create a supervisor that relies 250 on a chief supervisor for inits and restore. 251 init_op: `Operation`. Used by chief supervisors to initialize the model 252 when it can not be recovered. Defaults to an `Operation` that 253 initializes all global variables. If `None`, no initialization is done 254 automatically unless you pass a value for `init_fn`, see below. 255 init_feed_dict: A dictionary that maps `Tensor` objects to feed values. 256 This feed dictionary will be used when `init_op` is evaluated. 257 local_init_op: `Operation`. Used by all supervisors to run initializations 258 that should run for every new supervisor instance. By default these 259 are table initializers and initializers for local variables. 260 If `None`, no further per supervisor-instance initialization is 261 done automatically. 262 logdir: A string. Optional path to a directory where to checkpoint the 263 model and log events for the visualizer. Used by chief supervisors. 264 The directory will be created if it does not exist. 265 summary_op: An `Operation` that returns a Summary for the event logs. 266 Used by chief supervisors if a `logdir` was specified. Defaults to the 267 operation returned from summary.merge_all(). If `None`, summaries are 268 not computed automatically. 269 saver: A Saver object. Used by chief supervisors if a `logdir` was 270 specified. Defaults to the saved returned by Saver(). 271 If `None`, the model is not saved automatically. 272 global_step: An integer Tensor of size 1 that counts steps. The value 273 from 'global_step' is used in summaries and checkpoint filenames. 274 Default to the op named 'global_step' in the graph if it exists, is of 275 rank 1, size 1, and of type tf.int32 or tf.int64. If `None` the global 276 step is not recorded in summaries and checkpoint files. Used by chief 277 supervisors if a `logdir` was specified. 278 save_summaries_secs: Number of seconds between the computation of 279 summaries for the event log. Defaults to 120 seconds. Pass 0 to 280 disable summaries. 281 save_model_secs: Number of seconds between the creation of model 282 checkpoints. Defaults to 600 seconds. Pass 0 to disable checkpoints. 283 recovery_wait_secs: Number of seconds between checks that the model 284 is ready. Used by supervisors when waiting for a chief supervisor 285 to initialize or restore the model. Defaults to 30 seconds. 286 stop_grace_secs: Grace period, in seconds, given to running threads to 287 stop when `stop()` is called. Defaults to 120 seconds. 288 checkpoint_basename: The basename for checkpoint saving. 289 session_manager: `SessionManager`, which manages Session creation and 290 recovery. If it is `None`, a default `SessionManager` will be created 291 with the set of arguments passed in for backwards compatibility. 292 summary_writer: `SummaryWriter` to use or `USE_DEFAULT`. Can be `None` 293 to indicate that no summaries should be written. 294 init_fn: Optional callable used to initialize the model. Called 295 after the optional `init_op` is called. The callable must accept one 296 argument, the session being initialized. 297 298 Returns: 299 A `Supervisor`. 300 301 Raises: 302 RuntimeError: If called with eager execution enabled. 303 304 @compatibility(eager) 305 `Supervisor`s are not supported when eager execution is enabled. 306 @end_compatibility 307 """ 308 if context.in_eager_mode(): 309 raise RuntimeError("Supervisors are compatible with eager execution.") 310 # Set default values of arguments. 311 if graph is None: 312 graph = ops.get_default_graph() 313 with graph.as_default(): 314 self._init_ready_op( 315 ready_op=ready_op, ready_for_local_init_op=ready_for_local_init_op) 316 self._init_init_op(init_op=init_op, init_feed_dict=init_feed_dict) 317 self._init_local_init_op(local_init_op=local_init_op) 318 self._init_saver(saver=saver) 319 self._init_summary_op(summary_op=summary_op) 320 self._init_global_step(global_step=global_step) 321 self._graph = graph 322 self._meta_graph_def = meta_graph.create_meta_graph_def( 323 graph_def=graph.as_graph_def(add_shapes=True), 324 saver_def=self._saver.saver_def if self._saver else None) 325 self._is_chief = is_chief 326 self._coord = coordinator.Coordinator() 327 self._recovery_wait_secs = recovery_wait_secs 328 self._stop_grace_secs = stop_grace_secs 329 self._init_fn = init_fn 330 331 # Set all attributes related to checkpointing and writing events to None. 332 # Afterwards, set them appropriately for chief supervisors, as these are 333 # the only supervisors that can write checkpoints and events. 334 self._logdir = None 335 self._save_summaries_secs = None 336 self._save_model_secs = None 337 self._save_path = None 338 self._summary_writer = None 339 340 if self._is_chief: 341 self._logdir = logdir 342 self._save_summaries_secs = save_summaries_secs 343 self._save_model_secs = save_model_secs 344 if self._logdir: 345 self._save_path = os.path.join(self._logdir, checkpoint_basename) 346 if summary_writer is Supervisor.USE_DEFAULT: 347 if self._logdir: 348 self._summary_writer = _summary.FileWriter(self._logdir) 349 else: 350 self._summary_writer = summary_writer 351 self._graph_added_to_summary = False 352 353 self._init_session_manager(session_manager=session_manager) 354 self._verify_setup() 355 # The graph is not allowed to change anymore. 356 graph.finalize() 357 358 def _init_session_manager(self, session_manager=None): 359 if session_manager is None: 360 self._session_manager = session_manager_mod.SessionManager( 361 local_init_op=self._local_init_op, 362 ready_op=self._ready_op, 363 ready_for_local_init_op=self._ready_for_local_init_op, 364 graph=self._graph, 365 recovery_wait_secs=self._recovery_wait_secs) 366 else: 367 self._session_manager = session_manager 368 369 def _get_first_op_from_collection(self, key): 370 """Returns the first `Operation` from a collection. 371 372 Args: 373 key: A string collection key. 374 375 Returns: 376 The first Op found in a collection, or `None` if the collection is empty. 377 """ 378 try: 379 op_list = ops.get_collection(key) 380 if len(op_list) > 1: 381 logging.info("Found %d %s operations. Returning the first one.", 382 len(op_list), key) 383 if op_list: 384 return op_list[0] 385 except LookupError: 386 pass 387 388 return None 389 390 def _init_ready_op(self, 391 ready_op=USE_DEFAULT, 392 ready_for_local_init_op=USE_DEFAULT): 393 """Initializes ready_op. 394 395 Args: 396 ready_op: `Tensor` to check if the model is initialized. 397 If it's set to USE_DEFAULT, creates an op that checks all 398 the variables are initialized. 399 ready_for_local_init_op: `Tensor` to check if the model is ready to run 400 local_init_op. 401 If it's set to USE_DEFAULT, creates an op that checks all 402 the global variables are initialized. 403 """ 404 if ready_op is Supervisor.USE_DEFAULT: 405 ready_op = self._get_first_op_from_collection(ops.GraphKeys.READY_OP) 406 if ready_op is None: 407 ready_op = variables.report_uninitialized_variables() 408 ops.add_to_collection(ops.GraphKeys.READY_OP, ready_op) 409 self._ready_op = ready_op 410 411 # ready_for_local_init_op defaults to None for backward compatibility 412 if ready_for_local_init_op is Supervisor.USE_DEFAULT: 413 ready_for_local_init_op = self._get_first_op_from_collection( 414 ops.GraphKeys.READY_FOR_LOCAL_INIT_OP) 415 self._ready_for_local_init_op = ready_for_local_init_op 416 417 def _init_init_op(self, init_op=USE_DEFAULT, init_feed_dict=None): 418 """Initializes init_op. 419 420 Args: 421 init_op: `Operation` to initialize the variables. If set to USE_DEFAULT, 422 create an op that initializes all variables and tables. 423 init_feed_dict: A dictionary that maps `Tensor` objects to feed values. 424 This feed dictionary will be used when `init_op` is evaluated. 425 """ 426 if init_op is Supervisor.USE_DEFAULT: 427 init_op = self._get_first_op_from_collection(ops.GraphKeys.INIT_OP) 428 if init_op is None: 429 init_op = variables.global_variables_initializer() 430 ops.add_to_collection(ops.GraphKeys.INIT_OP, init_op) 431 self._init_op = init_op 432 self._init_feed_dict = init_feed_dict 433 434 def _init_local_init_op(self, local_init_op=USE_DEFAULT): 435 """Initializes local_init_op. 436 437 Args: 438 local_init_op: `Operation` run for every new supervisor instance. If set 439 to USE_DEFAULT, use the first op from the GraphKeys.LOCAL_INIT_OP 440 collection. If the collection is empty, create an op that initializes 441 all local variables and all tables. 442 """ 443 if local_init_op is Supervisor.USE_DEFAULT: 444 local_init_op = self._get_first_op_from_collection( 445 ops.GraphKeys.LOCAL_INIT_OP) 446 if local_init_op is None: 447 op_list = [ 448 variables.local_variables_initializer(), 449 lookup_ops.tables_initializer() 450 ] 451 if op_list: 452 local_init_op = control_flow_ops.group(*op_list) 453 ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op) 454 self._local_init_op = local_init_op 455 456 def _init_saver(self, saver=USE_DEFAULT): 457 """Initializes saver. 458 459 Args: 460 saver: A `Saver` object. If set to USE_DEFAULT, create one that 461 saves all the variables. 462 """ 463 if saver is Supervisor.USE_DEFAULT: 464 saver = self._get_first_op_from_collection(ops.GraphKeys.SAVERS) 465 if saver is None and variables.global_variables(): 466 saver = saver_mod.Saver() 467 ops.add_to_collection(ops.GraphKeys.SAVERS, saver) 468 self._saver = saver 469 470 def _init_summary_op(self, summary_op=USE_DEFAULT): 471 """Initializes summary_op. 472 473 Args: 474 summary_op: An Operation that returns a Summary for the event logs. 475 If set to USE_DEFAULT, create an op that merges all the summaries. 476 """ 477 if summary_op is Supervisor.USE_DEFAULT: 478 summary_op = self._get_first_op_from_collection(ops.GraphKeys.SUMMARY_OP) 479 if summary_op is None: 480 summary_op = _summary.merge_all() 481 if summary_op is not None: 482 ops.add_to_collection(ops.GraphKeys.SUMMARY_OP, summary_op) 483 self._summary_op = summary_op 484 485 def _init_global_step(self, global_step=USE_DEFAULT): 486 """Initializes global_step. 487 488 Args: 489 global_step: An integer Tensor of size 1 that counts steps. If 490 set to USE_DEFAULT, creates global_step tensor. 491 """ 492 if global_step is Supervisor.USE_DEFAULT: 493 global_step = self._get_first_op_from_collection( 494 ops.GraphKeys.GLOBAL_STEP) 495 if global_step is None: 496 global_step = self._default_global_step_tensor() 497 if global_step is not None: 498 ops.add_to_collection(ops.GraphKeys.GLOBAL_STEP, global_step) 499 self._global_step = global_step 500 501 @property 502 def is_chief(self): 503 """Return True if this is a chief supervisor. 504 505 Returns: 506 A bool. 507 """ 508 return self._is_chief 509 510 @property 511 def session_manager(self): 512 """Return the SessionManager used by the Supervisor. 513 514 Returns: 515 A SessionManager object. 516 """ 517 return self._session_manager 518 519 @property 520 def coord(self): 521 """Return the Coordinator used by the Supervisor. 522 523 The Coordinator can be useful if you want to run multiple threads 524 during your training. 525 526 Returns: 527 A Coordinator object. 528 """ 529 return self._coord 530 531 @property 532 def init_op(self): 533 """Return the Init Op used by the supervisor. 534 535 Returns: 536 An Op or `None`. 537 """ 538 return self._init_op 539 540 @property 541 def init_feed_dict(self): 542 """Return the feed dictionary used when evaluating the `init_op`. 543 544 Returns: 545 A feed dictionary or `None`. 546 """ 547 return self._init_feed_dict 548 549 @property 550 def ready_op(self): 551 """Return the Ready Op used by the supervisor. 552 553 Returns: 554 An Op or `None`. 555 """ 556 return self._ready_op 557 558 @property 559 def ready_for_local_init_op(self): 560 return self._ready_for_local_init_op 561 562 @property 563 def summary_writer(self): 564 """Return the SummaryWriter used by the chief supervisor. 565 566 Returns: 567 A SummaryWriter. 568 """ 569 return self._summary_writer 570 571 @property 572 def summary_op(self): 573 """Return the Summary Tensor used by the chief supervisor. 574 575 Returns: 576 A string Tensor for the summary or `None`. 577 """ 578 return self._summary_op 579 580 @property 581 def save_summaries_secs(self): 582 """Return the delay between summary computations. 583 584 Returns: 585 A timestamp. 586 """ 587 return self._save_summaries_secs 588 589 @property 590 def global_step(self): 591 """Return the global_step Tensor used by the supervisor. 592 593 Returns: 594 An integer Tensor for the global_step. 595 """ 596 return self._global_step 597 598 @property 599 def saver(self): 600 """Return the Saver used by the supervisor. 601 602 Returns: 603 A Saver object. 604 """ 605 return self._saver 606 607 @property 608 def save_model_secs(self): 609 """Return the delay between checkpoints. 610 611 Returns: 612 A timestamp. 613 """ 614 return self._save_model_secs 615 616 @property 617 def save_path(self): 618 """Return the save path used by the supervisor. 619 620 Returns: 621 A string. 622 """ 623 return self._save_path 624 625 def _write_graph(self): 626 """Writes graph_def to `logdir` and adds it to summary if applicable.""" 627 assert self._is_chief 628 if self._logdir: 629 training_util.write_graph(self._graph.as_graph_def(add_shapes=True), 630 self._logdir, "graph.pbtxt") 631 if self._summary_writer and not self._graph_added_to_summary: 632 self._summary_writer.add_graph(self._graph) 633 self._summary_writer.add_meta_graph(self._meta_graph_def) 634 self._graph_added_to_summary = True 635 636 def start_standard_services(self, sess): 637 """Start the standard services for 'sess'. 638 639 This starts services in the background. The services started depend 640 on the parameters to the constructor and may include: 641 642 - A Summary thread computing summaries every save_summaries_secs. 643 - A Checkpoint thread saving the model every save_model_secs. 644 - A StepCounter thread measure step time. 645 646 Args: 647 sess: A Session. 648 649 Returns: 650 A list of threads that are running the standard services. You can use 651 the Supervisor's Coordinator to join these threads with: 652 sv.coord.Join(<list of threads>) 653 654 Raises: 655 RuntimeError: If called with a non-chief Supervisor. 656 ValueError: If not `logdir` was passed to the constructor as the 657 services need a log directory. 658 """ 659 if not self._is_chief: 660 raise RuntimeError("Only chief supervisor can start standard services. " 661 "Because only chief supervisors can write events.") 662 663 if not self._logdir: 664 logging.warning("Standard services need a 'logdir' " 665 "passed to the SessionManager") 666 return 667 668 if self._global_step is not None and self._summary_writer: 669 # Only add the session log if we keep track of global step. 670 # TensorBoard cannot use START message for purging expired events 671 # if there is no step value. 672 current_step = training_util.global_step(sess, self._global_step) 673 self._summary_writer.add_session_log( 674 SessionLog(status=SessionLog.START), 675 current_step) 676 677 threads = [] 678 if self._save_summaries_secs and self._summary_writer: 679 if self._summary_op is not None: 680 threads.append(SVSummaryThread(self, sess)) 681 if self._global_step is not None: 682 threads.append(SVStepCounterThread(self, sess)) 683 if self.saver and self._save_model_secs: 684 threads.append(SVTimerCheckpointThread(self, sess)) 685 for t in threads: 686 t.start() 687 return threads 688 689 def prepare_or_wait_for_session(self, master="", config=None, 690 wait_for_checkpoint=False, 691 max_wait_secs=7200, 692 start_standard_services=True): 693 """Make sure the model is ready to be used. 694 695 Create a session on 'master', recovering or initializing the model as 696 needed, or wait for a session to be ready. If running as the chief 697 and `start_standard_service` is set to True, also call the session 698 manager to start the standard services. 699 700 Args: 701 master: name of the TensorFlow master to use. See the `tf.Session` 702 constructor for how this is interpreted. 703 config: Optional ConfigProto proto used to configure the session, 704 which is passed as-is to create the session. 705 wait_for_checkpoint: Whether we should wait for the availability of a 706 checkpoint before creating Session. Defaults to False. 707 max_wait_secs: Maximum time to wait for the session to become available. 708 start_standard_services: Whether to start the standard services and the 709 queue runners. 710 711 Returns: 712 A Session object that can be used to drive the model. 713 """ 714 # For users who recreate the session with prepare_or_wait_for_session(), we 715 # need to clear the coordinator's stop_event so that threads managed by the 716 # coordinator can run. 717 self._coord.clear_stop() 718 if self._summary_writer: 719 self._summary_writer.reopen() 720 721 if self._is_chief: 722 sess = self._session_manager.prepare_session( 723 master, init_op=self.init_op, saver=self.saver, 724 checkpoint_dir=self._logdir, wait_for_checkpoint=wait_for_checkpoint, 725 max_wait_secs=max_wait_secs, config=config, 726 init_feed_dict=self._init_feed_dict, init_fn=self._init_fn) 727 self._write_graph() 728 if start_standard_services: 729 logging.info("Starting standard services.") 730 self.start_standard_services(sess) 731 else: 732 sess = self._session_manager.wait_for_session(master, 733 config=config, 734 max_wait_secs=max_wait_secs) 735 if start_standard_services: 736 logging.info("Starting queue runners.") 737 self.start_queue_runners(sess) 738 return sess 739 740 def start_queue_runners(self, sess, queue_runners=None): 741 """Start threads for `QueueRunners`. 742 743 Note that the queue runners collected in the graph key `QUEUE_RUNNERS` 744 are already started automatically when you create a session with the 745 supervisor, so unless you have non-collected queue runners to start 746 you do not need to call this explicitly. 747 748 Args: 749 sess: A `Session`. 750 queue_runners: A list of `QueueRunners`. If not specified, we'll use the 751 list of queue runners gathered in the graph under the key 752 `GraphKeys.QUEUE_RUNNERS`. 753 754 Returns: 755 The list of threads started for the `QueueRunners`. 756 757 Raises: 758 RuntimeError: If called with eager execution enabled. 759 760 @compatibility(eager) 761 Queues are not compatible with eager execution. To ingest data when eager 762 execution is enabled, use the `tf.data` API. 763 @end_compatibility 764 """ 765 if context.in_eager_mode(): 766 raise RuntimeError("Queues are not compatible with eager execution.") 767 if queue_runners is None: 768 queue_runners = self._graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS) 769 threads = [] 770 for qr in queue_runners: 771 threads.extend(qr.create_threads(sess, coord=self._coord, daemon=True, 772 start=True)) 773 return threads 774 775 def loop(self, timer_interval_secs, target, args=None, kwargs=None): 776 """Start a LooperThread that calls a function periodically. 777 778 If `timer_interval_secs` is None the thread calls `target(*args, **kwargs)` 779 repeatedly. Otherwise it calls it every `timer_interval_secs` 780 seconds. The thread terminates when a stop is requested. 781 782 The started thread is added to the list of threads managed by the supervisor 783 so it does not need to be passed to the `stop()` method. 784 785 Args: 786 timer_interval_secs: Number. Time boundaries at which to call `target`. 787 target: A callable object. 788 args: Optional arguments to pass to `target` when calling it. 789 kwargs: Optional keyword arguments to pass to `target` when calling it. 790 791 Returns: 792 The started thread. 793 """ 794 looper = coordinator.LooperThread(self._coord, timer_interval_secs, 795 target=target, args=args, kwargs=kwargs) 796 looper.start() 797 return looper 798 799 def stop(self, 800 threads=None, 801 close_summary_writer=True, 802 ignore_live_threads=False): 803 """Stop the services and the coordinator. 804 805 This does not close the session. 806 807 Args: 808 threads: Optional list of threads to join with the coordinator. If 809 `None`, defaults to the threads running the standard services, the 810 threads started for `QueueRunners`, and the threads started by the 811 `loop()` method. To wait on additional threads, pass the 812 list in this parameter. 813 close_summary_writer: Whether to close the `summary_writer`. Defaults to 814 `True` if the summary writer was created by the supervisor, `False` 815 otherwise. 816 ignore_live_threads: If `True` ignores threads that remain running after 817 a grace period when joining threads via the coordinator, instead of 818 raising a RuntimeError. 819 """ 820 self._coord.request_stop() 821 try: 822 # coord.join() re-raises the first reported exception; the "finally" 823 # block ensures that we clean up whether or not an exception was 824 # reported. 825 self._coord.join( 826 threads, 827 stop_grace_period_secs=self._stop_grace_secs, 828 ignore_live_threads=ignore_live_threads) 829 finally: 830 # Close the writer last, in case one of the running threads was using it. 831 if close_summary_writer and self._summary_writer: 832 # Stop messages are not logged with event.step, 833 # since the session may have already terminated. 834 self._summary_writer.add_session_log(SessionLog(status=SessionLog.STOP)) 835 self._summary_writer.close() 836 self._graph_added_to_summary = False 837 838 def request_stop(self, ex=None): 839 """Request that the coordinator stop the threads. 840 841 See `Coordinator.request_stop()`. 842 843 Args: 844 ex: Optional `Exception`, or Python `exc_info` tuple as returned by 845 `sys.exc_info()`. If this is the first call to `request_stop()` the 846 corresponding exception is recorded and re-raised from `join()`. 847 """ 848 self._coord.request_stop(ex=ex) 849 850 def should_stop(self): 851 """Check if the coordinator was told to stop. 852 853 See `Coordinator.should_stop()`. 854 855 Returns: 856 True if the coordinator was told to stop, False otherwise. 857 """ 858 return self._coord.should_stop() 859 860 def stop_on_exception(self): 861 """Context handler to stop the supervisor when an exception is raised. 862 863 See `Coordinator.stop_on_exception()`. 864 865 Returns: 866 A context handler. 867 """ 868 return self._coord.stop_on_exception() 869 870 def wait_for_stop(self): 871 """Block waiting for the coordinator to stop.""" 872 self._coord.wait_for_stop() 873 874 def summary_computed(self, sess, summary, global_step=None): 875 """Indicate that a summary was computed. 876 877 Args: 878 sess: A `Session` object. 879 summary: A Summary proto, or a string holding a serialized summary proto. 880 global_step: Int. global step this summary is associated with. If `None`, 881 it will try to fetch the current step. 882 883 Raises: 884 TypeError: if 'summary' is not a Summary proto or a string. 885 RuntimeError: if the Supervisor was created without a `logdir`. 886 """ 887 if not self._summary_writer: 888 raise RuntimeError("Writing a summary requires a summary writer.") 889 if global_step is None and self.global_step is not None: 890 global_step = training_util.global_step(sess, self.global_step) 891 self._summary_writer.add_summary(summary, global_step) 892 893 def _default_global_step_tensor(self): 894 """Returns the global_step from the default graph. 895 896 Returns: 897 The global step `Tensor` or `None`. 898 """ 899 try: 900 gs = ops.get_default_graph().get_tensor_by_name("global_step:0") 901 if gs.dtype.base_dtype in [dtypes.int32, dtypes.int64]: 902 return gs 903 else: 904 logging.warning("Found 'global_step' is not an int type: %s", gs.dtype) 905 return None 906 except KeyError: 907 return None 908 909 def _verify_setup(self): 910 """Check that all is good. 911 912 Raises: 913 ValueError: If something is not good. 914 """ 915 # Not running as chief means that replicas are used. 916 # In that case all Variables must have their device set. 917 if not self._is_chief: 918 for op in self._graph.get_operations(): 919 if op.type in ["Variable", "VariableV2"] and not op.device: 920 raise ValueError("When using replicas, all Variables must have " 921 "their device set: %s" % op) 922 923 # pylint: disable=g-doc-return-or-yield,broad-except 924 @contextlib.contextmanager 925 def managed_session(self, master="", config=None, 926 start_standard_services=True, 927 close_summary_writer=True): 928 """Returns a context manager for a managed session. 929 930 This context manager creates and automatically recovers a session. It 931 optionally starts the standard services that handle checkpoints and 932 summaries. It monitors exceptions raised from the `with` block or from the 933 services and stops the supervisor as needed. 934 935 The context manager is typically used as follows: 936 937 ```python 938 def train(): 939 sv = tf.train.Supervisor(...) 940 with sv.managed_session(<master>) as sess: 941 for step in xrange(..): 942 if sv.should_stop(): 943 break 944 sess.run(<my training op>) 945 ...do other things needed at each training step... 946 ``` 947 948 An exception raised from the `with` block or one of the service threads is 949 raised again when the block exits. This is done after stopping all threads 950 and closing the session. For example, an `AbortedError` exception, raised 951 in case of preemption of one of the workers in a distributed model, is 952 raised again when the block exits. 953 954 If you want to retry the training loop in case of preemption you can do it 955 as follows: 956 957 ```python 958 def main(...): 959 while True 960 try: 961 train() 962 except tf.errors.Aborted: 963 pass 964 ``` 965 966 As a special case, exceptions used for control flow, such as 967 `OutOfRangeError` which reports that input queues are exhausted, are not 968 raised again from the `with` block: they indicate a clean termination of 969 the training loop and are considered normal termination. 970 971 Args: 972 master: name of the TensorFlow master to use. See the `tf.Session` 973 constructor for how this is interpreted. 974 config: Optional `ConfigProto` proto used to configure the session. 975 Passed as-is to create the session. 976 start_standard_services: Whether to start the standard services, 977 such as checkpoint, summary and step counter. 978 close_summary_writer: Whether to close the summary writer when 979 closing the session. Defaults to True. 980 981 Returns: 982 A context manager that yields a `Session` restored from the latest 983 checkpoint or initialized from scratch if not checkpoint exists. The 984 session is closed when the `with` block exits. 985 """ 986 try: 987 sess = self.prepare_or_wait_for_session( 988 master=master, config=config, 989 start_standard_services=start_standard_services) 990 yield sess 991 except Exception as e: 992 self.request_stop(e) 993 finally: 994 try: 995 # Request all the threads to stop and wait for them to do so. Any 996 # exception raised by the threads is raised again from stop(). 997 # Passing stop_grace_period_secs is for blocked enqueue/dequeue 998 # threads which are not checking for `should_stop()`. They 999 # will be stopped when we close the session further down. 1000 self.stop(close_summary_writer=close_summary_writer) 1001 finally: 1002 # Close the session to finish up all pending calls. We do not care 1003 # about exceptions raised when closing. This takes care of 1004 # blocked enqueue/dequeue calls. 1005 try: 1006 sess.close() 1007 except Exception: 1008 # Silently ignore exceptions raised by close(). 1009 pass 1010 # pylint: enable=g-doc-return-or-yield,broad-except 1011 1012 1013 class SVSummaryThread(coordinator.LooperThread): 1014 """A thread to save summaries on a timer.""" 1015 1016 def __init__(self, sv, sess): 1017 """Create a SVSummaryThread. 1018 1019 Args: 1020 sv: A `Supervisor`. 1021 sess: A `Session`. 1022 """ 1023 super(SVSummaryThread, self).__init__(sv.coord, sv.save_summaries_secs) 1024 self._sv = sv 1025 self._sess = sess 1026 1027 def run_loop(self): 1028 if self._sv.global_step is not None: 1029 summary_strs, global_step = self._sess.run([self._sv.summary_op, 1030 self._sv.global_step]) 1031 else: 1032 summary_strs = self._sess.run(self._sv.summary_op) 1033 global_step = None 1034 if self._sv.summary_writer: 1035 logging.info("Recording summary at step %s.", global_step) 1036 self._sv.summary_writer.add_summary(summary_strs, global_step) 1037 1038 1039 class SVStepCounterThread(coordinator.LooperThread): 1040 """Threads to count steps and measure their duration.""" 1041 1042 def __init__(self, sv, sess, step_counter=None): 1043 """Create a `SVStepCounterThread`. 1044 1045 Args: 1046 sv: A `Supervisor`. 1047 sess: A `Session`. 1048 step_counter: A `Tensor` holding the step counter. By defaults, it uses 1049 sv.global_step. 1050 """ 1051 super(SVStepCounterThread, self).__init__(sv.coord, sv.save_summaries_secs) 1052 self._sv = sv 1053 self._sess = sess 1054 self._last_time = 0.0 1055 self._last_step = 0 1056 step_counter = sv.global_step if step_counter is None else step_counter 1057 self._step_counter = step_counter 1058 self._summary_tag = "%s/sec" % self._step_counter.op.name 1059 1060 def start_loop(self): 1061 self._last_time = time.time() 1062 self._last_step = training_util.global_step( 1063 self._sess, self._step_counter) 1064 1065 def run_loop(self): 1066 # Count the steps. 1067 current_step = training_util.global_step(self._sess, self._step_counter) 1068 added_steps = current_step - self._last_step 1069 self._last_step = current_step 1070 # Measure the elapsed time. 1071 current_time = time.time() 1072 elapsed_time = current_time - self._last_time 1073 self._last_time = current_time 1074 # Reports the number of steps done per second 1075 if elapsed_time > 0.: 1076 steps_per_sec = added_steps / elapsed_time 1077 else: 1078 steps_per_sec = float("inf") 1079 summary = Summary(value=[Summary.Value(tag=self._summary_tag, 1080 simple_value=steps_per_sec)]) 1081 if self._sv.summary_writer: 1082 self._sv.summary_writer.add_summary(summary, current_step) 1083 logging.log_first_n(logging.INFO, "%s: %g", 10, 1084 self._summary_tag, steps_per_sec) 1085 1086 1087 class SVTimerCheckpointThread(coordinator.LooperThread): 1088 """A thread to checkpoint on a timer.""" 1089 1090 def __init__(self, sv, sess): 1091 """Create a `SVTimerCheckpointThread`. 1092 1093 Args: 1094 sv: A `Supervisor`. 1095 sess: A `Session`. 1096 """ 1097 super(SVTimerCheckpointThread, self).__init__(sv.coord, sv.save_model_secs) 1098 self._sv = sv 1099 self._sess = sess 1100 1101 def run_loop(self): 1102 logging.info("Saving checkpoint to path %s", self._sv.save_path) 1103 self._sv.saver.save(self._sess, self._sv.save_path, 1104 global_step=self._sv.global_step) 1105 if self._sv.summary_writer and self._sv.global_step is not None: 1106 current_step = training_util.global_step(self._sess, self._sv.global_step) 1107 self._sv.summary_writer.add_session_log( 1108 SessionLog(status=SessionLog.CHECKPOINT, 1109 checkpoint_path=self._sv.save_path), 1110 current_step) 1111 1112 1113 # TODO(sherrym): All non-PEP8 compliant names will be deprecated shortly. 1114 setattr(Supervisor, "PrepareSession", Supervisor.prepare_or_wait_for_session) 1115 setattr(Supervisor, "StartQueueRunners", Supervisor.start_queue_runners) 1116 setattr(Supervisor, "StartStandardServices", Supervisor.start_standard_services) 1117 setattr(Supervisor, "Stop", Supervisor.stop) 1118 setattr(Supervisor, "RequestStop", Supervisor.request_stop) 1119 setattr(Supervisor, "Loop", Supervisor.loop) 1120 setattr(Supervisor, "ShouldStop", Supervisor.should_stop) 1121 setattr(Supervisor, "StopOnException", Supervisor.stop_on_exception) 1122 setattr(Supervisor, "WaitForStop", Supervisor.wait_for_stop) 1123 setattr(Supervisor, "SummaryComputed", Supervisor.summary_computed) 1124