Home | History | Annotate | Download | only in training
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 """Create threads to run multiple enqueue ops."""
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import threading
     22 import weakref
     23 
     24 from tensorflow.core.protobuf import queue_runner_pb2
     25 from tensorflow.python.client import session
     26 from tensorflow.python.eager import context
     27 from tensorflow.python.framework import errors
     28 from tensorflow.python.framework import ops
     29 from tensorflow.python.platform import tf_logging as logging
     30 from tensorflow.python.util.tf_export import tf_export
     31 
     32 
     33 @tf_export("train.queue_runner.QueueRunner", "train.QueueRunner")
     34 class QueueRunner(object):
     35   """Holds a list of enqueue operations for a queue, each to be run in a thread.
     36 
     37   Queues are a convenient TensorFlow mechanism to compute tensors
     38   asynchronously using multiple threads. For example in the canonical 'Input
     39   Reader' setup one set of threads generates filenames in a queue; a second set
     40   of threads read records from the files, processes them, and enqueues tensors
     41   on a second queue; a third set of threads dequeues these input records to
     42   construct batches and runs them through training operations.
     43 
     44   There are several delicate issues when running multiple threads that way:
     45   closing the queues in sequence as the input is exhausted, correctly catching
     46   and reporting exceptions, etc.
     47 
     48   The `QueueRunner`, combined with the `Coordinator`, helps handle these issues.
     49 
     50   @compatibility(eager)
     51   QueueRunners are not compatible with eager execution. Instead, please
     52   use `tf.data` to get data into your model.
     53   @end_compatibility
     54   """
     55 
     56   def __init__(self, queue=None, enqueue_ops=None, close_op=None,
     57                cancel_op=None, queue_closed_exception_types=None,
     58                queue_runner_def=None, import_scope=None):
     59     """Create a QueueRunner.
     60 
     61     On construction the `QueueRunner` adds an op to close the queue.  That op
     62     will be run if the enqueue ops raise exceptions.
     63 
     64     When you later call the `create_threads()` method, the `QueueRunner` will
     65     create one thread for each op in `enqueue_ops`.  Each thread will run its
     66     enqueue op in parallel with the other threads.  The enqueue ops do not have
     67     to all be the same op, but it is expected that they all enqueue tensors in
     68     `queue`.
     69 
     70     Args:
     71       queue: A `Queue`.
     72       enqueue_ops: List of enqueue ops to run in threads later.
     73       close_op: Op to close the queue. Pending enqueue ops are preserved.
     74       cancel_op: Op to close the queue and cancel pending enqueue ops.
     75       queue_closed_exception_types: Optional tuple of Exception types that
     76         indicate that the queue has been closed when raised during an enqueue
     77         operation.  Defaults to `(tf.errors.OutOfRangeError,)`.  Another common
     78         case includes `(tf.errors.OutOfRangeError, tf.errors.CancelledError)`,
     79         when some of the enqueue ops may dequeue from other Queues.
     80       queue_runner_def: Optional `QueueRunnerDef` protocol buffer. If specified,
     81         recreates the QueueRunner from its contents. `queue_runner_def` and the
     82         other arguments are mutually exclusive.
     83       import_scope: Optional `string`. Name scope to add. Only used when
     84         initializing from protocol buffer.
     85 
     86     Raises:
     87       ValueError: If both `queue_runner_def` and `queue` are both specified.
     88       ValueError: If `queue` or `enqueue_ops` are not provided when not
     89         restoring from `queue_runner_def`.
     90       RuntimeError: If eager execution is enabled.
     91     """
     92     if context.in_eager_mode():
     93       raise RuntimeError(
     94           "QueueRunners are not supported when eager execution is enabled. "
     95           "Instead, please use tf.data to get data into your model.")
     96 
     97     if queue_runner_def:
     98       if queue or enqueue_ops:
     99         raise ValueError("queue_runner_def and queue are mutually exclusive.")
    100       self._init_from_proto(queue_runner_def,
    101                             import_scope=import_scope)
    102     else:
    103       self._init_from_args(
    104           queue=queue, enqueue_ops=enqueue_ops,
    105           close_op=close_op, cancel_op=cancel_op,
    106           queue_closed_exception_types=queue_closed_exception_types)
    107     # Protect the count of runs to wait for.
    108     self._lock = threading.Lock()
    109     # A map from a session object to the number of outstanding queue runner
    110     # threads for that session.
    111     self._runs_per_session = weakref.WeakKeyDictionary()
    112     # List of exceptions raised by the running threads.
    113     self._exceptions_raised = []
    114 
    115   def _init_from_args(self, queue=None, enqueue_ops=None, close_op=None,
    116                       cancel_op=None, queue_closed_exception_types=None):
    117     """Create a QueueRunner from arguments.
    118 
    119     Args:
    120       queue: A `Queue`.
    121       enqueue_ops: List of enqueue ops to run in threads later.
    122       close_op: Op to close the queue. Pending enqueue ops are preserved.
    123       cancel_op: Op to close the queue and cancel pending enqueue ops.
    124       queue_closed_exception_types: Tuple of exception types, which indicate
    125         the queue has been safely closed.
    126 
    127     Raises:
    128       ValueError: If `queue` or `enqueue_ops` are not provided when not
    129         restoring from `queue_runner_def`.
    130       TypeError: If `queue_closed_exception_types` is provided, but is not
    131         a non-empty tuple of error types (subclasses of `tf.errors.OpError`).
    132     """
    133     if not queue or not enqueue_ops:
    134       raise ValueError("Must provide queue and enqueue_ops.")
    135     self._queue = queue
    136     self._enqueue_ops = enqueue_ops
    137     self._close_op = close_op
    138     self._cancel_op = cancel_op
    139     if queue_closed_exception_types is not None:
    140       if (not isinstance(queue_closed_exception_types, tuple)
    141           or not queue_closed_exception_types
    142           or not all(issubclass(t, errors.OpError)
    143                      for t in queue_closed_exception_types)):
    144         raise TypeError(
    145             "queue_closed_exception_types, when provided, "
    146             "must be a tuple of tf.error types, but saw: %s"
    147             % queue_closed_exception_types)
    148     self._queue_closed_exception_types = queue_closed_exception_types
    149     # Close when no more will be produced, but pending enqueues should be
    150     # preserved.
    151     if self._close_op is None:
    152       self._close_op = self._queue.close()
    153     # Close and cancel pending enqueues since there was an error and we want
    154     # to unblock everything so we can cleanly exit.
    155     if self._cancel_op is None:
    156       self._cancel_op = self._queue.close(cancel_pending_enqueues=True)
    157     if not self._queue_closed_exception_types:
    158       self._queue_closed_exception_types = (errors.OutOfRangeError,)
    159     else:
    160       self._queue_closed_exception_types = tuple(
    161           self._queue_closed_exception_types)
    162 
    163   def _init_from_proto(self, queue_runner_def, import_scope=None):
    164     """Create a QueueRunner from `QueueRunnerDef`.
    165 
    166     Args:
    167       queue_runner_def: Optional `QueueRunnerDef` protocol buffer.
    168       import_scope: Optional `string`. Name scope to add.
    169     """
    170     assert isinstance(queue_runner_def, queue_runner_pb2.QueueRunnerDef)
    171     g = ops.get_default_graph()
    172     self._queue = g.as_graph_element(
    173         ops.prepend_name_scope(queue_runner_def.queue_name, import_scope))
    174     self._enqueue_ops = [g.as_graph_element(
    175         ops.prepend_name_scope(op, import_scope))
    176                          for op in queue_runner_def.enqueue_op_name]
    177     self._close_op = g.as_graph_element(ops.prepend_name_scope(
    178         queue_runner_def.close_op_name, import_scope))
    179     self._cancel_op = g.as_graph_element(ops.prepend_name_scope(
    180         queue_runner_def.cancel_op_name, import_scope))
    181     self._queue_closed_exception_types = tuple(
    182         errors.exception_type_from_error_code(code)
    183         for code in queue_runner_def.queue_closed_exception_types)
    184     # Legacy support for old QueueRunnerDefs created before this field
    185     # was added.
    186     if not self._queue_closed_exception_types:
    187       self._queue_closed_exception_types = (errors.OutOfRangeError,)
    188 
    189   @property
    190   def queue(self):
    191     return self._queue
    192 
    193   @property
    194   def enqueue_ops(self):
    195     return self._enqueue_ops
    196 
    197   @property
    198   def close_op(self):
    199     return self._close_op
    200 
    201   @property
    202   def cancel_op(self):
    203     return self._cancel_op
    204 
    205   @property
    206   def queue_closed_exception_types(self):
    207     return self._queue_closed_exception_types
    208 
    209   @property
    210   def exceptions_raised(self):
    211     """Exceptions raised but not handled by the `QueueRunner` threads.
    212 
    213     Exceptions raised in queue runner threads are handled in one of two ways
    214     depending on whether or not a `Coordinator` was passed to
    215     `create_threads()`:
    216 
    217     * With a `Coordinator`, exceptions are reported to the coordinator and
    218       forgotten by the `QueueRunner`.
    219     * Without a `Coordinator`, exceptions are captured by the `QueueRunner` and
    220       made available in this `exceptions_raised` property.
    221 
    222     Returns:
    223       A list of Python `Exception` objects.  The list is empty if no exception
    224       was captured.  (No exceptions are captured when using a Coordinator.)
    225     """
    226     return self._exceptions_raised
    227 
    228   @property
    229   def name(self):
    230     """The string name of the underlying Queue."""
    231     return self._queue.name
    232 
    233   # pylint: disable=broad-except
    234   def _run(self, sess, enqueue_op, coord=None):
    235     """Execute the enqueue op in a loop, close the queue in case of error.
    236 
    237     Args:
    238       sess: A Session.
    239       enqueue_op: The Operation to run.
    240       coord: Optional Coordinator object for reporting errors and checking
    241         for stop conditions.
    242     """
    243     decremented = False
    244     try:
    245       # Make a cached callable from the `enqueue_op` to decrease the
    246       # Python overhead in the queue-runner loop.
    247       enqueue_callable = sess.make_callable(enqueue_op)
    248       while True:
    249         if coord and coord.should_stop():
    250           break
    251         try:
    252           enqueue_callable()
    253         except self._queue_closed_exception_types:  # pylint: disable=catching-non-exception
    254           # This exception indicates that a queue was closed.
    255           with self._lock:
    256             self._runs_per_session[sess] -= 1
    257             decremented = True
    258             if self._runs_per_session[sess] == 0:
    259               try:
    260                 sess.run(self._close_op)
    261               except Exception as e:
    262                 # Intentionally ignore errors from close_op.
    263                 logging.vlog(1, "Ignored exception: %s", str(e))
    264             return
    265     except Exception as e:
    266       # This catches all other exceptions.
    267       if coord:
    268         coord.request_stop(e)
    269       else:
    270         logging.error("Exception in QueueRunner: %s", str(e))
    271         with self._lock:
    272           self._exceptions_raised.append(e)
    273         raise
    274     finally:
    275       # Make sure we account for all terminations: normal or errors.
    276       if not decremented:
    277         with self._lock:
    278           self._runs_per_session[sess] -= 1
    279 
    280   def _close_on_stop(self, sess, cancel_op, coord):
    281     """Close the queue when the Coordinator requests stop.
    282 
    283     Args:
    284       sess: A Session.
    285       cancel_op: The Operation to run.
    286       coord: Coordinator.
    287     """
    288     coord.wait_for_stop()
    289     try:
    290       sess.run(cancel_op)
    291     except Exception as e:
    292       # Intentionally ignore errors from cancel_op.
    293       logging.vlog(1, "Ignored exception: %s", str(e))
    294   # pylint: enable=broad-except
    295 
    296   def create_threads(self, sess, coord=None, daemon=False, start=False):
    297     """Create threads to run the enqueue ops for the given session.
    298 
    299     This method requires a session in which the graph was launched.  It creates
    300     a list of threads, optionally starting them.  There is one thread for each
    301     op passed in `enqueue_ops`.
    302 
    303     The `coord` argument is an optional coordinator that the threads will use
    304     to terminate together and report exceptions.  If a coordinator is given,
    305     this method starts an additional thread to close the queue when the
    306     coordinator requests a stop.
    307 
    308     If previously created threads for the given session are still running, no
    309     new threads will be created.
    310 
    311     Args:
    312       sess: A `Session`.
    313       coord: Optional `Coordinator` object for reporting errors and checking
    314         stop conditions.
    315       daemon: Boolean.  If `True` make the threads daemon threads.
    316       start: Boolean.  If `True` starts the threads.  If `False` the
    317         caller must call the `start()` method of the returned threads.
    318 
    319     Returns:
    320       A list of threads.
    321     """
    322     with self._lock:
    323       try:
    324         if self._runs_per_session[sess] > 0:
    325           # Already started: no new threads to return.
    326           return []
    327       except KeyError:
    328         # We haven't seen this session yet.
    329         pass
    330       self._runs_per_session[sess] = len(self._enqueue_ops)
    331       self._exceptions_raised = []
    332 
    333     ret_threads = []
    334     for op in self._enqueue_ops:
    335       name = "QueueRunnerThread-{}-{}".format(self.name, op.name)
    336       ret_threads.append(threading.Thread(target=self._run,
    337                                           args=(sess, op, coord),
    338                                           name=name))
    339     if coord:
    340       name = "QueueRunnerThread-{}-close_on_stop".format(self.name)
    341       ret_threads.append(threading.Thread(target=self._close_on_stop,
    342                                           args=(sess, self._cancel_op, coord),
    343                                           name=name))
    344     for t in ret_threads:
    345       if coord:
    346         coord.register_thread(t)
    347       if daemon:
    348         t.daemon = True
    349       if start:
    350         t.start()
    351     return ret_threads
    352 
    353   def to_proto(self, export_scope=None):
    354     """Converts this `QueueRunner` to a `QueueRunnerDef` protocol buffer.
    355 
    356     Args:
    357       export_scope: Optional `string`. Name scope to remove.
    358 
    359     Returns:
    360       A `QueueRunnerDef` protocol buffer, or `None` if the `Variable` is not in
    361       the specified name scope.
    362     """
    363     if (export_scope is None or
    364         self.queue.name.startswith(export_scope)):
    365       queue_runner_def = queue_runner_pb2.QueueRunnerDef()
    366       queue_runner_def.queue_name = ops.strip_name_scope(
    367           self.queue.name, export_scope)
    368       for enqueue_op in self.enqueue_ops:
    369         queue_runner_def.enqueue_op_name.append(
    370             ops.strip_name_scope(enqueue_op.name, export_scope))
    371       queue_runner_def.close_op_name = ops.strip_name_scope(
    372           self.close_op.name, export_scope)
    373       queue_runner_def.cancel_op_name = ops.strip_name_scope(
    374           self.cancel_op.name, export_scope)
    375       queue_runner_def.queue_closed_exception_types.extend([
    376           errors.error_code_from_exception_type(cls)
    377           for cls in self._queue_closed_exception_types])
    378       return queue_runner_def
    379     else:
    380       return None
    381 
    382   @staticmethod
    383   def from_proto(queue_runner_def, import_scope=None):
    384     """Returns a `QueueRunner` object created from `queue_runner_def`."""
    385     return QueueRunner(queue_runner_def=queue_runner_def,
    386                        import_scope=import_scope)
    387 
    388 
    389 @tf_export("train.queue_runner.add_queue_runner", "train.add_queue_runner")
    390 def add_queue_runner(qr, collection=ops.GraphKeys.QUEUE_RUNNERS):
    391   """Adds a `QueueRunner` to a collection in the graph.
    392 
    393   When building a complex model that uses many queues it is often difficult to
    394   gather all the queue runners that need to be run.  This convenience function
    395   allows you to add a queue runner to a well known collection in the graph.
    396 
    397   The companion method `start_queue_runners()` can be used to start threads for
    398   all the collected queue runners.
    399 
    400   Args:
    401     qr: A `QueueRunner`.
    402     collection: A `GraphKey` specifying the graph collection to add
    403       the queue runner to.  Defaults to `GraphKeys.QUEUE_RUNNERS`.
    404   """
    405   ops.add_to_collection(collection, qr)
    406 
    407 
    408 @tf_export("train.queue_runner.start_queue_runners",
    409            "train.start_queue_runners")
    410 def start_queue_runners(sess=None, coord=None, daemon=True, start=True,
    411                         collection=ops.GraphKeys.QUEUE_RUNNERS):
    412   """Starts all queue runners collected in the graph.
    413 
    414   This is a companion method to `add_queue_runner()`.  It just starts
    415   threads for all queue runners collected in the graph.  It returns
    416   the list of all threads.
    417 
    418   Args:
    419     sess: `Session` used to run the queue ops.  Defaults to the
    420       default session.
    421     coord: Optional `Coordinator` for coordinating the started threads.
    422     daemon: Whether the threads should be marked as `daemons`, meaning
    423       they don't block program exit.
    424     start: Set to `False` to only create the threads, not start them.
    425     collection: A `GraphKey` specifying the graph collection to
    426       get the queue runners from.  Defaults to `GraphKeys.QUEUE_RUNNERS`.
    427 
    428   Raises:
    429     ValueError: if `sess` is None and there isn't any default session.
    430     TypeError: if `sess` is not a `tf.Session` object.
    431 
    432   Returns:
    433     A list of threads.
    434 
    435   Raises:
    436     RuntimeError: If called with eager execution enabled.
    437     ValueError: If called without a default `tf.Session` registered.
    438 
    439   @compatibility(eager)
    440   Not compatible with eager execution. To ingest data under eager execution,
    441   use the `tf.data` API instead.
    442   @end_compatibility
    443   """
    444   if context.in_eager_mode():
    445     raise RuntimeError("Queues are not compatible with eager execution.")
    446   if sess is None:
    447     sess = ops.get_default_session()
    448     if not sess:
    449       raise ValueError("Cannot start queue runners: No default session is "
    450                        "registered. Use `with sess.as_default()` or pass an "
    451                        "explicit session to tf.start_queue_runners(sess=sess)")
    452 
    453   if not isinstance(sess, session.SessionInterface):
    454     # Following check is due to backward compatibility. (b/62061352)
    455     if sess.__class__.__name__ in [
    456         "MonitoredSession", "SingularMonitoredSession"]:
    457       return []
    458     raise TypeError("sess must be a `tf.Session` object. "
    459                     "Given class: {}".format(sess.__class__))
    460 
    461   with sess.graph.as_default():
    462     threads = []
    463     for qr in ops.get_collection(collection):
    464       threads.extend(qr.create_threads(sess, coord=coord, daemon=daemon,
    465                                        start=start))
    466   return threads
    467 
    468 
    469 ops.register_proto_function(ops.GraphKeys.QUEUE_RUNNERS,
    470                             proto_type=queue_runner_pb2.QueueRunnerDef,
    471                             to_proto=QueueRunner.to_proto,
    472                             from_proto=QueueRunner.from_proto)
    473