Home | History | Annotate | Download | only in training
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Training helper that checkpoints models and computes summaries."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import contextlib
     21 import os
     22 import time
     23 
     24 from tensorflow.core.framework.summary_pb2 import Summary
     25 from tensorflow.core.util.event_pb2 import SessionLog
     26 from tensorflow.python.eager import context
     27 from tensorflow.python.framework import dtypes
     28 from tensorflow.python.framework import meta_graph
     29 from tensorflow.python.framework import ops
     30 from tensorflow.python.ops import control_flow_ops
     31 from tensorflow.python.ops import lookup_ops
     32 from tensorflow.python.ops import variables
     33 from tensorflow.python.platform import tf_logging as logging
     34 from tensorflow.python.summary import summary as _summary
     35 from tensorflow.python.training import coordinator
     36 from tensorflow.python.training import saver as saver_mod
     37 from tensorflow.python.training import session_manager as session_manager_mod
     38 from tensorflow.python.training import training_util
     39 from tensorflow.python.util import deprecation
     40 from tensorflow.python.util.tf_export import tf_export
     41 
     42 
     43 @tf_export("train.Supervisor")
     44 class Supervisor(object):
     45   """A training helper that checkpoints models and computes summaries.
     46 
     47   This class is deprecated. Please use
     48   ${tf.train.MonitoredTrainingSession} instead.
     49 
     50   The Supervisor is a small wrapper around a `Coordinator`, a `Saver`,
     51   and a `SessionManager` that takes care of common needs of TensorFlow
     52   training programs.
     53 
     54   #### Use for a single program
     55 
     56   ```python
     57   with tf.Graph().as_default():
     58     ...add operations to the graph...
     59     # Create a Supervisor that will checkpoint the model in '/tmp/mydir'.
     60     sv = Supervisor(logdir='/tmp/mydir')
     61     # Get a TensorFlow session managed by the supervisor.
     62     with sv.managed_session(FLAGS.master) as sess:
     63       # Use the session to train the graph.
     64       while not sv.should_stop():
     65         sess.run(<my_train_op>)
     66   ```
     67 
     68   Within the `with sv.managed_session()` block all variables in the graph have
     69   been initialized.  In addition, a few services have been started to
     70   checkpoint the model and add summaries to the event log.
     71 
     72   If the program crashes and is restarted, the managed session automatically
     73   reinitialize variables from the most recent checkpoint.
     74 
     75   The supervisor is notified of any exception raised by one of the services.
     76   After an exception is raised, `should_stop()` returns `True`.  In that case
     77   the training loop should also stop.  This is why the training loop has to
     78   check for `sv.should_stop()`.
     79 
     80   Exceptions that indicate that the training inputs have been exhausted,
     81   `tf.errors.OutOfRangeError`, also cause `sv.should_stop()` to return `True`
     82   but are not re-raised from the `with` block: they indicate a normal
     83   termination.
     84 
     85   #### Use for multiple replicas
     86 
     87   To train with replicas you deploy the same program in a `Cluster`.
     88   One of the tasks must be identified as the *chief*: the task that handles
     89   initialization, checkpoints, summaries, and recovery.  The other tasks
     90   depend on the *chief* for these services.
     91 
     92   The only change you have to do to the single program code is to indicate
     93   if the program is running as the *chief*.
     94 
     95   ```python
     96   # Choose a task as the chief. This could be based on server_def.task_index,
     97   # or job_def.name, or job_def.tasks. It's entirely up to the end user.
     98   # But there can be only one *chief*.
     99   is_chief = (server_def.task_index == 0)
    100   server = tf.train.Server(server_def)
    101 
    102   with tf.Graph().as_default():
    103     ...add operations to the graph...
    104     # Create a Supervisor that uses log directory on a shared file system.
    105     # Indicate if you are the 'chief'
    106     sv = Supervisor(logdir='/shared_directory/...', is_chief=is_chief)
    107     # Get a Session in a TensorFlow server on the cluster.
    108     with sv.managed_session(server.target) as sess:
    109       # Use the session to train the graph.
    110       while not sv.should_stop():
    111         sess.run(<my_train_op>)
    112   ```
    113 
    114   In the *chief* task, the `Supervisor` works exactly as in the first example
    115   above.  In the other tasks `sv.managed_session()` waits for the Model to have
    116   been initialized before returning a session to the training code.  The
    117   non-chief tasks depend on the chief task for initializing the model.
    118 
    119   If one of the tasks crashes and restarts, `managed_session()`
    120   checks if the Model is initialized.  If yes, it just creates a session and
    121   returns it to the training code that proceeds normally.  If the model needs
    122   to be initialized, the chief task takes care of reinitializing it; the other
    123   tasks just wait for the model to have been initialized.
    124 
    125   NOTE: This modified program still works fine as a single program.
    126   The single program marks itself as the chief.
    127 
    128   #### What `master` string to use
    129 
    130   Whether you are running on your machine or in the cluster you can use the
    131   following values for the --master flag:
    132 
    133   * Specifying `''` requests an in-process session that does not use RPC.
    134 
    135   * Specifying `'local'` requests a session that uses the RPC-based
    136     "Master interface" to run TensorFlow programs. See
    137     @{tf.train.Server.create_local_server} for
    138     details.
    139 
    140   * Specifying `'grpc://hostname:port'` requests a session that uses
    141     the RPC interface to a specific host, and also allows the in-process
    142     master to access remote tensorflow workers. Often, it is
    143     appropriate to pass `server.target` (for some `tf.train.Server`
    144     named `server).
    145 
    146   #### Advanced use
    147 
    148   ##### Launching additional services
    149 
    150   `managed_session()` launches the Checkpoint and Summary services (threads).
    151   If you need more services to run you can simply launch them in the block
    152   controlled by `managed_session()`.
    153 
    154   Example: Start a thread to print losses.  We want this thread to run
    155   every 60 seconds, so we launch it with `sv.loop()`.
    156 
    157   ```python
    158   ...
    159   sv = Supervisor(logdir='/tmp/mydir')
    160   with sv.managed_session(FLAGS.master) as sess:
    161     sv.loop(60, print_loss, (sess, ))
    162     while not sv.should_stop():
    163       sess.run(my_train_op)
    164   ```
    165 
    166   ##### Launching fewer services
    167 
    168   `managed_session()` launches the "summary" and "checkpoint" threads which use
    169   either the optionally `summary_op` and `saver` passed to the constructor, or
    170   default ones created automatically by the supervisor.  If you want to run
    171   your own summary and checkpointing logic, disable these services by passing
    172   `None` to the `summary_op` and `saver` parameters.
    173 
    174   Example: Create summaries manually every 100 steps in the chief.
    175 
    176   ```python
    177   # Create a Supervisor with no automatic summaries.
    178   sv = Supervisor(logdir='/tmp/mydir', is_chief=is_chief, summary_op=None)
    179   # As summary_op was None, managed_session() does not start the
    180   # summary thread.
    181   with sv.managed_session(FLAGS.master) as sess:
    182     for step in xrange(1000000):
    183       if sv.should_stop():
    184         break
    185       if is_chief and step % 100 == 0:
    186         # Create the summary every 100 chief steps.
    187         sv.summary_computed(sess, sess.run(my_summary_op))
    188       else:
    189         # Train normally
    190         sess.run(my_train_op)
    191   ```
    192 
    193   ##### Custom model initialization
    194 
    195   `managed_session()` only supports initializing the model by running an
    196   `init_op` or restoring from the latest checkpoint.  If you have special
    197   initialization needs, see how to specify a `local_init_op` when creating the
    198   supervisor.  You can also use the `SessionManager` directly to create a
    199   session and check if it could be initialized automatically.
    200   """
    201 
    202   # Value to pass for the 'ready_op', 'init_op', 'summary_op', 'saver',
    203   # and 'global_step' parameters of Supervisor.__init__() to indicate that
    204   # the default behavior should be used.
    205   USE_DEFAULT = 0
    206 
    207   @deprecation.deprecated(None,
    208                           "Please switch to tf.train.MonitoredTrainingSession")
    209   def __init__(self,
    210                graph=None,
    211                ready_op=USE_DEFAULT,
    212                ready_for_local_init_op=USE_DEFAULT,
    213                is_chief=True,
    214                init_op=USE_DEFAULT,
    215                init_feed_dict=None,
    216                local_init_op=USE_DEFAULT,
    217                logdir=None,
    218                summary_op=USE_DEFAULT,
    219                saver=USE_DEFAULT,
    220                global_step=USE_DEFAULT,
    221                save_summaries_secs=120,
    222                save_model_secs=600,
    223                recovery_wait_secs=30,
    224                stop_grace_secs=120,
    225                checkpoint_basename="model.ckpt",
    226                session_manager=None,
    227                summary_writer=USE_DEFAULT,
    228                init_fn=None):
    229     """Create a `Supervisor`.
    230 
    231     Args:
    232       graph: A `Graph`.  The graph that the model will use.  Defaults to the
    233         default `Graph`.  The supervisor may add operations to the graph before
    234         creating a session, but the graph should not be modified by the caller
    235         after passing it to the supervisor.
    236       ready_op: 1-D string `Tensor`.  This tensor is evaluated by supervisors in
    237         `prepare_or_wait_for_session()` to check if the model is ready to use.
    238         The model is considered ready if it returns an empty array.  Defaults to
    239         the tensor returned from `tf.report_uninitialized_variables()`  If
    240         `None`, the model is not checked for readiness.
    241       ready_for_local_init_op: 1-D string `Tensor`.  This tensor is evaluated by
    242         supervisors in `prepare_or_wait_for_session()` to check if the model is
    243         ready to run the local_init_op.
    244         The model is considered ready if it returns an empty array.  Defaults to
    245         the tensor returned from
    246         `tf.report_uninitialized_variables(tf.global_variables())`. If `None`,
    247         the model is not checked for readiness before running local_init_op.
    248       is_chief: If True, create a chief supervisor in charge of initializing
    249         and restoring the model.  If False, create a supervisor that relies
    250         on a chief supervisor for inits and restore.
    251       init_op: `Operation`.  Used by chief supervisors to initialize the model
    252         when it can not be recovered.  Defaults to an `Operation` that
    253         initializes all global variables.  If `None`, no initialization is done
    254         automatically unless you pass a value for `init_fn`, see below.
    255       init_feed_dict: A dictionary that maps `Tensor` objects to feed values.
    256         This feed dictionary will be used when `init_op` is evaluated.
    257       local_init_op: `Operation`. Used by all supervisors to run initializations
    258         that should run for every new supervisor instance. By default these
    259         are table initializers and initializers for local variables.
    260         If `None`, no further per supervisor-instance initialization is
    261         done automatically.
    262       logdir: A string.  Optional path to a directory where to checkpoint the
    263         model and log events for the visualizer.  Used by chief supervisors.
    264         The directory will be created if it does not exist.
    265       summary_op: An `Operation` that returns a Summary for the event logs.
    266         Used by chief supervisors if a `logdir` was specified.  Defaults to the
    267         operation returned from summary.merge_all().  If `None`, summaries are
    268         not computed automatically.
    269       saver: A Saver object.  Used by chief supervisors if a `logdir` was
    270         specified.  Defaults to the saved returned by Saver().
    271         If `None`, the model is not saved automatically.
    272       global_step: An integer Tensor of size 1 that counts steps.  The value
    273         from 'global_step' is used in summaries and checkpoint filenames.
    274         Default to the op named 'global_step' in the graph if it exists, is of
    275         rank 1, size 1, and of type tf.int32 or tf.int64.  If `None` the global
    276         step is not recorded in summaries and checkpoint files.  Used by chief
    277         supervisors if a `logdir` was specified.
    278       save_summaries_secs: Number of seconds between the computation of
    279         summaries for the event log.  Defaults to 120 seconds.  Pass 0 to
    280         disable summaries.
    281       save_model_secs: Number of seconds between the creation of model
    282         checkpoints.  Defaults to 600 seconds.  Pass 0 to disable checkpoints.
    283       recovery_wait_secs: Number of seconds between checks that the model
    284         is ready.  Used by supervisors when waiting for a chief supervisor
    285         to initialize or restore the model.  Defaults to 30 seconds.
    286       stop_grace_secs: Grace period, in seconds, given to running threads to
    287         stop when `stop()` is called.  Defaults to 120 seconds.
    288       checkpoint_basename: The basename for checkpoint saving.
    289       session_manager: `SessionManager`, which manages Session creation and
    290         recovery. If it is `None`, a default `SessionManager` will be created
    291         with the set of arguments passed in for backwards compatibility.
    292       summary_writer: `SummaryWriter` to use or `USE_DEFAULT`.  Can be `None`
    293         to indicate that no summaries should be written.
    294       init_fn: Optional callable used to initialize the model. Called
    295         after the optional `init_op` is called.  The callable must accept one
    296         argument, the session being initialized.
    297 
    298     Returns:
    299       A `Supervisor`.
    300 
    301     Raises:
    302       RuntimeError: If called with eager execution enabled.
    303 
    304     @compatibility(eager)
    305     `Supervisor`s are not supported when eager execution is enabled.
    306     @end_compatibility
    307     """
    308     if context.in_eager_mode():
    309       raise RuntimeError("Supervisors are compatible with eager execution.")
    310     # Set default values of arguments.
    311     if graph is None:
    312       graph = ops.get_default_graph()
    313     with graph.as_default():
    314       self._init_ready_op(
    315           ready_op=ready_op, ready_for_local_init_op=ready_for_local_init_op)
    316       self._init_init_op(init_op=init_op, init_feed_dict=init_feed_dict)
    317       self._init_local_init_op(local_init_op=local_init_op)
    318       self._init_saver(saver=saver)
    319       self._init_summary_op(summary_op=summary_op)
    320       self._init_global_step(global_step=global_step)
    321     self._graph = graph
    322     self._meta_graph_def = meta_graph.create_meta_graph_def(
    323         graph_def=graph.as_graph_def(add_shapes=True),
    324         saver_def=self._saver.saver_def if self._saver else None)
    325     self._is_chief = is_chief
    326     self._coord = coordinator.Coordinator()
    327     self._recovery_wait_secs = recovery_wait_secs
    328     self._stop_grace_secs = stop_grace_secs
    329     self._init_fn = init_fn
    330 
    331     # Set all attributes related to checkpointing and writing events to None.
    332     # Afterwards, set them appropriately for chief supervisors, as these are
    333     # the only supervisors that can write checkpoints and events.
    334     self._logdir = None
    335     self._save_summaries_secs = None
    336     self._save_model_secs = None
    337     self._save_path = None
    338     self._summary_writer = None
    339 
    340     if self._is_chief:
    341       self._logdir = logdir
    342       self._save_summaries_secs = save_summaries_secs
    343       self._save_model_secs = save_model_secs
    344       if self._logdir:
    345         self._save_path = os.path.join(self._logdir, checkpoint_basename)
    346       if summary_writer is Supervisor.USE_DEFAULT:
    347         if self._logdir:
    348           self._summary_writer = _summary.FileWriter(self._logdir)
    349       else:
    350         self._summary_writer = summary_writer
    351       self._graph_added_to_summary = False
    352 
    353     self._init_session_manager(session_manager=session_manager)
    354     self._verify_setup()
    355     # The graph is not allowed to change anymore.
    356     graph.finalize()
    357 
    358   def _init_session_manager(self, session_manager=None):
    359     if session_manager is None:
    360       self._session_manager = session_manager_mod.SessionManager(
    361           local_init_op=self._local_init_op,
    362           ready_op=self._ready_op,
    363           ready_for_local_init_op=self._ready_for_local_init_op,
    364           graph=self._graph,
    365           recovery_wait_secs=self._recovery_wait_secs)
    366     else:
    367       self._session_manager = session_manager
    368 
    369   def _get_first_op_from_collection(self, key):
    370     """Returns the first `Operation` from a collection.
    371 
    372     Args:
    373       key: A string collection key.
    374 
    375     Returns:
    376       The first Op found in a collection, or `None` if the collection is empty.
    377     """
    378     try:
    379       op_list = ops.get_collection(key)
    380       if len(op_list) > 1:
    381         logging.info("Found %d %s operations. Returning the first one.",
    382                      len(op_list), key)
    383       if op_list:
    384         return op_list[0]
    385     except LookupError:
    386       pass
    387 
    388     return None
    389 
    390   def _init_ready_op(self,
    391                      ready_op=USE_DEFAULT,
    392                      ready_for_local_init_op=USE_DEFAULT):
    393     """Initializes ready_op.
    394 
    395     Args:
    396       ready_op: `Tensor` to check if the model is initialized.
    397         If it's set to USE_DEFAULT, creates an op that checks all
    398         the variables are initialized.
    399       ready_for_local_init_op: `Tensor` to check if the model is ready to run
    400         local_init_op.
    401         If it's set to USE_DEFAULT, creates an op that checks all
    402         the global variables are initialized.
    403     """
    404     if ready_op is Supervisor.USE_DEFAULT:
    405       ready_op = self._get_first_op_from_collection(ops.GraphKeys.READY_OP)
    406       if ready_op is None:
    407         ready_op = variables.report_uninitialized_variables()
    408         ops.add_to_collection(ops.GraphKeys.READY_OP, ready_op)
    409     self._ready_op = ready_op
    410 
    411     # ready_for_local_init_op defaults to None for backward compatibility
    412     if ready_for_local_init_op is Supervisor.USE_DEFAULT:
    413       ready_for_local_init_op = self._get_first_op_from_collection(
    414           ops.GraphKeys.READY_FOR_LOCAL_INIT_OP)
    415     self._ready_for_local_init_op = ready_for_local_init_op
    416 
    417   def _init_init_op(self, init_op=USE_DEFAULT, init_feed_dict=None):
    418     """Initializes init_op.
    419 
    420     Args:
    421       init_op: `Operation` to initialize the variables. If set to USE_DEFAULT,
    422         create an op that initializes all variables and tables.
    423       init_feed_dict: A dictionary that maps `Tensor` objects to feed values.
    424         This feed dictionary will be used when `init_op` is evaluated.
    425     """
    426     if init_op is Supervisor.USE_DEFAULT:
    427       init_op = self._get_first_op_from_collection(ops.GraphKeys.INIT_OP)
    428       if init_op is None:
    429         init_op = variables.global_variables_initializer()
    430         ops.add_to_collection(ops.GraphKeys.INIT_OP, init_op)
    431     self._init_op = init_op
    432     self._init_feed_dict = init_feed_dict
    433 
    434   def _init_local_init_op(self, local_init_op=USE_DEFAULT):
    435     """Initializes local_init_op.
    436 
    437     Args:
    438       local_init_op: `Operation` run for every new supervisor instance. If set
    439       to USE_DEFAULT, use the first op from the GraphKeys.LOCAL_INIT_OP
    440       collection. If the collection is empty, create an op that initializes
    441       all local variables and all tables.
    442     """
    443     if local_init_op is Supervisor.USE_DEFAULT:
    444       local_init_op = self._get_first_op_from_collection(
    445           ops.GraphKeys.LOCAL_INIT_OP)
    446       if local_init_op is None:
    447         op_list = [
    448             variables.local_variables_initializer(),
    449             lookup_ops.tables_initializer()
    450         ]
    451         if op_list:
    452           local_init_op = control_flow_ops.group(*op_list)
    453           ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op)
    454     self._local_init_op = local_init_op
    455 
    456   def _init_saver(self, saver=USE_DEFAULT):
    457     """Initializes saver.
    458 
    459     Args:
    460       saver: A `Saver` object. If set to USE_DEFAULT, create one that
    461         saves all the variables.
    462     """
    463     if saver is Supervisor.USE_DEFAULT:
    464       saver = self._get_first_op_from_collection(ops.GraphKeys.SAVERS)
    465       if saver is None and variables.global_variables():
    466         saver = saver_mod.Saver()
    467         ops.add_to_collection(ops.GraphKeys.SAVERS, saver)
    468     self._saver = saver
    469 
    470   def _init_summary_op(self, summary_op=USE_DEFAULT):
    471     """Initializes summary_op.
    472 
    473     Args:
    474       summary_op: An Operation that returns a Summary for the event logs.
    475         If set to USE_DEFAULT, create an op that merges all the summaries.
    476     """
    477     if summary_op is Supervisor.USE_DEFAULT:
    478       summary_op = self._get_first_op_from_collection(ops.GraphKeys.SUMMARY_OP)
    479       if summary_op is None:
    480         summary_op = _summary.merge_all()
    481         if summary_op is not None:
    482           ops.add_to_collection(ops.GraphKeys.SUMMARY_OP, summary_op)
    483     self._summary_op = summary_op
    484 
    485   def _init_global_step(self, global_step=USE_DEFAULT):
    486     """Initializes global_step.
    487 
    488     Args:
    489       global_step: An integer Tensor of size 1 that counts steps. If
    490         set to USE_DEFAULT, creates global_step tensor.
    491     """
    492     if global_step is Supervisor.USE_DEFAULT:
    493       global_step = self._get_first_op_from_collection(
    494           ops.GraphKeys.GLOBAL_STEP)
    495       if global_step is None:
    496         global_step = self._default_global_step_tensor()
    497         if global_step is not None:
    498           ops.add_to_collection(ops.GraphKeys.GLOBAL_STEP, global_step)
    499     self._global_step = global_step
    500 
    501   @property
    502   def is_chief(self):
    503     """Return True if this is a chief supervisor.
    504 
    505     Returns:
    506       A bool.
    507     """
    508     return self._is_chief
    509 
    510   @property
    511   def session_manager(self):
    512     """Return the SessionManager used by the Supervisor.
    513 
    514     Returns:
    515       A SessionManager object.
    516     """
    517     return self._session_manager
    518 
    519   @property
    520   def coord(self):
    521     """Return the Coordinator used by the Supervisor.
    522 
    523     The Coordinator can be useful if you want to run multiple threads
    524     during your training.
    525 
    526     Returns:
    527       A Coordinator object.
    528     """
    529     return self._coord
    530 
    531   @property
    532   def init_op(self):
    533     """Return the Init Op used by the supervisor.
    534 
    535     Returns:
    536       An Op or `None`.
    537     """
    538     return self._init_op
    539 
    540   @property
    541   def init_feed_dict(self):
    542     """Return the feed dictionary used when evaluating the `init_op`.
    543 
    544     Returns:
    545       A feed dictionary or `None`.
    546     """
    547     return self._init_feed_dict
    548 
    549   @property
    550   def ready_op(self):
    551     """Return the Ready Op used by the supervisor.
    552 
    553     Returns:
    554       An Op or `None`.
    555     """
    556     return self._ready_op
    557 
    558   @property
    559   def ready_for_local_init_op(self):
    560     return self._ready_for_local_init_op
    561 
    562   @property
    563   def summary_writer(self):
    564     """Return the SummaryWriter used by the chief supervisor.
    565 
    566     Returns:
    567       A SummaryWriter.
    568     """
    569     return self._summary_writer
    570 
    571   @property
    572   def summary_op(self):
    573     """Return the Summary Tensor used by the chief supervisor.
    574 
    575     Returns:
    576       A string Tensor for the summary or `None`.
    577     """
    578     return self._summary_op
    579 
    580   @property
    581   def save_summaries_secs(self):
    582     """Return the delay between summary computations.
    583 
    584     Returns:
    585       A timestamp.
    586     """
    587     return self._save_summaries_secs
    588 
    589   @property
    590   def global_step(self):
    591     """Return the global_step Tensor used by the supervisor.
    592 
    593     Returns:
    594       An integer Tensor for the global_step.
    595     """
    596     return self._global_step
    597 
    598   @property
    599   def saver(self):
    600     """Return the Saver used by the supervisor.
    601 
    602     Returns:
    603       A Saver object.
    604     """
    605     return self._saver
    606 
    607   @property
    608   def save_model_secs(self):
    609     """Return the delay between checkpoints.
    610 
    611     Returns:
    612       A timestamp.
    613     """
    614     return self._save_model_secs
    615 
    616   @property
    617   def save_path(self):
    618     """Return the save path used by the supervisor.
    619 
    620     Returns:
    621       A string.
    622     """
    623     return self._save_path
    624 
    625   def _write_graph(self):
    626     """Writes graph_def to `logdir` and adds it to summary if applicable."""
    627     assert self._is_chief
    628     if self._logdir:
    629       training_util.write_graph(self._graph.as_graph_def(add_shapes=True),
    630                                 self._logdir, "graph.pbtxt")
    631     if self._summary_writer and not self._graph_added_to_summary:
    632       self._summary_writer.add_graph(self._graph)
    633       self._summary_writer.add_meta_graph(self._meta_graph_def)
    634       self._graph_added_to_summary = True
    635 
    636   def start_standard_services(self, sess):
    637     """Start the standard services for 'sess'.
    638 
    639     This starts services in the background.  The services started depend
    640     on the parameters to the constructor and may include:
    641 
    642       - A Summary thread computing summaries every save_summaries_secs.
    643       - A Checkpoint thread saving the model every save_model_secs.
    644       - A StepCounter thread measure step time.
    645 
    646     Args:
    647       sess: A Session.
    648 
    649     Returns:
    650       A list of threads that are running the standard services.  You can use
    651       the Supervisor's Coordinator to join these threads with:
    652         sv.coord.Join(<list of threads>)
    653 
    654     Raises:
    655       RuntimeError: If called with a non-chief Supervisor.
    656       ValueError: If not `logdir` was passed to the constructor as the
    657         services need a log directory.
    658     """
    659     if not self._is_chief:
    660       raise RuntimeError("Only chief supervisor can start standard services. "
    661                          "Because only chief supervisors can write events.")
    662 
    663     if not self._logdir:
    664       logging.warning("Standard services need a 'logdir' "
    665                       "passed to the SessionManager")
    666       return
    667 
    668     if self._global_step is not None and self._summary_writer:
    669       # Only add the session log if we keep track of global step.
    670       # TensorBoard cannot use START message for purging expired events
    671       # if there is no step value.
    672       current_step = training_util.global_step(sess, self._global_step)
    673       self._summary_writer.add_session_log(
    674           SessionLog(status=SessionLog.START),
    675           current_step)
    676 
    677     threads = []
    678     if self._save_summaries_secs and self._summary_writer:
    679       if self._summary_op is not None:
    680         threads.append(SVSummaryThread(self, sess))
    681       if self._global_step is not None:
    682         threads.append(SVStepCounterThread(self, sess))
    683     if self.saver and self._save_model_secs:
    684       threads.append(SVTimerCheckpointThread(self, sess))
    685     for t in threads:
    686       t.start()
    687     return threads
    688 
    689   def prepare_or_wait_for_session(self, master="", config=None,
    690                                   wait_for_checkpoint=False,
    691                                   max_wait_secs=7200,
    692                                   start_standard_services=True):
    693     """Make sure the model is ready to be used.
    694 
    695     Create a session on 'master', recovering or initializing the model as
    696     needed, or wait for a session to be ready.  If running as the chief
    697     and `start_standard_service` is set to True, also call the session
    698     manager to start the standard services.
    699 
    700     Args:
    701       master: name of the TensorFlow master to use.  See the `tf.Session`
    702         constructor for how this is interpreted.
    703       config: Optional ConfigProto proto used to configure the session,
    704         which is passed as-is to create the session.
    705       wait_for_checkpoint: Whether we should wait for the availability of a
    706         checkpoint before creating Session. Defaults to False.
    707       max_wait_secs: Maximum time to wait for the session to become available.
    708       start_standard_services: Whether to start the standard services and the
    709         queue runners.
    710 
    711     Returns:
    712       A Session object that can be used to drive the model.
    713     """
    714     # For users who recreate the session with prepare_or_wait_for_session(), we
    715     # need to clear the coordinator's stop_event so that threads managed by the
    716     # coordinator can run.
    717     self._coord.clear_stop()
    718     if self._summary_writer:
    719       self._summary_writer.reopen()
    720 
    721     if self._is_chief:
    722       sess = self._session_manager.prepare_session(
    723           master, init_op=self.init_op, saver=self.saver,
    724           checkpoint_dir=self._logdir, wait_for_checkpoint=wait_for_checkpoint,
    725           max_wait_secs=max_wait_secs, config=config,
    726           init_feed_dict=self._init_feed_dict, init_fn=self._init_fn)
    727       self._write_graph()
    728       if start_standard_services:
    729         logging.info("Starting standard services.")
    730         self.start_standard_services(sess)
    731     else:
    732       sess = self._session_manager.wait_for_session(master,
    733                                                     config=config,
    734                                                     max_wait_secs=max_wait_secs)
    735     if start_standard_services:
    736       logging.info("Starting queue runners.")
    737       self.start_queue_runners(sess)
    738     return sess
    739 
    740   def start_queue_runners(self, sess, queue_runners=None):
    741     """Start threads for `QueueRunners`.
    742 
    743     Note that the queue runners collected in the graph key `QUEUE_RUNNERS`
    744     are already started automatically when you create a session with the
    745     supervisor, so unless you have non-collected queue runners to start
    746     you do not need to call this explicitly.
    747 
    748     Args:
    749       sess: A `Session`.
    750       queue_runners: A list of `QueueRunners`. If not specified, we'll use the
    751         list of queue runners gathered in the graph under the key
    752         `GraphKeys.QUEUE_RUNNERS`.
    753 
    754     Returns:
    755       The list of threads started for the `QueueRunners`.
    756 
    757     Raises:
    758       RuntimeError: If called with eager execution enabled.
    759 
    760     @compatibility(eager)
    761     Queues are not compatible with eager execution. To ingest data when eager
    762     execution is enabled, use the `tf.data` API.
    763     @end_compatibility
    764     """
    765     if context.in_eager_mode():
    766       raise RuntimeError("Queues are not compatible with eager execution.")
    767     if queue_runners is None:
    768       queue_runners = self._graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS)
    769     threads = []
    770     for qr in queue_runners:
    771       threads.extend(qr.create_threads(sess, coord=self._coord, daemon=True,
    772                                        start=True))
    773     return threads
    774 
    775   def loop(self, timer_interval_secs, target, args=None, kwargs=None):
    776     """Start a LooperThread that calls a function periodically.
    777 
    778     If `timer_interval_secs` is None the thread calls `target(*args, **kwargs)`
    779     repeatedly.  Otherwise it calls it every `timer_interval_secs`
    780     seconds.  The thread terminates when a stop is requested.
    781 
    782     The started thread is added to the list of threads managed by the supervisor
    783     so it does not need to be passed to the `stop()` method.
    784 
    785     Args:
    786       timer_interval_secs: Number. Time boundaries at which to call `target`.
    787       target: A callable object.
    788       args: Optional arguments to pass to `target` when calling it.
    789       kwargs: Optional keyword arguments to pass to `target` when calling it.
    790 
    791     Returns:
    792       The started thread.
    793     """
    794     looper = coordinator.LooperThread(self._coord, timer_interval_secs,
    795                                       target=target, args=args, kwargs=kwargs)
    796     looper.start()
    797     return looper
    798 
    799   def stop(self,
    800            threads=None,
    801            close_summary_writer=True,
    802            ignore_live_threads=False):
    803     """Stop the services and the coordinator.
    804 
    805     This does not close the session.
    806 
    807     Args:
    808       threads: Optional list of threads to join with the coordinator.  If
    809         `None`, defaults to the threads running the standard services, the
    810         threads started for `QueueRunners`, and the threads started by the
    811         `loop()` method.  To wait on additional threads, pass the
    812         list in this parameter.
    813       close_summary_writer: Whether to close the `summary_writer`.  Defaults to
    814         `True` if the summary writer was created by the supervisor, `False`
    815         otherwise.
    816       ignore_live_threads: If `True` ignores threads that remain running after
    817         a grace period when joining threads via the coordinator, instead of
    818         raising a RuntimeError.
    819     """
    820     self._coord.request_stop()
    821     try:
    822       # coord.join() re-raises the first reported exception; the "finally"
    823       # block ensures that we clean up whether or not an exception was
    824       # reported.
    825       self._coord.join(
    826           threads,
    827           stop_grace_period_secs=self._stop_grace_secs,
    828           ignore_live_threads=ignore_live_threads)
    829     finally:
    830       # Close the writer last, in case one of the running threads was using it.
    831       if close_summary_writer and self._summary_writer:
    832         # Stop messages are not logged with event.step,
    833         # since the session may have already terminated.
    834         self._summary_writer.add_session_log(SessionLog(status=SessionLog.STOP))
    835         self._summary_writer.close()
    836         self._graph_added_to_summary = False
    837 
    838   def request_stop(self, ex=None):
    839     """Request that the coordinator stop the threads.
    840 
    841     See `Coordinator.request_stop()`.
    842 
    843     Args:
    844       ex: Optional `Exception`, or Python `exc_info` tuple as returned by
    845         `sys.exc_info()`.  If this is the first call to `request_stop()` the
    846         corresponding exception is recorded and re-raised from `join()`.
    847     """
    848     self._coord.request_stop(ex=ex)
    849 
    850   def should_stop(self):
    851     """Check if the coordinator was told to stop.
    852 
    853     See `Coordinator.should_stop()`.
    854 
    855     Returns:
    856       True if the coordinator was told to stop, False otherwise.
    857     """
    858     return self._coord.should_stop()
    859 
    860   def stop_on_exception(self):
    861     """Context handler to stop the supervisor when an exception is raised.
    862 
    863     See `Coordinator.stop_on_exception()`.
    864 
    865     Returns:
    866       A context handler.
    867     """
    868     return self._coord.stop_on_exception()
    869 
    870   def wait_for_stop(self):
    871     """Block waiting for the coordinator to stop."""
    872     self._coord.wait_for_stop()
    873 
    874   def summary_computed(self, sess, summary, global_step=None):
    875     """Indicate that a summary was computed.
    876 
    877     Args:
    878       sess: A `Session` object.
    879       summary: A Summary proto, or a string holding a serialized summary proto.
    880       global_step: Int. global step this summary is associated with. If `None`,
    881         it will try to fetch the current step.
    882 
    883     Raises:
    884       TypeError: if 'summary' is not a Summary proto or a string.
    885       RuntimeError: if the Supervisor was created without a `logdir`.
    886     """
    887     if not self._summary_writer:
    888       raise RuntimeError("Writing a summary requires a summary writer.")
    889     if global_step is None and self.global_step is not None:
    890       global_step = training_util.global_step(sess, self.global_step)
    891     self._summary_writer.add_summary(summary, global_step)
    892 
    893   def _default_global_step_tensor(self):
    894     """Returns the global_step from the default graph.
    895 
    896     Returns:
    897       The global step `Tensor` or `None`.
    898     """
    899     try:
    900       gs = ops.get_default_graph().get_tensor_by_name("global_step:0")
    901       if gs.dtype.base_dtype in [dtypes.int32, dtypes.int64]:
    902         return gs
    903       else:
    904         logging.warning("Found 'global_step' is not an int type: %s", gs.dtype)
    905         return None
    906     except KeyError:
    907       return None
    908 
    909   def _verify_setup(self):
    910     """Check that all is good.
    911 
    912     Raises:
    913       ValueError: If something is not good.
    914     """
    915     # Not running as chief means that replicas are used.
    916     # In that case all Variables must have their device set.
    917     if not self._is_chief:
    918       for op in self._graph.get_operations():
    919         if op.type in ["Variable", "VariableV2"] and not op.device:
    920           raise ValueError("When using replicas, all Variables must have "
    921                            "their device set: %s" % op)
    922 
    923   # pylint: disable=g-doc-return-or-yield,broad-except
    924   @contextlib.contextmanager
    925   def managed_session(self, master="", config=None,
    926                       start_standard_services=True,
    927                       close_summary_writer=True):
    928     """Returns a context manager for a managed session.
    929 
    930     This context manager creates and automatically recovers a session.  It
    931     optionally starts the standard services that handle checkpoints and
    932     summaries.  It monitors exceptions raised from the `with` block or from the
    933     services and stops the supervisor as needed.
    934 
    935     The context manager is typically used as follows:
    936 
    937     ```python
    938     def train():
    939       sv = tf.train.Supervisor(...)
    940       with sv.managed_session(<master>) as sess:
    941         for step in xrange(..):
    942           if sv.should_stop():
    943             break
    944           sess.run(<my training op>)
    945           ...do other things needed at each training step...
    946     ```
    947 
    948     An exception raised from the `with` block or one of the service threads is
    949     raised again when the block exits.  This is done after stopping all threads
    950     and closing the session.  For example, an `AbortedError` exception, raised
    951     in case of preemption of one of the workers in a distributed model, is
    952     raised again when the block exits.
    953 
    954     If you want to retry the training loop in case of preemption you can do it
    955     as follows:
    956 
    957     ```python
    958     def main(...):
    959       while True
    960         try:
    961           train()
    962         except tf.errors.Aborted:
    963           pass
    964     ```
    965 
    966     As a special case, exceptions used for control flow, such as
    967     `OutOfRangeError` which reports that input queues are exhausted, are not
    968     raised again from the `with` block: they indicate a clean termination of
    969     the training loop and are considered normal termination.
    970 
    971     Args:
    972       master: name of the TensorFlow master to use.  See the `tf.Session`
    973         constructor for how this is interpreted.
    974       config: Optional `ConfigProto` proto used to configure the session.
    975         Passed as-is to create the session.
    976       start_standard_services: Whether to start the standard services,
    977         such as checkpoint, summary and step counter.
    978       close_summary_writer: Whether to close the summary writer when
    979         closing the session.  Defaults to True.
    980 
    981     Returns:
    982       A context manager that yields a `Session` restored from the latest
    983       checkpoint or initialized from scratch if not checkpoint exists.  The
    984       session is closed when the `with` block exits.
    985     """
    986     try:
    987       sess = self.prepare_or_wait_for_session(
    988           master=master, config=config,
    989           start_standard_services=start_standard_services)
    990       yield sess
    991     except Exception as e:
    992       self.request_stop(e)
    993     finally:
    994       try:
    995         # Request all the threads to stop and wait for them to do so.  Any
    996         # exception raised by the threads is raised again from stop().
    997         # Passing stop_grace_period_secs is for blocked enqueue/dequeue
    998         # threads which are not checking for `should_stop()`.  They
    999         # will be stopped when we close the session further down.
   1000         self.stop(close_summary_writer=close_summary_writer)
   1001       finally:
   1002         # Close the session to finish up all pending calls.  We do not care
   1003         # about exceptions raised when closing.  This takes care of
   1004         # blocked enqueue/dequeue calls.
   1005         try:
   1006           sess.close()
   1007         except Exception:
   1008           # Silently ignore exceptions raised by close().
   1009           pass
   1010   # pylint: enable=g-doc-return-or-yield,broad-except
   1011 
   1012 
   1013 class SVSummaryThread(coordinator.LooperThread):
   1014   """A thread to save summaries on a timer."""
   1015 
   1016   def __init__(self, sv, sess):
   1017     """Create a SVSummaryThread.
   1018 
   1019     Args:
   1020       sv: A `Supervisor`.
   1021       sess: A `Session`.
   1022     """
   1023     super(SVSummaryThread, self).__init__(sv.coord, sv.save_summaries_secs)
   1024     self._sv = sv
   1025     self._sess = sess
   1026 
   1027   def run_loop(self):
   1028     if self._sv.global_step is not None:
   1029       summary_strs, global_step = self._sess.run([self._sv.summary_op,
   1030                                                   self._sv.global_step])
   1031     else:
   1032       summary_strs = self._sess.run(self._sv.summary_op)
   1033       global_step = None
   1034     if self._sv.summary_writer:
   1035       logging.info("Recording summary at step %s.", global_step)
   1036       self._sv.summary_writer.add_summary(summary_strs, global_step)
   1037 
   1038 
   1039 class SVStepCounterThread(coordinator.LooperThread):
   1040   """Threads to count steps and measure their duration."""
   1041 
   1042   def __init__(self, sv, sess, step_counter=None):
   1043     """Create a `SVStepCounterThread`.
   1044 
   1045     Args:
   1046       sv: A `Supervisor`.
   1047       sess: A `Session`.
   1048       step_counter: A `Tensor` holding the step counter. By defaults, it uses
   1049         sv.global_step.
   1050     """
   1051     super(SVStepCounterThread, self).__init__(sv.coord, sv.save_summaries_secs)
   1052     self._sv = sv
   1053     self._sess = sess
   1054     self._last_time = 0.0
   1055     self._last_step = 0
   1056     step_counter = sv.global_step if step_counter is None else step_counter
   1057     self._step_counter = step_counter
   1058     self._summary_tag = "%s/sec" % self._step_counter.op.name
   1059 
   1060   def start_loop(self):
   1061     self._last_time = time.time()
   1062     self._last_step = training_util.global_step(
   1063         self._sess, self._step_counter)
   1064 
   1065   def run_loop(self):
   1066     # Count the steps.
   1067     current_step = training_util.global_step(self._sess, self._step_counter)
   1068     added_steps = current_step - self._last_step
   1069     self._last_step = current_step
   1070     # Measure the elapsed time.
   1071     current_time = time.time()
   1072     elapsed_time = current_time - self._last_time
   1073     self._last_time = current_time
   1074     # Reports the number of steps done per second
   1075     if elapsed_time > 0.:
   1076       steps_per_sec = added_steps / elapsed_time
   1077     else:
   1078       steps_per_sec = float("inf")
   1079     summary = Summary(value=[Summary.Value(tag=self._summary_tag,
   1080                                            simple_value=steps_per_sec)])
   1081     if self._sv.summary_writer:
   1082       self._sv.summary_writer.add_summary(summary, current_step)
   1083     logging.log_first_n(logging.INFO, "%s: %g", 10,
   1084                         self._summary_tag, steps_per_sec)
   1085 
   1086 
   1087 class SVTimerCheckpointThread(coordinator.LooperThread):
   1088   """A thread to checkpoint on a timer."""
   1089 
   1090   def __init__(self, sv, sess):
   1091     """Create a `SVTimerCheckpointThread`.
   1092 
   1093     Args:
   1094       sv: A `Supervisor`.
   1095       sess: A `Session`.
   1096     """
   1097     super(SVTimerCheckpointThread, self).__init__(sv.coord, sv.save_model_secs)
   1098     self._sv = sv
   1099     self._sess = sess
   1100 
   1101   def run_loop(self):
   1102     logging.info("Saving checkpoint to path %s", self._sv.save_path)
   1103     self._sv.saver.save(self._sess, self._sv.save_path,
   1104                         global_step=self._sv.global_step)
   1105     if self._sv.summary_writer and self._sv.global_step is not None:
   1106       current_step = training_util.global_step(self._sess, self._sv.global_step)
   1107       self._sv.summary_writer.add_session_log(
   1108           SessionLog(status=SessionLog.CHECKPOINT,
   1109                      checkpoint_path=self._sv.save_path),
   1110           current_step)
   1111 
   1112 
   1113 # TODO(sherrym): All non-PEP8 compliant names will be deprecated shortly.
   1114 setattr(Supervisor, "PrepareSession", Supervisor.prepare_or_wait_for_session)
   1115 setattr(Supervisor, "StartQueueRunners", Supervisor.start_queue_runners)
   1116 setattr(Supervisor, "StartStandardServices", Supervisor.start_standard_services)
   1117 setattr(Supervisor, "Stop", Supervisor.stop)
   1118 setattr(Supervisor, "RequestStop", Supervisor.request_stop)
   1119 setattr(Supervisor, "Loop", Supervisor.loop)
   1120 setattr(Supervisor, "ShouldStop", Supervisor.should_stop)
   1121 setattr(Supervisor, "StopOnException", Supervisor.stop_on_exception)
   1122 setattr(Supervisor, "WaitForStop", Supervisor.wait_for_stop)
   1123 setattr(Supervisor, "SummaryComputed", Supervisor.summary_computed)
   1124