1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 """Create threads to run multiple enqueue ops.""" 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import threading 22 import weakref 23 24 from tensorflow.core.protobuf import queue_runner_pb2 25 from tensorflow.python.client import session 26 from tensorflow.python.eager import context 27 from tensorflow.python.framework import errors 28 from tensorflow.python.framework import ops 29 from tensorflow.python.platform import tf_logging as logging 30 from tensorflow.python.util.tf_export import tf_export 31 32 33 @tf_export("train.queue_runner.QueueRunner", "train.QueueRunner") 34 class QueueRunner(object): 35 """Holds a list of enqueue operations for a queue, each to be run in a thread. 36 37 Queues are a convenient TensorFlow mechanism to compute tensors 38 asynchronously using multiple threads. For example in the canonical 'Input 39 Reader' setup one set of threads generates filenames in a queue; a second set 40 of threads read records from the files, processes them, and enqueues tensors 41 on a second queue; a third set of threads dequeues these input records to 42 construct batches and runs them through training operations. 43 44 There are several delicate issues when running multiple threads that way: 45 closing the queues in sequence as the input is exhausted, correctly catching 46 and reporting exceptions, etc. 47 48 The `QueueRunner`, combined with the `Coordinator`, helps handle these issues. 49 50 @compatibility(eager) 51 QueueRunners are not compatible with eager execution. Instead, please 52 use `tf.data` to get data into your model. 53 @end_compatibility 54 """ 55 56 def __init__(self, queue=None, enqueue_ops=None, close_op=None, 57 cancel_op=None, queue_closed_exception_types=None, 58 queue_runner_def=None, import_scope=None): 59 """Create a QueueRunner. 60 61 On construction the `QueueRunner` adds an op to close the queue. That op 62 will be run if the enqueue ops raise exceptions. 63 64 When you later call the `create_threads()` method, the `QueueRunner` will 65 create one thread for each op in `enqueue_ops`. Each thread will run its 66 enqueue op in parallel with the other threads. The enqueue ops do not have 67 to all be the same op, but it is expected that they all enqueue tensors in 68 `queue`. 69 70 Args: 71 queue: A `Queue`. 72 enqueue_ops: List of enqueue ops to run in threads later. 73 close_op: Op to close the queue. Pending enqueue ops are preserved. 74 cancel_op: Op to close the queue and cancel pending enqueue ops. 75 queue_closed_exception_types: Optional tuple of Exception types that 76 indicate that the queue has been closed when raised during an enqueue 77 operation. Defaults to `(tf.errors.OutOfRangeError,)`. Another common 78 case includes `(tf.errors.OutOfRangeError, tf.errors.CancelledError)`, 79 when some of the enqueue ops may dequeue from other Queues. 80 queue_runner_def: Optional `QueueRunnerDef` protocol buffer. If specified, 81 recreates the QueueRunner from its contents. `queue_runner_def` and the 82 other arguments are mutually exclusive. 83 import_scope: Optional `string`. Name scope to add. Only used when 84 initializing from protocol buffer. 85 86 Raises: 87 ValueError: If both `queue_runner_def` and `queue` are both specified. 88 ValueError: If `queue` or `enqueue_ops` are not provided when not 89 restoring from `queue_runner_def`. 90 RuntimeError: If eager execution is enabled. 91 """ 92 if context.in_eager_mode(): 93 raise RuntimeError( 94 "QueueRunners are not supported when eager execution is enabled. " 95 "Instead, please use tf.data to get data into your model.") 96 97 if queue_runner_def: 98 if queue or enqueue_ops: 99 raise ValueError("queue_runner_def and queue are mutually exclusive.") 100 self._init_from_proto(queue_runner_def, 101 import_scope=import_scope) 102 else: 103 self._init_from_args( 104 queue=queue, enqueue_ops=enqueue_ops, 105 close_op=close_op, cancel_op=cancel_op, 106 queue_closed_exception_types=queue_closed_exception_types) 107 # Protect the count of runs to wait for. 108 self._lock = threading.Lock() 109 # A map from a session object to the number of outstanding queue runner 110 # threads for that session. 111 self._runs_per_session = weakref.WeakKeyDictionary() 112 # List of exceptions raised by the running threads. 113 self._exceptions_raised = [] 114 115 def _init_from_args(self, queue=None, enqueue_ops=None, close_op=None, 116 cancel_op=None, queue_closed_exception_types=None): 117 """Create a QueueRunner from arguments. 118 119 Args: 120 queue: A `Queue`. 121 enqueue_ops: List of enqueue ops to run in threads later. 122 close_op: Op to close the queue. Pending enqueue ops are preserved. 123 cancel_op: Op to close the queue and cancel pending enqueue ops. 124 queue_closed_exception_types: Tuple of exception types, which indicate 125 the queue has been safely closed. 126 127 Raises: 128 ValueError: If `queue` or `enqueue_ops` are not provided when not 129 restoring from `queue_runner_def`. 130 TypeError: If `queue_closed_exception_types` is provided, but is not 131 a non-empty tuple of error types (subclasses of `tf.errors.OpError`). 132 """ 133 if not queue or not enqueue_ops: 134 raise ValueError("Must provide queue and enqueue_ops.") 135 self._queue = queue 136 self._enqueue_ops = enqueue_ops 137 self._close_op = close_op 138 self._cancel_op = cancel_op 139 if queue_closed_exception_types is not None: 140 if (not isinstance(queue_closed_exception_types, tuple) 141 or not queue_closed_exception_types 142 or not all(issubclass(t, errors.OpError) 143 for t in queue_closed_exception_types)): 144 raise TypeError( 145 "queue_closed_exception_types, when provided, " 146 "must be a tuple of tf.error types, but saw: %s" 147 % queue_closed_exception_types) 148 self._queue_closed_exception_types = queue_closed_exception_types 149 # Close when no more will be produced, but pending enqueues should be 150 # preserved. 151 if self._close_op is None: 152 self._close_op = self._queue.close() 153 # Close and cancel pending enqueues since there was an error and we want 154 # to unblock everything so we can cleanly exit. 155 if self._cancel_op is None: 156 self._cancel_op = self._queue.close(cancel_pending_enqueues=True) 157 if not self._queue_closed_exception_types: 158 self._queue_closed_exception_types = (errors.OutOfRangeError,) 159 else: 160 self._queue_closed_exception_types = tuple( 161 self._queue_closed_exception_types) 162 163 def _init_from_proto(self, queue_runner_def, import_scope=None): 164 """Create a QueueRunner from `QueueRunnerDef`. 165 166 Args: 167 queue_runner_def: Optional `QueueRunnerDef` protocol buffer. 168 import_scope: Optional `string`. Name scope to add. 169 """ 170 assert isinstance(queue_runner_def, queue_runner_pb2.QueueRunnerDef) 171 g = ops.get_default_graph() 172 self._queue = g.as_graph_element( 173 ops.prepend_name_scope(queue_runner_def.queue_name, import_scope)) 174 self._enqueue_ops = [g.as_graph_element( 175 ops.prepend_name_scope(op, import_scope)) 176 for op in queue_runner_def.enqueue_op_name] 177 self._close_op = g.as_graph_element(ops.prepend_name_scope( 178 queue_runner_def.close_op_name, import_scope)) 179 self._cancel_op = g.as_graph_element(ops.prepend_name_scope( 180 queue_runner_def.cancel_op_name, import_scope)) 181 self._queue_closed_exception_types = tuple( 182 errors.exception_type_from_error_code(code) 183 for code in queue_runner_def.queue_closed_exception_types) 184 # Legacy support for old QueueRunnerDefs created before this field 185 # was added. 186 if not self._queue_closed_exception_types: 187 self._queue_closed_exception_types = (errors.OutOfRangeError,) 188 189 @property 190 def queue(self): 191 return self._queue 192 193 @property 194 def enqueue_ops(self): 195 return self._enqueue_ops 196 197 @property 198 def close_op(self): 199 return self._close_op 200 201 @property 202 def cancel_op(self): 203 return self._cancel_op 204 205 @property 206 def queue_closed_exception_types(self): 207 return self._queue_closed_exception_types 208 209 @property 210 def exceptions_raised(self): 211 """Exceptions raised but not handled by the `QueueRunner` threads. 212 213 Exceptions raised in queue runner threads are handled in one of two ways 214 depending on whether or not a `Coordinator` was passed to 215 `create_threads()`: 216 217 * With a `Coordinator`, exceptions are reported to the coordinator and 218 forgotten by the `QueueRunner`. 219 * Without a `Coordinator`, exceptions are captured by the `QueueRunner` and 220 made available in this `exceptions_raised` property. 221 222 Returns: 223 A list of Python `Exception` objects. The list is empty if no exception 224 was captured. (No exceptions are captured when using a Coordinator.) 225 """ 226 return self._exceptions_raised 227 228 @property 229 def name(self): 230 """The string name of the underlying Queue.""" 231 return self._queue.name 232 233 # pylint: disable=broad-except 234 def _run(self, sess, enqueue_op, coord=None): 235 """Execute the enqueue op in a loop, close the queue in case of error. 236 237 Args: 238 sess: A Session. 239 enqueue_op: The Operation to run. 240 coord: Optional Coordinator object for reporting errors and checking 241 for stop conditions. 242 """ 243 decremented = False 244 try: 245 # Make a cached callable from the `enqueue_op` to decrease the 246 # Python overhead in the queue-runner loop. 247 enqueue_callable = sess.make_callable(enqueue_op) 248 while True: 249 if coord and coord.should_stop(): 250 break 251 try: 252 enqueue_callable() 253 except self._queue_closed_exception_types: # pylint: disable=catching-non-exception 254 # This exception indicates that a queue was closed. 255 with self._lock: 256 self._runs_per_session[sess] -= 1 257 decremented = True 258 if self._runs_per_session[sess] == 0: 259 try: 260 sess.run(self._close_op) 261 except Exception as e: 262 # Intentionally ignore errors from close_op. 263 logging.vlog(1, "Ignored exception: %s", str(e)) 264 return 265 except Exception as e: 266 # This catches all other exceptions. 267 if coord: 268 coord.request_stop(e) 269 else: 270 logging.error("Exception in QueueRunner: %s", str(e)) 271 with self._lock: 272 self._exceptions_raised.append(e) 273 raise 274 finally: 275 # Make sure we account for all terminations: normal or errors. 276 if not decremented: 277 with self._lock: 278 self._runs_per_session[sess] -= 1 279 280 def _close_on_stop(self, sess, cancel_op, coord): 281 """Close the queue when the Coordinator requests stop. 282 283 Args: 284 sess: A Session. 285 cancel_op: The Operation to run. 286 coord: Coordinator. 287 """ 288 coord.wait_for_stop() 289 try: 290 sess.run(cancel_op) 291 except Exception as e: 292 # Intentionally ignore errors from cancel_op. 293 logging.vlog(1, "Ignored exception: %s", str(e)) 294 # pylint: enable=broad-except 295 296 def create_threads(self, sess, coord=None, daemon=False, start=False): 297 """Create threads to run the enqueue ops for the given session. 298 299 This method requires a session in which the graph was launched. It creates 300 a list of threads, optionally starting them. There is one thread for each 301 op passed in `enqueue_ops`. 302 303 The `coord` argument is an optional coordinator that the threads will use 304 to terminate together and report exceptions. If a coordinator is given, 305 this method starts an additional thread to close the queue when the 306 coordinator requests a stop. 307 308 If previously created threads for the given session are still running, no 309 new threads will be created. 310 311 Args: 312 sess: A `Session`. 313 coord: Optional `Coordinator` object for reporting errors and checking 314 stop conditions. 315 daemon: Boolean. If `True` make the threads daemon threads. 316 start: Boolean. If `True` starts the threads. If `False` the 317 caller must call the `start()` method of the returned threads. 318 319 Returns: 320 A list of threads. 321 """ 322 with self._lock: 323 try: 324 if self._runs_per_session[sess] > 0: 325 # Already started: no new threads to return. 326 return [] 327 except KeyError: 328 # We haven't seen this session yet. 329 pass 330 self._runs_per_session[sess] = len(self._enqueue_ops) 331 self._exceptions_raised = [] 332 333 ret_threads = [] 334 for op in self._enqueue_ops: 335 name = "QueueRunnerThread-{}-{}".format(self.name, op.name) 336 ret_threads.append(threading.Thread(target=self._run, 337 args=(sess, op, coord), 338 name=name)) 339 if coord: 340 name = "QueueRunnerThread-{}-close_on_stop".format(self.name) 341 ret_threads.append(threading.Thread(target=self._close_on_stop, 342 args=(sess, self._cancel_op, coord), 343 name=name)) 344 for t in ret_threads: 345 if coord: 346 coord.register_thread(t) 347 if daemon: 348 t.daemon = True 349 if start: 350 t.start() 351 return ret_threads 352 353 def to_proto(self, export_scope=None): 354 """Converts this `QueueRunner` to a `QueueRunnerDef` protocol buffer. 355 356 Args: 357 export_scope: Optional `string`. Name scope to remove. 358 359 Returns: 360 A `QueueRunnerDef` protocol buffer, or `None` if the `Variable` is not in 361 the specified name scope. 362 """ 363 if (export_scope is None or 364 self.queue.name.startswith(export_scope)): 365 queue_runner_def = queue_runner_pb2.QueueRunnerDef() 366 queue_runner_def.queue_name = ops.strip_name_scope( 367 self.queue.name, export_scope) 368 for enqueue_op in self.enqueue_ops: 369 queue_runner_def.enqueue_op_name.append( 370 ops.strip_name_scope(enqueue_op.name, export_scope)) 371 queue_runner_def.close_op_name = ops.strip_name_scope( 372 self.close_op.name, export_scope) 373 queue_runner_def.cancel_op_name = ops.strip_name_scope( 374 self.cancel_op.name, export_scope) 375 queue_runner_def.queue_closed_exception_types.extend([ 376 errors.error_code_from_exception_type(cls) 377 for cls in self._queue_closed_exception_types]) 378 return queue_runner_def 379 else: 380 return None 381 382 @staticmethod 383 def from_proto(queue_runner_def, import_scope=None): 384 """Returns a `QueueRunner` object created from `queue_runner_def`.""" 385 return QueueRunner(queue_runner_def=queue_runner_def, 386 import_scope=import_scope) 387 388 389 @tf_export("train.queue_runner.add_queue_runner", "train.add_queue_runner") 390 def add_queue_runner(qr, collection=ops.GraphKeys.QUEUE_RUNNERS): 391 """Adds a `QueueRunner` to a collection in the graph. 392 393 When building a complex model that uses many queues it is often difficult to 394 gather all the queue runners that need to be run. This convenience function 395 allows you to add a queue runner to a well known collection in the graph. 396 397 The companion method `start_queue_runners()` can be used to start threads for 398 all the collected queue runners. 399 400 Args: 401 qr: A `QueueRunner`. 402 collection: A `GraphKey` specifying the graph collection to add 403 the queue runner to. Defaults to `GraphKeys.QUEUE_RUNNERS`. 404 """ 405 ops.add_to_collection(collection, qr) 406 407 408 @tf_export("train.queue_runner.start_queue_runners", 409 "train.start_queue_runners") 410 def start_queue_runners(sess=None, coord=None, daemon=True, start=True, 411 collection=ops.GraphKeys.QUEUE_RUNNERS): 412 """Starts all queue runners collected in the graph. 413 414 This is a companion method to `add_queue_runner()`. It just starts 415 threads for all queue runners collected in the graph. It returns 416 the list of all threads. 417 418 Args: 419 sess: `Session` used to run the queue ops. Defaults to the 420 default session. 421 coord: Optional `Coordinator` for coordinating the started threads. 422 daemon: Whether the threads should be marked as `daemons`, meaning 423 they don't block program exit. 424 start: Set to `False` to only create the threads, not start them. 425 collection: A `GraphKey` specifying the graph collection to 426 get the queue runners from. Defaults to `GraphKeys.QUEUE_RUNNERS`. 427 428 Raises: 429 ValueError: if `sess` is None and there isn't any default session. 430 TypeError: if `sess` is not a `tf.Session` object. 431 432 Returns: 433 A list of threads. 434 435 Raises: 436 RuntimeError: If called with eager execution enabled. 437 ValueError: If called without a default `tf.Session` registered. 438 439 @compatibility(eager) 440 Not compatible with eager execution. To ingest data under eager execution, 441 use the `tf.data` API instead. 442 @end_compatibility 443 """ 444 if context.in_eager_mode(): 445 raise RuntimeError("Queues are not compatible with eager execution.") 446 if sess is None: 447 sess = ops.get_default_session() 448 if not sess: 449 raise ValueError("Cannot start queue runners: No default session is " 450 "registered. Use `with sess.as_default()` or pass an " 451 "explicit session to tf.start_queue_runners(sess=sess)") 452 453 if not isinstance(sess, session.SessionInterface): 454 # Following check is due to backward compatibility. (b/62061352) 455 if sess.__class__.__name__ in [ 456 "MonitoredSession", "SingularMonitoredSession"]: 457 return [] 458 raise TypeError("sess must be a `tf.Session` object. " 459 "Given class: {}".format(sess.__class__)) 460 461 with sess.graph.as_default(): 462 threads = [] 463 for qr in ops.get_collection(collection): 464 threads.extend(qr.create_threads(sess, coord=coord, daemon=daemon, 465 start=start)) 466 return threads 467 468 469 ops.register_proto_function(ops.GraphKeys.QUEUE_RUNNERS, 470 proto_type=queue_runner_pb2.QueueRunnerDef, 471 to_proto=QueueRunner.to_proto, 472 from_proto=QueueRunner.from_proto) 473