Home | History | Annotate | Download | only in training
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Training helper that checkpoints models and creates session."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import time
     21 import numpy as np
     22 
     23 from tensorflow.python.client import session
     24 from tensorflow.python.distribute import distribution_strategy_context
     25 from tensorflow.python.framework import errors
     26 from tensorflow.python.framework import ops
     27 from tensorflow.python.platform import tf_logging as logging
     28 from tensorflow.python.training import checkpoint_management
     29 from tensorflow.python.util.tf_export import tf_export
     30 
     31 
     32 def _maybe_name(obj):
     33   """Returns object name if it has one, or a message otherwise.
     34 
     35   This is useful for names that apper in error messages.
     36   Args:
     37     obj: Object to get the name of.
     38   Returns:
     39     name, "None", or a "no name" message.
     40   """
     41   if obj is None:
     42     return "None"
     43   elif hasattr(obj, "name"):
     44     return obj.name
     45   else:
     46     return "<no name for %s>" % type(obj)
     47 
     48 
     49 @tf_export(v1=["train.SessionManager"])
     50 class SessionManager(object):
     51   """Training helper that restores from checkpoint and creates session.
     52 
     53   This class is a small wrapper that takes care of session creation and
     54   checkpoint recovery. It also provides functions that to facilitate
     55   coordination among multiple training threads or processes.
     56 
     57   * Checkpointing trained variables as the training progresses.
     58   * Initializing variables on startup, restoring them from the most recent
     59     checkpoint after a crash, or wait for checkpoints to become available.
     60 
     61   ### Usage:
     62 
     63   ```python
     64   with tf.Graph().as_default():
     65      ...add operations to the graph...
     66     # Create a SessionManager that will checkpoint the model in '/tmp/mydir'.
     67     sm = SessionManager()
     68     sess = sm.prepare_session(master, init_op, saver, checkpoint_dir)
     69     # Use the session to train the graph.
     70     while True:
     71       sess.run(<my_train_op>)
     72   ```
     73 
     74   `prepare_session()` initializes or restores a model. It requires `init_op`
     75   and `saver` as an argument.
     76 
     77   A second process could wait for the model to be ready by doing the following:
     78 
     79   ```python
     80   with tf.Graph().as_default():
     81      ...add operations to the graph...
     82     # Create a SessionManager that will wait for the model to become ready.
     83     sm = SessionManager()
     84     sess = sm.wait_for_session(master)
     85     # Use the session to train the graph.
     86     while True:
     87       sess.run(<my_train_op>)
     88   ```
     89 
     90   `wait_for_session()` waits for a model to be initialized by other processes.
     91 
     92   """
     93 
     94   def __init__(self,
     95                local_init_op=None,
     96                ready_op=None,
     97                ready_for_local_init_op=None,
     98                graph=None,
     99                recovery_wait_secs=30,
    100                local_init_run_options=None):
    101     """Creates a SessionManager.
    102 
    103     The `local_init_op` is an `Operation` that is run always after a new session
    104     was created. If `None`, this step is skipped.
    105 
    106     The `ready_op` is an `Operation` used to check if the model is ready.  The
    107     model is considered ready if that operation returns an empty 1D string
    108     tensor. If the operation returns a non empty 1D string tensor, the elements
    109     are concatenated and used to indicate to the user why the model is not
    110     ready.
    111 
    112     The `ready_for_local_init_op` is an `Operation` used to check if the model
    113     is ready to run local_init_op.  The model is considered ready if that
    114     operation returns an empty 1D string tensor. If the operation returns a non
    115     empty 1D string tensor, the elements are concatenated and used to indicate
    116     to the user why the model is not ready.
    117 
    118     If `ready_op` is `None`, the model is not checked for readiness.
    119 
    120     `recovery_wait_secs` is the number of seconds between checks that
    121     the model is ready.  It is used by processes to wait for a model to
    122     be initialized or restored.  Defaults to 30 seconds.
    123 
    124     Args:
    125       local_init_op: An `Operation` run immediately after session creation.
    126          Usually used to initialize tables and local variables.
    127       ready_op: An `Operation` to check if the model is initialized.
    128       ready_for_local_init_op: An `Operation` to check if the model is ready
    129          to run local_init_op.
    130       graph: The `Graph` that the model will use.
    131       recovery_wait_secs: Seconds between checks for the model to be ready.
    132       local_init_run_options: RunOptions to be passed to session.run when
    133         executing the local_init_op.
    134 
    135     Raises:
    136       ValueError: If ready_for_local_init_op is not None but local_init_op is
    137         None
    138     """
    139     # Sets default values of arguments.
    140     if graph is None:
    141       graph = ops.get_default_graph()
    142     self._local_init_op = local_init_op
    143     self._ready_op = ready_op
    144     self._ready_for_local_init_op = ready_for_local_init_op
    145     self._graph = graph
    146     self._recovery_wait_secs = recovery_wait_secs
    147     self._target = None
    148     self._local_init_run_options = local_init_run_options
    149     if ready_for_local_init_op is not None and local_init_op is None:
    150       raise ValueError("If you pass a ready_for_local_init_op "
    151                        "you must also pass a local_init_op "
    152                        ", ready_for_local_init_op [%s]" %
    153                        ready_for_local_init_op)
    154 
    155   def _restore_checkpoint(self,
    156                           master,
    157                           saver=None,
    158                           checkpoint_dir=None,
    159                           checkpoint_filename_with_path=None,
    160                           wait_for_checkpoint=False,
    161                           max_wait_secs=7200,
    162                           config=None):
    163     """Creates a `Session`, and tries to restore a checkpoint.
    164 
    165 
    166     Args:
    167       master: `String` representation of the TensorFlow master to use.
    168       saver: A `Saver` object used to restore a model.
    169       checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the
    170         dir will be used to restore.
    171       checkpoint_filename_with_path: Full file name path to the checkpoint file.
    172       wait_for_checkpoint: Whether to wait for checkpoint to become available.
    173       max_wait_secs: Maximum time to wait for checkpoints to become available.
    174       config: Optional `ConfigProto` proto used to configure the session.
    175 
    176     Returns:
    177       A pair (sess, is_restored) where 'is_restored' is `True` if
    178       the session could be restored, `False` otherwise.
    179 
    180     Raises:
    181       ValueError: If both checkpoint_dir and checkpoint_filename_with_path are
    182         set.
    183     """
    184     self._target = master
    185 
    186     # This is required to so that we initialize the TPU device before
    187     # restoring from checkpoint since we'll be placing variables on the device
    188     # and TPUInitialize wipes out the memory of the device.
    189     strategy = distribution_strategy_context.get_strategy()
    190     if strategy and hasattr(strategy.extended,
    191                             "_experimental_initialize_system"):
    192       strategy.extended._experimental_initialize_system()  # pylint: disable=protected-access
    193 
    194     sess = session.Session(self._target, graph=self._graph, config=config)
    195     if checkpoint_dir and checkpoint_filename_with_path:
    196       raise ValueError("Can not provide both checkpoint_dir and "
    197                        "checkpoint_filename_with_path.")
    198     # If either saver or checkpoint_* is not specified, cannot restore. Just
    199     # return.
    200     if not saver or not (checkpoint_dir or checkpoint_filename_with_path):
    201       return sess, False
    202 
    203     if checkpoint_filename_with_path:
    204       saver.restore(sess, checkpoint_filename_with_path)
    205       return sess, True
    206 
    207     # Waits up until max_wait_secs for checkpoint to become available.
    208     wait_time = 0
    209     ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir)
    210     while not ckpt or not ckpt.model_checkpoint_path:
    211       if wait_for_checkpoint and wait_time < max_wait_secs:
    212         logging.info("Waiting for checkpoint to be available.")
    213         time.sleep(self._recovery_wait_secs)
    214         wait_time += self._recovery_wait_secs
    215         ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir)
    216       else:
    217         return sess, False
    218 
    219     # Loads the checkpoint.
    220     saver.restore(sess, ckpt.model_checkpoint_path)
    221     saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths)
    222     return sess, True
    223 
    224   def prepare_session(self,
    225                       master,
    226                       init_op=None,
    227                       saver=None,
    228                       checkpoint_dir=None,
    229                       checkpoint_filename_with_path=None,
    230                       wait_for_checkpoint=False,
    231                       max_wait_secs=7200,
    232                       config=None,
    233                       init_feed_dict=None,
    234                       init_fn=None):
    235     """Creates a `Session`. Makes sure the model is ready to be used.
    236 
    237     Creates a `Session` on 'master'. If a `saver` object is passed in, and
    238     `checkpoint_dir` points to a directory containing valid checkpoint
    239     files, then it will try to recover the model from checkpoint. If
    240     no checkpoint files are available, and `wait_for_checkpoint` is
    241     `True`, then the process would check every `recovery_wait_secs`,
    242     up to `max_wait_secs`, for recovery to succeed.
    243 
    244     If the model cannot be recovered successfully then it is initialized by
    245     running the `init_op` and calling `init_fn` if they are provided.
    246     The `local_init_op` is also run after init_op and init_fn, regardless of
    247     whether the model was recovered successfully, but only if
    248     `ready_for_local_init_op` passes.
    249 
    250     If the model is recovered from a checkpoint it is assumed that all
    251     global variables have been initialized, in particular neither `init_op`
    252     nor `init_fn` will be executed.
    253 
    254     It is an error if the model cannot be recovered and no `init_op`
    255     or `init_fn` or `local_init_op` are passed.
    256 
    257     Args:
    258       master: `String` representation of the TensorFlow master to use.
    259       init_op: Optional `Operation` used to initialize the model.
    260       saver: A `Saver` object used to restore a model.
    261       checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the
    262         dir will be used to restore.
    263       checkpoint_filename_with_path: Full file name path to the checkpoint file.
    264       wait_for_checkpoint: Whether to wait for checkpoint to become available.
    265       max_wait_secs: Maximum time to wait for checkpoints to become available.
    266       config: Optional `ConfigProto` proto used to configure the session.
    267       init_feed_dict: Optional dictionary that maps `Tensor` objects to feed
    268         values.  This feed dictionary is passed to the session `run()` call when
    269         running the init op.
    270       init_fn: Optional callable used to initialize the model. Called after the
    271         optional `init_op` is called.  The callable must accept one argument,
    272         the session being initialized.
    273 
    274     Returns:
    275       A `Session` object that can be used to drive the model.
    276 
    277     Raises:
    278       RuntimeError: If the model cannot be initialized or recovered.
    279       ValueError: If both checkpoint_dir and checkpoint_filename_with_path are
    280         set.
    281     """
    282 
    283     sess, is_loaded_from_checkpoint = self._restore_checkpoint(
    284         master,
    285         saver,
    286         checkpoint_dir=checkpoint_dir,
    287         checkpoint_filename_with_path=checkpoint_filename_with_path,
    288         wait_for_checkpoint=wait_for_checkpoint,
    289         max_wait_secs=max_wait_secs,
    290         config=config)
    291     if not is_loaded_from_checkpoint:
    292       if init_op is None and not init_fn and self._local_init_op is None:
    293         raise RuntimeError("Model is not initialized and no init_op or "
    294                            "init_fn or local_init_op was given")
    295       if init_op is not None:
    296         sess.run(init_op, feed_dict=init_feed_dict)
    297       if init_fn:
    298         init_fn(sess)
    299 
    300     local_init_success, msg = self._try_run_local_init_op(sess)
    301     if not local_init_success:
    302       raise RuntimeError(
    303           "Init operations did not make model ready for local_init.  "
    304           "Init op: %s, init fn: %s, error: %s" % (_maybe_name(init_op),
    305                                                    init_fn,
    306                                                    msg))
    307 
    308     is_ready, msg = self._model_ready(sess)
    309     if not is_ready:
    310       raise RuntimeError(
    311           "Init operations did not make model ready.  "
    312           "Init op: %s, init fn: %s, local_init_op: %s, error: %s" %
    313           (_maybe_name(init_op), init_fn, self._local_init_op, msg))
    314     return sess
    315 
    316   def recover_session(self,
    317                       master,
    318                       saver=None,
    319                       checkpoint_dir=None,
    320                       checkpoint_filename_with_path=None,
    321                       wait_for_checkpoint=False,
    322                       max_wait_secs=7200,
    323                       config=None):
    324     """Creates a `Session`, recovering if possible.
    325 
    326     Creates a new session on 'master'.  If the session is not initialized
    327     and can be recovered from a checkpoint, recover it.
    328 
    329     Args:
    330       master: `String` representation of the TensorFlow master to use.
    331       saver: A `Saver` object used to restore a model.
    332       checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the
    333         dir will be used to restore.
    334       checkpoint_filename_with_path: Full file name path to the checkpoint file.
    335       wait_for_checkpoint: Whether to wait for checkpoint to become available.
    336       max_wait_secs: Maximum time to wait for checkpoints to become available.
    337       config: Optional `ConfigProto` proto used to configure the session.
    338 
    339     Returns:
    340       A pair (sess, initialized) where 'initialized' is `True` if
    341       the session could be recovered and initialized, `False` otherwise.
    342 
    343     Raises:
    344       ValueError: If both checkpoint_dir and checkpoint_filename_with_path are
    345         set.
    346     """
    347 
    348     sess, is_loaded_from_checkpoint = self._restore_checkpoint(
    349         master,
    350         saver,
    351         checkpoint_dir=checkpoint_dir,
    352         checkpoint_filename_with_path=checkpoint_filename_with_path,
    353         wait_for_checkpoint=wait_for_checkpoint,
    354         max_wait_secs=max_wait_secs,
    355         config=config)
    356 
    357     # Always try to run local_init_op
    358     local_init_success, msg = self._try_run_local_init_op(sess)
    359 
    360     if not is_loaded_from_checkpoint:
    361       # Do not need to run checks for readiness
    362       return sess, False
    363 
    364     restoring_file = checkpoint_dir or checkpoint_filename_with_path
    365     if not local_init_success:
    366       logging.info(
    367           "Restoring model from %s did not make model ready for local init:"
    368           " %s", restoring_file, msg)
    369       return sess, False
    370 
    371     is_ready, msg = self._model_ready(sess)
    372     if not is_ready:
    373       logging.info("Restoring model from %s did not make model ready: %s",
    374                    restoring_file, msg)
    375       return sess, False
    376 
    377     logging.info("Restored model from %s", restoring_file)
    378     return sess, is_loaded_from_checkpoint
    379 
    380   def wait_for_session(self, master, config=None, max_wait_secs=float("Inf")):
    381     """Creates a new `Session` and waits for model to be ready.
    382 
    383     Creates a new `Session` on 'master'.  Waits for the model to be
    384     initialized or recovered from a checkpoint.  It's expected that
    385     another thread or process will make the model ready, and that this
    386     is intended to be used by threads/processes that participate in a
    387     distributed training configuration where a different thread/process
    388     is responsible for initializing or recovering the model being trained.
    389 
    390     NB: The amount of time this method waits for the session is bounded
    391     by max_wait_secs. By default, this function will wait indefinitely.
    392 
    393     Args:
    394       master: `String` representation of the TensorFlow master to use.
    395       config: Optional ConfigProto proto used to configure the session.
    396       max_wait_secs: Maximum time to wait for the session to become available.
    397 
    398     Returns:
    399       A `Session`. May be None if the operation exceeds the timeout
    400       specified by config.operation_timeout_in_ms.
    401 
    402     Raises:
    403       tf.DeadlineExceededError: if the session is not available after
    404         max_wait_secs.
    405     """
    406     self._target = master
    407 
    408     if max_wait_secs is None:
    409       max_wait_secs = float("Inf")
    410     timer = _CountDownTimer(max_wait_secs)
    411 
    412     while True:
    413       sess = session.Session(self._target, graph=self._graph, config=config)
    414       not_ready_msg = None
    415       not_ready_local_msg = None
    416       local_init_success, not_ready_local_msg = self._try_run_local_init_op(
    417           sess)
    418       if local_init_success:
    419         # Successful if local_init_op is None, or ready_for_local_init_op passes
    420         is_ready, not_ready_msg = self._model_ready(sess)
    421         if is_ready:
    422           return sess
    423 
    424       self._safe_close(sess)
    425 
    426       # Do we have enough time left to try again?
    427       remaining_ms_after_wait = (
    428           timer.secs_remaining() - self._recovery_wait_secs)
    429       if remaining_ms_after_wait < 0:
    430         raise errors.DeadlineExceededError(
    431             None, None,
    432             "Session was not ready after waiting %d secs." % (max_wait_secs,))
    433 
    434       logging.info("Waiting for model to be ready.  "
    435                    "Ready_for_local_init_op:  %s, ready: %s",
    436                    not_ready_local_msg, not_ready_msg)
    437       time.sleep(self._recovery_wait_secs)
    438 
    439   def _safe_close(self, sess):
    440     """Closes a session without raising an exception.
    441 
    442     Just like sess.close() but ignores exceptions.
    443 
    444     Args:
    445       sess: A `Session`.
    446     """
    447     # pylint: disable=broad-except
    448     try:
    449       sess.close()
    450     except Exception:
    451       # Intentionally not logging to avoid user complaints that
    452       # they get cryptic errors.  We really do not care that Close
    453       # fails.
    454       pass
    455     # pylint: enable=broad-except
    456 
    457   def _model_ready(self, sess):
    458     """Checks if the model is ready or not.
    459 
    460     Args:
    461       sess: A `Session`.
    462 
    463     Returns:
    464       A tuple (is_ready, msg), where is_ready is True if ready and False
    465       otherwise, and msg is `None` if the model is ready, a `String` with the
    466       reason why it is not ready otherwise.
    467     """
    468     return _ready(self._ready_op, sess, "Model not ready")
    469 
    470   def _model_ready_for_local_init(self, sess):
    471     """Checks if the model is ready to run local_init_op.
    472 
    473     Args:
    474       sess: A `Session`.
    475 
    476     Returns:
    477       A tuple (is_ready, msg), where is_ready is True if ready to run
    478       local_init_op and False otherwise, and msg is `None` if the model is
    479       ready to run local_init_op, a `String` with the reason why it is not ready
    480       otherwise.
    481     """
    482     return _ready(self._ready_for_local_init_op, sess,
    483                   "Model not ready for local init")
    484 
    485   def _try_run_local_init_op(self, sess):
    486     """Tries to run _local_init_op, if not None, and is ready for local init.
    487 
    488     Args:
    489       sess: A `Session`.
    490 
    491     Returns:
    492       A tuple (is_successful, msg), where is_successful is True if
    493       _local_init_op is None, or we ran _local_init_op, and False otherwise;
    494       and msg is a `String` with the reason why the model was not ready to run
    495       local init.
    496     """
    497     if self._local_init_op is not None:
    498       is_ready_for_local_init, msg = self._model_ready_for_local_init(sess)
    499       if is_ready_for_local_init:
    500         logging.info("Running local_init_op.")
    501         sess.run(self._local_init_op, options=self._local_init_run_options)
    502         logging.info("Done running local_init_op.")
    503         return True, None
    504       else:
    505         return False, msg
    506     return True, None
    507 
    508 
    509 def _ready(op, sess, msg):
    510   """Checks if the model is ready or not, as determined by op.
    511 
    512   Args:
    513     op: An op, either _ready_op or _ready_for_local_init_op, which defines the
    514       readiness of the model.
    515     sess: A `Session`.
    516     msg: A message to log to warning if not ready
    517 
    518   Returns:
    519     A tuple (is_ready, msg), where is_ready is True if ready and False
    520     otherwise, and msg is `None` if the model is ready, a `String` with the
    521     reason why it is not ready otherwise.
    522   """
    523   if op is None:
    524     return True, None
    525   else:
    526     try:
    527       ready_value = sess.run(op)
    528       # The model is considered ready if ready_op returns an empty 1-D tensor.
    529       # Also compare to `None` and dtype being int32 for backward
    530       # compatibility.
    531       if (ready_value is None or ready_value.dtype == np.int32 or
    532           ready_value.size == 0):
    533         return True, None
    534       else:
    535         # TODO(sherrym): If a custom ready_op returns other types of tensor,
    536         # or strings other than variable names, this message could be
    537         # confusing.
    538         non_initialized_varnames = ", ".join(
    539             [i.decode("utf-8") for i in ready_value])
    540         return False, "Variables not initialized: " + non_initialized_varnames
    541     except errors.FailedPreconditionError as e:
    542       if "uninitialized" not in str(e):
    543         logging.warning("%s : error [%s]", msg, str(e))
    544         raise e
    545       return False, str(e)
    546 
    547 
    548 class _CountDownTimer(object):
    549 
    550   def __init__(self, duration_secs):
    551     self._start_time_secs = time.time()
    552     self._duration_secs = duration_secs
    553 
    554   def secs_remaining(self):
    555     diff = self._duration_secs - (time.time() - self._start_time_secs)
    556     return max(0, diff)
    557