Home | History | Annotate | Download | only in training
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Coordinator to help multiple threads stop when requested."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import contextlib
     21 import sys
     22 import threading
     23 import time
     24 
     25 import six
     26 
     27 from tensorflow.python.framework import errors
     28 from tensorflow.python.platform import tf_logging as logging
     29 from tensorflow.python.util import compat
     30 from tensorflow.python.util.tf_export import tf_export
     31 
     32 
     33 @tf_export("train.Coordinator")
     34 class Coordinator(object):
     35   """A coordinator for threads.
     36 
     37   This class implements a simple mechanism to coordinate the termination of a
     38   set of threads.
     39 
     40   #### Usage:
     41 
     42   ```python
     43   # Create a coordinator.
     44   coord = Coordinator()
     45   # Start a number of threads, passing the coordinator to each of them.
     46   ...start thread 1...(coord, ...)
     47   ...start thread N...(coord, ...)
     48   # Wait for all the threads to terminate.
     49   coord.join(threads)
     50   ```
     51 
     52   Any of the threads can call `coord.request_stop()` to ask for all the threads
     53   to stop.  To cooperate with the requests, each thread must check for
     54   `coord.should_stop()` on a regular basis.  `coord.should_stop()` returns
     55   `True` as soon as `coord.request_stop()` has been called.
     56 
     57   A typical thread running with a coordinator will do something like:
     58 
     59   ```python
     60   while not coord.should_stop():
     61     ...do some work...
     62   ```
     63 
     64   #### Exception handling:
     65 
     66   A thread can report an exception to the coordinator as part of the
     67   `request_stop()` call.  The exception will be re-raised from the
     68   `coord.join()` call.
     69 
     70   Thread code:
     71 
     72   ```python
     73   try:
     74     while not coord.should_stop():
     75       ...do some work...
     76   except Exception as e:
     77     coord.request_stop(e)
     78   ```
     79 
     80   Main code:
     81 
     82   ```python
     83   try:
     84     ...
     85     coord = Coordinator()
     86     # Start a number of threads, passing the coordinator to each of them.
     87     ...start thread 1...(coord, ...)
     88     ...start thread N...(coord, ...)
     89     # Wait for all the threads to terminate.
     90     coord.join(threads)
     91   except Exception as e:
     92     ...exception that was passed to coord.request_stop()
     93   ```
     94 
     95   To simplify the thread implementation, the Coordinator provides a
     96   context handler `stop_on_exception()` that automatically requests a stop if
     97   an exception is raised.  Using the context handler the thread code above
     98   can be written as:
     99 
    100   ```python
    101   with coord.stop_on_exception():
    102     while not coord.should_stop():
    103       ...do some work...
    104   ```
    105 
    106   #### Grace period for stopping:
    107 
    108   After a thread has called `coord.request_stop()` the other threads have a
    109   fixed time to stop, this is called the 'stop grace period' and defaults to 2
    110   minutes.  If any of the threads is still alive after the grace period expires
    111   `coord.join()` raises a RuntimeError reporting the laggards.
    112 
    113   ```python
    114   try:
    115     ...
    116     coord = Coordinator()
    117     # Start a number of threads, passing the coordinator to each of them.
    118     ...start thread 1...(coord, ...)
    119     ...start thread N...(coord, ...)
    120     # Wait for all the threads to terminate, give them 10s grace period
    121     coord.join(threads, stop_grace_period_secs=10)
    122   except RuntimeError:
    123     ...one of the threads took more than 10s to stop after request_stop()
    124     ...was called.
    125   except Exception:
    126     ...exception that was passed to coord.request_stop()
    127   ```
    128   """
    129 
    130   def __init__(self, clean_stop_exception_types=None):
    131     """Create a new Coordinator.
    132 
    133     Args:
    134       clean_stop_exception_types: Optional tuple of Exception types that should
    135         cause a clean stop of the coordinator. If an exception of one of these
    136         types is reported to `request_stop(ex)` the coordinator will behave as
    137         if `request_stop(None)` was called.  Defaults to
    138         `(tf.errors.OutOfRangeError,)` which is used by input queues to signal
    139         the end of input. When feeding training data from a Python iterator it
    140         is common to add `StopIteration` to this list.
    141     """
    142     if clean_stop_exception_types is None:
    143       clean_stop_exception_types = (errors.OutOfRangeError,)
    144     self._clean_stop_exception_types = tuple(clean_stop_exception_types)
    145     # Protects all attributes.
    146     self._lock = threading.Lock()
    147     # Event set when threads must stop.
    148     self._stop_event = threading.Event()
    149     # Python exc_info to report.
    150     # If not None, it should hold the returned value of sys.exc_info(), which is
    151     # a tuple containing exception (type, value, traceback).
    152     self._exc_info_to_raise = None
    153     # True if we have called join() already.
    154     self._joined = False
    155     # Set of threads registered for joining when join() is called.  These
    156     # threads will be joined in addition to the threads passed to the join()
    157     # call.  It's ok if threads are both registered and passed to the join()
    158     # call.
    159     self._registered_threads = set()
    160 
    161   def _filter_exception(self, ex):
    162     """Check if the exception indicated in 'ex' should be ignored.
    163 
    164     This method examines `ex` to check if it is an exception that should be
    165     reported to the users.  If yes, it returns `ex` as is, otherwise it returns
    166     None.
    167 
    168     The code returns None for exception types listed in
    169     `_clean_stop_exception_types`.
    170 
    171     Args:
    172       ex: None, an `Exception`, or a Python `exc_info` tuple as returned by
    173         `sys.exc_info()`.
    174 
    175     Returns:
    176       ex or None.
    177     """
    178     if isinstance(ex, tuple):
    179       ex2 = ex[1]
    180     else:
    181       ex2 = ex
    182     if isinstance(ex2, self._clean_stop_exception_types):
    183       # Ignore the exception.
    184       ex = None
    185     return ex
    186 
    187   def request_stop(self, ex=None):
    188     """Request that the threads stop.
    189 
    190     After this is called, calls to `should_stop()` will return `True`.
    191 
    192     Note: If an exception is being passed in, in must be in the context of
    193     handling the exception (i.e. `try: ... except Exception as ex: ...`) and not
    194     a newly created one.
    195 
    196     Args:
    197       ex: Optional `Exception`, or Python `exc_info` tuple as returned by
    198         `sys.exc_info()`.  If this is the first call to `request_stop()` the
    199         corresponding exception is recorded and re-raised from `join()`.
    200     """
    201     with self._lock:
    202       ex = self._filter_exception(ex)
    203       # If we have already joined the coordinator the exception will not have a
    204       # chance to be reported, so just raise it normally.  This can happen if
    205       # you continue to use a session have having stopped and joined the
    206       # coordinator threads.
    207       if self._joined:
    208         if isinstance(ex, tuple):
    209           six.reraise(*ex)
    210         elif ex is not None:
    211           # NOTE(touts): This is bogus if request_stop() is not called
    212           # from the exception handler that raised ex.
    213           six.reraise(*sys.exc_info())
    214       if not self._stop_event.is_set():
    215         if ex and self._exc_info_to_raise is None:
    216           if isinstance(ex, tuple):
    217             logging.info("Error reported to Coordinator: %s",
    218                          compat.as_str_any(ex[1]),
    219                          exc_info=ex)
    220             self._exc_info_to_raise = ex
    221           else:
    222             logging.info("Error reported to Coordinator: %s, %s",
    223                          type(ex),
    224                          compat.as_str_any(ex))
    225             self._exc_info_to_raise = sys.exc_info()
    226           # self._exc_info_to_raise should contain a tuple containing exception
    227           # (type, value, traceback)
    228           if (len(self._exc_info_to_raise) != 3 or
    229               not self._exc_info_to_raise[0] or
    230               not self._exc_info_to_raise[1]):
    231             # Raise, catch and record the exception here so that error happens
    232             # where expected.
    233             try:
    234               raise ValueError(
    235                   "ex must be a tuple or sys.exc_info must return the current "
    236                   "exception: %s"
    237                   % self._exc_info_to_raise)
    238             except ValueError:
    239               # Record this error so it kills the coordinator properly.
    240               # NOTE(touts): As above, this is bogus if request_stop() is not
    241               # called from the exception handler that raised ex.
    242               self._exc_info_to_raise = sys.exc_info()
    243 
    244         self._stop_event.set()
    245 
    246   def clear_stop(self):
    247     """Clears the stop flag.
    248 
    249     After this is called, calls to `should_stop()` will return `False`.
    250     """
    251     with self._lock:
    252       self._joined = False
    253       self._exc_info_to_raise = None
    254       if self._stop_event.is_set():
    255         self._stop_event.clear()
    256 
    257   def should_stop(self):
    258     """Check if stop was requested.
    259 
    260     Returns:
    261       True if a stop was requested.
    262     """
    263     return self._stop_event.is_set()
    264 
    265   @contextlib.contextmanager
    266   def stop_on_exception(self):
    267     """Context manager to request stop when an Exception is raised.
    268 
    269     Code that uses a coordinator must catch exceptions and pass
    270     them to the `request_stop()` method to stop the other threads
    271     managed by the coordinator.
    272 
    273     This context handler simplifies the exception handling.
    274     Use it as follows:
    275 
    276     ```python
    277     with coord.stop_on_exception():
    278       # Any exception raised in the body of the with
    279       # clause is reported to the coordinator before terminating
    280       # the execution of the body.
    281       ...body...
    282     ```
    283 
    284     This is completely equivalent to the slightly longer code:
    285 
    286     ```python
    287     try:
    288       ...body...
    289     except:
    290       coord.request_stop(sys.exc_info())
    291     ```
    292 
    293     Yields:
    294       nothing.
    295     """
    296     try:
    297       yield
    298     except:  # pylint: disable=bare-except
    299       self.request_stop(ex=sys.exc_info())
    300 
    301   def wait_for_stop(self, timeout=None):
    302     """Wait till the Coordinator is told to stop.
    303 
    304     Args:
    305       timeout: Float.  Sleep for up to that many seconds waiting for
    306         should_stop() to become True.
    307 
    308     Returns:
    309       True if the Coordinator is told stop, False if the timeout expired.
    310     """
    311     return self._stop_event.wait(timeout)
    312 
    313   def register_thread(self, thread):
    314     """Register a thread to join.
    315 
    316     Args:
    317       thread: A Python thread to join.
    318     """
    319     with self._lock:
    320       self._registered_threads.add(thread)
    321 
    322   def join(self, threads=None, stop_grace_period_secs=120,
    323            ignore_live_threads=False):
    324     """Wait for threads to terminate.
    325 
    326     This call blocks until a set of threads have terminated.  The set of thread
    327     is the union of the threads passed in the `threads` argument and the list
    328     of threads that registered with the coordinator by calling
    329     `Coordinator.register_thread()`.
    330 
    331     After the threads stop, if an `exc_info` was passed to `request_stop`, that
    332     exception is re-raised.
    333 
    334     Grace period handling: When `request_stop()` is called, threads are given
    335     'stop_grace_period_secs' seconds to terminate.  If any of them is still
    336     alive after that period expires, a `RuntimeError` is raised.  Note that if
    337     an `exc_info` was passed to `request_stop()` then it is raised instead of
    338     that `RuntimeError`.
    339 
    340     Args:
    341       threads: List of `threading.Threads`. The started threads to join in
    342         addition to the registered threads.
    343       stop_grace_period_secs: Number of seconds given to threads to stop after
    344         `request_stop()` has been called.
    345       ignore_live_threads: If `False`, raises an error if any of the threads are
    346         still alive after `stop_grace_period_secs`.
    347 
    348     Raises:
    349       RuntimeError: If any thread is still alive after `request_stop()`
    350         is called and the grace period expires.
    351     """
    352     # Threads registered after this call will not be joined.
    353     with self._lock:
    354       if threads is None:
    355         threads = self._registered_threads
    356       else:
    357         threads = self._registered_threads.union(set(threads))
    358       # Copy the set into a list to avoid race conditions where a new thread
    359       # is added while we are waiting.
    360       threads = list(threads)
    361 
    362     # Wait for all threads to stop or for request_stop() to be called.
    363     while any(t.is_alive() for t in threads) and not self.wait_for_stop(1.0):
    364       pass
    365 
    366     # If any thread is still alive, wait for the grace period to expire.
    367     # By the time this check is executed, threads may still be shutting down,
    368     # so we add a sleep of increasing duration to give them a chance to shut
    369     # down without losing too many cycles.
    370     # The sleep duration is limited to the remaining grace duration.
    371     stop_wait_secs = 0.001
    372     while any(t.is_alive() for t in threads) and stop_grace_period_secs >= 0.0:
    373       time.sleep(stop_wait_secs)
    374       stop_grace_period_secs -= stop_wait_secs
    375       stop_wait_secs = 2 * stop_wait_secs
    376       # Keep the waiting period within sane bounds.
    377       # The minimum value is to avoid decreasing stop_wait_secs to a value
    378       # that could cause stop_grace_period_secs to remain unchanged.
    379       stop_wait_secs = max(min(stop_wait_secs, stop_grace_period_secs), 0.001)
    380 
    381     # List the threads still alive after the grace period.
    382     stragglers = [t.name for t in threads if t.is_alive()]
    383 
    384     # Terminate with an exception if appropriate.
    385     with self._lock:
    386       self._joined = True
    387       self._registered_threads = set()
    388       if self._exc_info_to_raise:
    389         six.reraise(*self._exc_info_to_raise)
    390       elif stragglers:
    391         if ignore_live_threads:
    392           logging.info("Coordinator stopped with threads still running: %s",
    393                        " ".join(stragglers))
    394         else:
    395           raise RuntimeError(
    396               "Coordinator stopped with threads still running: %s" %
    397               " ".join(stragglers))
    398 
    399   @property
    400   def joined(self):
    401     return self._joined
    402 
    403   def raise_requested_exception(self):
    404     """If an exception has been passed to `request_stop`, this raises it."""
    405     with self._lock:
    406       if self._exc_info_to_raise:
    407         six.reraise(*self._exc_info_to_raise)
    408 
    409 
    410 # Threads for the standard services.
    411 @tf_export("train.LooperThread")
    412 class LooperThread(threading.Thread):
    413   """A thread that runs code repeatedly, optionally on a timer.
    414 
    415   This thread class is intended to be used with a `Coordinator`.  It repeatedly
    416   runs code specified either as `target` and `args` or by the `run_loop()`
    417   method.
    418 
    419   Before each run the thread checks if the coordinator has requested stop.  In
    420   that case the looper thread terminates immediately.
    421 
    422   If the code being run raises an exception, that exception is reported to the
    423   coordinator and the thread terminates.  The coordinator will then request all
    424   the other threads it coordinates to stop.
    425 
    426   You typically pass looper threads to the supervisor `Join()` method.
    427   """
    428 
    429   def __init__(self, coord, timer_interval_secs, target=None, args=None,
    430                kwargs=None):
    431     """Create a LooperThread.
    432 
    433     Args:
    434       coord: A Coordinator.
    435       timer_interval_secs: Time boundaries at which to call Run(), or None
    436         if it should be called back to back.
    437       target: Optional callable object that will be executed in the thread.
    438       args: Optional arguments to pass to `target` when calling it.
    439       kwargs: Optional keyword arguments to pass to `target` when calling it.
    440 
    441     Raises:
    442       ValueError: If one of the arguments is invalid.
    443     """
    444     if not isinstance(coord, Coordinator):
    445       raise ValueError("'coord' argument must be a Coordinator: %s" % coord)
    446     super(LooperThread, self).__init__()
    447     self.daemon = True
    448     self._coord = coord
    449     self._timer_interval_secs = timer_interval_secs
    450     self._target = target
    451     if self._target:
    452       self._args = args or ()
    453       self._kwargs = kwargs or {}
    454     elif args or kwargs:
    455       raise ValueError("'args' and 'kwargs' argument require that you also "
    456                        "pass 'target'")
    457     self._coord.register_thread(self)
    458 
    459   @staticmethod
    460   def loop(coord, timer_interval_secs, target, args=None, kwargs=None):
    461     """Start a LooperThread that calls a function periodically.
    462 
    463     If `timer_interval_secs` is None the thread calls `target(args)`
    464     repeatedly.  Otherwise `target(args)` is called every `timer_interval_secs`
    465     seconds.  The thread terminates when a stop of the coordinator is
    466     requested.
    467 
    468     Args:
    469       coord: A Coordinator.
    470       timer_interval_secs: Number. Time boundaries at which to call `target`.
    471       target: A callable object.
    472       args: Optional arguments to pass to `target` when calling it.
    473       kwargs: Optional keyword arguments to pass to `target` when calling it.
    474 
    475     Returns:
    476       The started thread.
    477     """
    478     looper = LooperThread(coord, timer_interval_secs, target=target, args=args,
    479                           kwargs=kwargs)
    480     looper.start()
    481     return looper
    482 
    483   def run(self):
    484     with self._coord.stop_on_exception():
    485       self.start_loop()
    486       if self._timer_interval_secs is None:
    487         # Call back-to-back.
    488         while not self._coord.should_stop():
    489           self.run_loop()
    490       else:
    491         # Next time at which to call run_loop(), starts as 'now'.
    492         next_timer_time = time.time()
    493         while not self._coord.wait_for_stop(next_timer_time - time.time()):
    494           next_timer_time += self._timer_interval_secs
    495           self.run_loop()
    496       self.stop_loop()
    497 
    498   def start_loop(self):
    499     """Called when the thread starts."""
    500     pass
    501 
    502   def stop_loop(self):
    503     """Called when the thread stops."""
    504     pass
    505 
    506   def run_loop(self):
    507     """Called at 'timer_interval_secs' boundaries."""
    508     if self._target:
    509       self._target(*self._args, **self._kwargs)
    510