Home | History | Annotate | Download | only in training
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Training helper that checkpoints models and creates session."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import time
     21 import numpy as np
     22 
     23 from tensorflow.python.client import session
     24 from tensorflow.python.framework import errors
     25 from tensorflow.python.framework import ops
     26 from tensorflow.python.platform import tf_logging as logging
     27 from tensorflow.python.training import saver as saver_mod
     28 from tensorflow.python.util.tf_export import tf_export
     29 
     30 
     31 def _maybe_name(obj):
     32   """Returns object name if it has one, or a message otherwise.
     33 
     34   This is useful for names that apper in error messages.
     35   Args:
     36     obj: Object to get the name of.
     37   Returns:
     38     name, "None", or a "no name" message.
     39   """
     40   if obj is None:
     41     return "None"
     42   elif hasattr(obj, "name"):
     43     return obj.name
     44   else:
     45     return "<no name for %s>" % type(obj)
     46 
     47 
     48 @tf_export("train.SessionManager")
     49 class SessionManager(object):
     50   """Training helper that restores from checkpoint and creates session.
     51 
     52   This class is a small wrapper that takes care of session creation and
     53   checkpoint recovery. It also provides functions that to facilitate
     54   coordination among multiple training threads or processes.
     55 
     56   * Checkpointing trained variables as the training progresses.
     57   * Initializing variables on startup, restoring them from the most recent
     58     checkpoint after a crash, or wait for checkpoints to become available.
     59 
     60   ### Usage:
     61 
     62   ```python
     63   with tf.Graph().as_default():
     64      ...add operations to the graph...
     65     # Create a SessionManager that will checkpoint the model in '/tmp/mydir'.
     66     sm = SessionManager()
     67     sess = sm.prepare_session(master, init_op, saver, checkpoint_dir)
     68     # Use the session to train the graph.
     69     while True:
     70       sess.run(<my_train_op>)
     71   ```
     72 
     73   `prepare_session()` initializes or restores a model. It requires `init_op`
     74   and `saver` as an argument.
     75 
     76   A second process could wait for the model to be ready by doing the following:
     77 
     78   ```python
     79   with tf.Graph().as_default():
     80      ...add operations to the graph...
     81     # Create a SessionManager that will wait for the model to become ready.
     82     sm = SessionManager()
     83     sess = sm.wait_for_session(master)
     84     # Use the session to train the graph.
     85     while True:
     86       sess.run(<my_train_op>)
     87   ```
     88 
     89   `wait_for_session()` waits for a model to be initialized by other processes.
     90 
     91   """
     92 
     93   def __init__(self,
     94                local_init_op=None,
     95                ready_op=None,
     96                ready_for_local_init_op=None,
     97                graph=None,
     98                recovery_wait_secs=30):
     99     """Creates a SessionManager.
    100 
    101     The `local_init_op` is an `Operation` that is run always after a new session
    102     was created. If `None`, this step is skipped.
    103 
    104     The `ready_op` is an `Operation` used to check if the model is ready.  The
    105     model is considered ready if that operation returns an empty 1D string
    106     tensor. If the operation returns a non empty 1D string tensor, the elements
    107     are concatenated and used to indicate to the user why the model is not
    108     ready.
    109 
    110     The `ready_for_local_init_op` is an `Operation` used to check if the model
    111     is ready to run local_init_op.  The model is considered ready if that
    112     operation returns an empty 1D string tensor. If the operation returns a non
    113     empty 1D string tensor, the elements are concatenated and used to indicate
    114     to the user why the model is not ready.
    115 
    116     If `ready_op` is `None`, the model is not checked for readiness.
    117 
    118     `recovery_wait_secs` is the number of seconds between checks that
    119     the model is ready.  It is used by processes to wait for a model to
    120     be initialized or restored.  Defaults to 30 seconds.
    121 
    122     Args:
    123       local_init_op: An `Operation` run immediately after session creation.
    124          Usually used to initialize tables and local variables.
    125       ready_op: An `Operation` to check if the model is initialized.
    126       ready_for_local_init_op: An `Operation` to check if the model is ready
    127          to run local_init_op.
    128       graph: The `Graph` that the model will use.
    129       recovery_wait_secs: Seconds between checks for the model to be ready.
    130 
    131     Raises:
    132       ValueError: If ready_for_local_init_op is not None but local_init_op is
    133         None
    134     """
    135     # Sets default values of arguments.
    136     if graph is None:
    137       graph = ops.get_default_graph()
    138     self._local_init_op = local_init_op
    139     self._ready_op = ready_op
    140     self._ready_for_local_init_op = ready_for_local_init_op
    141     self._graph = graph
    142     self._recovery_wait_secs = recovery_wait_secs
    143     self._target = None
    144     if ready_for_local_init_op is not None and local_init_op is None:
    145       raise ValueError("If you pass a ready_for_local_init_op "
    146                        "you must also pass a local_init_op "
    147                        ", ready_for_local_init_op [%s]" %
    148                        ready_for_local_init_op)
    149 
    150   def _restore_checkpoint(self,
    151                           master,
    152                           saver=None,
    153                           checkpoint_dir=None,
    154                           checkpoint_filename_with_path=None,
    155                           wait_for_checkpoint=False,
    156                           max_wait_secs=7200,
    157                           config=None):
    158     """Creates a `Session`, and tries to restore a checkpoint.
    159 
    160 
    161     Args:
    162       master: `String` representation of the TensorFlow master to use.
    163       saver: A `Saver` object used to restore a model.
    164       checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the
    165         dir will be used to restore.
    166       checkpoint_filename_with_path: Full file name path to the checkpoint file.
    167       wait_for_checkpoint: Whether to wait for checkpoint to become available.
    168       max_wait_secs: Maximum time to wait for checkpoints to become available.
    169       config: Optional `ConfigProto` proto used to configure the session.
    170 
    171     Returns:
    172       A pair (sess, is_restored) where 'is_restored' is `True` if
    173       the session could be restored, `False` otherwise.
    174 
    175     Raises:
    176       ValueError: If both checkpoint_dir and checkpoint_filename_with_path are
    177         set.
    178     """
    179     self._target = master
    180     sess = session.Session(self._target, graph=self._graph, config=config)
    181 
    182     if checkpoint_dir and checkpoint_filename_with_path:
    183       raise ValueError("Can not provide both checkpoint_dir and "
    184                        "checkpoint_filename_with_path.")
    185     # If either saver or checkpoint_* is not specified, cannot restore. Just
    186     # return.
    187     if not saver or not (checkpoint_dir or checkpoint_filename_with_path):
    188       return sess, False
    189 
    190     if checkpoint_filename_with_path:
    191       saver.restore(sess, checkpoint_filename_with_path)
    192       return sess, True
    193 
    194     # Waits up until max_wait_secs for checkpoint to become available.
    195     wait_time = 0
    196     ckpt = saver_mod.get_checkpoint_state(checkpoint_dir)
    197     while not ckpt or not ckpt.model_checkpoint_path:
    198       if wait_for_checkpoint and wait_time < max_wait_secs:
    199         logging.info("Waiting for checkpoint to be available.")
    200         time.sleep(self._recovery_wait_secs)
    201         wait_time += self._recovery_wait_secs
    202         ckpt = saver_mod.get_checkpoint_state(checkpoint_dir)
    203       else:
    204         return sess, False
    205 
    206     # Loads the checkpoint.
    207     saver.restore(sess, ckpt.model_checkpoint_path)
    208     saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths)
    209     return sess, True
    210 
    211   def prepare_session(self,
    212                       master,
    213                       init_op=None,
    214                       saver=None,
    215                       checkpoint_dir=None,
    216                       checkpoint_filename_with_path=None,
    217                       wait_for_checkpoint=False,
    218                       max_wait_secs=7200,
    219                       config=None,
    220                       init_feed_dict=None,
    221                       init_fn=None):
    222     """Creates a `Session`. Makes sure the model is ready to be used.
    223 
    224     Creates a `Session` on 'master'. If a `saver` object is passed in, and
    225     `checkpoint_dir` points to a directory containing valid checkpoint
    226     files, then it will try to recover the model from checkpoint. If
    227     no checkpoint files are available, and `wait_for_checkpoint` is
    228     `True`, then the process would check every `recovery_wait_secs`,
    229     up to `max_wait_secs`, for recovery to succeed.
    230 
    231     If the model cannot be recovered successfully then it is initialized by
    232     either running the provided `init_op`, or calling the provided `init_fn`.
    233     The local_init_op is also run after init_op and init_fn, regardless of
    234     whether the model was recovered successfully, but only if
    235     ready_for_local_init_op passes.
    236 
    237     It is an error if the model cannot be recovered and no `init_op`
    238     or `init_fn` or `local_init_op` are passed.
    239 
    240     Args:
    241       master: `String` representation of the TensorFlow master to use.
    242       init_op: Optional `Operation` used to initialize the model.
    243       saver: A `Saver` object used to restore a model.
    244       checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the
    245         dir will be used to restore.
    246       checkpoint_filename_with_path: Full file name path to the checkpoint file.
    247       wait_for_checkpoint: Whether to wait for checkpoint to become available.
    248       max_wait_secs: Maximum time to wait for checkpoints to become available.
    249       config: Optional `ConfigProto` proto used to configure the session.
    250       init_feed_dict: Optional dictionary that maps `Tensor` objects to feed
    251         values.  This feed dictionary is passed to the session `run()` call when
    252         running the init op.
    253       init_fn: Optional callable used to initialize the model. Called after the
    254         optional `init_op` is called.  The callable must accept one argument,
    255         the session being initialized.
    256 
    257     Returns:
    258       A `Session` object that can be used to drive the model.
    259 
    260     Raises:
    261       RuntimeError: If the model cannot be initialized or recovered.
    262 
    263     Raises:
    264       ValueError: If both checkpoint_dir and checkpoint_filename_with_path are
    265         set.
    266     """
    267 
    268     sess, is_loaded_from_checkpoint = self._restore_checkpoint(
    269         master,
    270         saver,
    271         checkpoint_dir=checkpoint_dir,
    272         checkpoint_filename_with_path=checkpoint_filename_with_path,
    273         wait_for_checkpoint=wait_for_checkpoint,
    274         max_wait_secs=max_wait_secs,
    275         config=config)
    276     if not is_loaded_from_checkpoint:
    277       if init_op is None and not init_fn and self._local_init_op is None:
    278         raise RuntimeError("Model is not initialized and no init_op or "
    279                            "init_fn or local_init_op was given")
    280       if init_op is not None:
    281         sess.run(init_op, feed_dict=init_feed_dict)
    282       if init_fn:
    283         init_fn(sess)
    284 
    285     local_init_success, msg = self._try_run_local_init_op(sess)
    286     if not local_init_success:
    287       raise RuntimeError(
    288           "Init operations did not make model ready for local_init.  "
    289           "Init op: %s, init fn: %s, error: %s" % (_maybe_name(init_op),
    290                                                    init_fn,
    291                                                    msg))
    292 
    293     is_ready, msg = self._model_ready(sess)
    294     if not is_ready:
    295       raise RuntimeError(
    296           "Init operations did not make model ready.  "
    297           "Init op: %s, init fn: %s, local_init_op: %s, error: %s" %
    298           (_maybe_name(init_op), init_fn, self._local_init_op, msg))
    299     return sess
    300 
    301   def recover_session(self,
    302                       master,
    303                       saver=None,
    304                       checkpoint_dir=None,
    305                       checkpoint_filename_with_path=None,
    306                       wait_for_checkpoint=False,
    307                       max_wait_secs=7200,
    308                       config=None):
    309     """Creates a `Session`, recovering if possible.
    310 
    311     Creates a new session on 'master'.  If the session is not initialized
    312     and can be recovered from a checkpoint, recover it.
    313 
    314     Args:
    315       master: `String` representation of the TensorFlow master to use.
    316       saver: A `Saver` object used to restore a model.
    317       checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the
    318         dir will be used to restore.
    319       checkpoint_filename_with_path: Full file name path to the checkpoint file.
    320       wait_for_checkpoint: Whether to wait for checkpoint to become available.
    321       max_wait_secs: Maximum time to wait for checkpoints to become available.
    322       config: Optional `ConfigProto` proto used to configure the session.
    323 
    324     Returns:
    325       A pair (sess, initialized) where 'initialized' is `True` if
    326       the session could be recovered and initialized, `False` otherwise.
    327 
    328     Raises:
    329       ValueError: If both checkpoint_dir and checkpoint_filename_with_path are
    330         set.
    331     """
    332 
    333     sess, is_loaded_from_checkpoint = self._restore_checkpoint(
    334         master,
    335         saver,
    336         checkpoint_dir=checkpoint_dir,
    337         checkpoint_filename_with_path=checkpoint_filename_with_path,
    338         wait_for_checkpoint=wait_for_checkpoint,
    339         max_wait_secs=max_wait_secs,
    340         config=config)
    341 
    342     # Always try to run local_init_op
    343     local_init_success, msg = self._try_run_local_init_op(sess)
    344 
    345     if not is_loaded_from_checkpoint:
    346       # Do not need to run checks for readiness
    347       return sess, False
    348 
    349     restoring_file = checkpoint_dir or checkpoint_filename_with_path
    350     if not local_init_success:
    351       logging.info(
    352           "Restoring model from %s did not make model ready for local init:"
    353           " %s", restoring_file, msg)
    354       return sess, False
    355 
    356     is_ready, msg = self._model_ready(sess)
    357     if not is_ready:
    358       logging.info("Restoring model from %s did not make model ready: %s",
    359                    restoring_file, msg)
    360       return sess, False
    361 
    362     logging.info("Restored model from %s", restoring_file)
    363     return sess, is_loaded_from_checkpoint
    364 
    365   def wait_for_session(self, master, config=None, max_wait_secs=float("Inf")):
    366     """Creates a new `Session` and waits for model to be ready.
    367 
    368     Creates a new `Session` on 'master'.  Waits for the model to be
    369     initialized or recovered from a checkpoint.  It's expected that
    370     another thread or process will make the model ready, and that this
    371     is intended to be used by threads/processes that participate in a
    372     distributed training configuration where a different thread/process
    373     is responsible for initializing or recovering the model being trained.
    374 
    375     NB: The amount of time this method waits for the session is bounded
    376     by max_wait_secs. By default, this function will wait indefinitely.
    377 
    378     Args:
    379       master: `String` representation of the TensorFlow master to use.
    380       config: Optional ConfigProto proto used to configure the session.
    381       max_wait_secs: Maximum time to wait for the session to become available.
    382 
    383     Returns:
    384       A `Session`. May be None if the operation exceeds the timeout
    385       specified by config.operation_timeout_in_ms.
    386 
    387     Raises:
    388       tf.DeadlineExceededError: if the session is not available after
    389         max_wait_secs.
    390     """
    391     self._target = master
    392 
    393     if max_wait_secs is None:
    394       max_wait_secs = float("Inf")
    395     timer = _CountDownTimer(max_wait_secs)
    396 
    397     while True:
    398       sess = session.Session(self._target, graph=self._graph, config=config)
    399       not_ready_msg = None
    400       not_ready_local_msg = None
    401       local_init_success, not_ready_local_msg = self._try_run_local_init_op(
    402           sess)
    403       if local_init_success:
    404         # Successful if local_init_op is None, or ready_for_local_init_op passes
    405         is_ready, not_ready_msg = self._model_ready(sess)
    406         if is_ready:
    407           return sess
    408 
    409       self._safe_close(sess)
    410 
    411       # Do we have enough time left to try again?
    412       remaining_ms_after_wait = (
    413           timer.secs_remaining() - self._recovery_wait_secs)
    414       if remaining_ms_after_wait < 0:
    415         raise errors.DeadlineExceededError(
    416             None, None,
    417             "Session was not ready after waiting %d secs." % (max_wait_secs,))
    418 
    419       logging.info("Waiting for model to be ready.  "
    420                    "Ready_for_local_init_op:  %s, ready: %s",
    421                    not_ready_local_msg, not_ready_msg)
    422       time.sleep(self._recovery_wait_secs)
    423 
    424   def _safe_close(self, sess):
    425     """Closes a session without raising an exception.
    426 
    427     Just like sess.close() but ignores exceptions.
    428 
    429     Args:
    430       sess: A `Session`.
    431     """
    432     # pylint: disable=broad-except
    433     try:
    434       sess.close()
    435     except Exception:
    436       # Intentionally not logging to avoid user complaints that
    437       # they get cryptic errors.  We really do not care that Close
    438       # fails.
    439       pass
    440     # pylint: enable=broad-except
    441 
    442   def _model_ready(self, sess):
    443     """Checks if the model is ready or not.
    444 
    445     Args:
    446       sess: A `Session`.
    447 
    448     Returns:
    449       A tuple (is_ready, msg), where is_ready is True if ready and False
    450       otherwise, and msg is `None` if the model is ready, a `String` with the
    451       reason why it is not ready otherwise.
    452     """
    453     return _ready(self._ready_op, sess, "Model not ready")
    454 
    455   def _model_ready_for_local_init(self, sess):
    456     """Checks if the model is ready to run local_init_op.
    457 
    458     Args:
    459       sess: A `Session`.
    460 
    461     Returns:
    462       A tuple (is_ready, msg), where is_ready is True if ready to run
    463       local_init_op and False otherwise, and msg is `None` if the model is
    464       ready to run local_init_op, a `String` with the reason why it is not ready
    465       otherwise.
    466     """
    467     return _ready(self._ready_for_local_init_op, sess,
    468                   "Model not ready for local init")
    469 
    470   def _try_run_local_init_op(self, sess):
    471     """Tries to run _local_init_op, if not None, and is ready for local init.
    472 
    473     Args:
    474       sess: A `Session`.
    475 
    476     Returns:
    477       A tuple (is_successful, msg), where is_successful is True if
    478       _local_init_op is None, or we ran _local_init_op, and False otherwise;
    479       and msg is a `String` with the reason why the model was not ready to run
    480       local init.
    481     """
    482     if self._local_init_op is not None:
    483       is_ready_for_local_init, msg = self._model_ready_for_local_init(sess)
    484       if is_ready_for_local_init:
    485         logging.info("Running local_init_op.")
    486         sess.run(self._local_init_op)
    487         logging.info("Done running local_init_op.")
    488         return True, None
    489       else:
    490         return False, msg
    491     return True, None
    492 
    493 
    494 def _ready(op, sess, msg):
    495   """Checks if the model is ready or not, as determined by op.
    496 
    497   Args:
    498     op: An op, either _ready_op or _ready_for_local_init_op, which defines the
    499       readiness of the model.
    500     sess: A `Session`.
    501     msg: A message to log to warning if not ready
    502 
    503   Returns:
    504     A tuple (is_ready, msg), where is_ready is True if ready and False
    505     otherwise, and msg is `None` if the model is ready, a `String` with the
    506     reason why it is not ready otherwise.
    507   """
    508   if op is None:
    509     return True, None
    510   else:
    511     try:
    512       ready_value = sess.run(op)
    513       # The model is considered ready if ready_op returns an empty 1-D tensor.
    514       # Also compare to `None` and dtype being int32 for backward
    515       # compatibility.
    516       if (ready_value is None or ready_value.dtype == np.int32 or
    517           ready_value.size == 0):
    518         return True, None
    519       else:
    520         # TODO(sherrym): If a custom ready_op returns other types of tensor,
    521         # or strings other than variable names, this message could be
    522         # confusing.
    523         non_initialized_varnames = ", ".join(
    524             [i.decode("utf-8") for i in ready_value])
    525         return False, "Variables not initialized: " + non_initialized_varnames
    526     except errors.FailedPreconditionError as e:
    527       if "uninitialized" not in str(e):
    528         logging.warning("%s : error [%s]", msg, str(e))
    529         raise e
    530       return False, str(e)
    531 
    532 
    533 class _CountDownTimer(object):
    534 
    535   def __init__(self, duration_secs):
    536     self._start_time_secs = time.time()
    537     self._duration_secs = duration_secs
    538 
    539   def secs_remaining(self):
    540     diff = self._duration_secs - (time.time() - self._start_time_secs)
    541     return max(0, diff)
    542