1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for QueueRunner.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import collections 22 import time 23 24 from tensorflow.python.client import session 25 from tensorflow.python.framework import constant_op 26 from tensorflow.python.framework import dtypes 27 from tensorflow.python.framework import errors_impl 28 from tensorflow.python.framework import ops 29 from tensorflow.python.ops import control_flow_ops 30 from tensorflow.python.ops import data_flow_ops 31 from tensorflow.python.ops import variables 32 from tensorflow.python.platform import test 33 from tensorflow.python.training import coordinator 34 from tensorflow.python.training import monitored_session 35 from tensorflow.python.training import queue_runner_impl 36 37 38 _MockOp = collections.namedtuple("MockOp", ["name"]) 39 40 41 class QueueRunnerTest(test.TestCase): 42 43 def testBasic(self): 44 with self.test_session() as sess: 45 # CountUpTo will raise OUT_OF_RANGE when it reaches the count. 46 zero64 = constant_op.constant(0, dtype=dtypes.int64) 47 var = variables.Variable(zero64) 48 count_up_to = var.count_up_to(3) 49 queue = data_flow_ops.FIFOQueue(10, dtypes.float32) 50 variables.global_variables_initializer().run() 51 qr = queue_runner_impl.QueueRunner(queue, [count_up_to]) 52 threads = qr.create_threads(sess) 53 self.assertEqual(sorted(t.name for t in threads), 54 ["QueueRunnerThread-fifo_queue-CountUpTo:0"]) 55 for t in threads: 56 t.start() 57 for t in threads: 58 t.join() 59 self.assertEqual(0, len(qr.exceptions_raised)) 60 # The variable should be 3. 61 self.assertEqual(3, var.eval()) 62 63 def testTwoOps(self): 64 with self.test_session() as sess: 65 # CountUpTo will raise OUT_OF_RANGE when it reaches the count. 66 zero64 = constant_op.constant(0, dtype=dtypes.int64) 67 var0 = variables.Variable(zero64) 68 count_up_to_3 = var0.count_up_to(3) 69 var1 = variables.Variable(zero64) 70 count_up_to_30 = var1.count_up_to(30) 71 queue = data_flow_ops.FIFOQueue(10, dtypes.float32) 72 qr = queue_runner_impl.QueueRunner(queue, [count_up_to_3, count_up_to_30]) 73 threads = qr.create_threads(sess) 74 self.assertEqual(sorted(t.name for t in threads), 75 ["QueueRunnerThread-fifo_queue-CountUpTo:0", 76 "QueueRunnerThread-fifo_queue-CountUpTo_1:0"]) 77 variables.global_variables_initializer().run() 78 for t in threads: 79 t.start() 80 for t in threads: 81 t.join() 82 self.assertEqual(0, len(qr.exceptions_raised)) 83 self.assertEqual(3, var0.eval()) 84 self.assertEqual(30, var1.eval()) 85 86 def testExceptionsCaptured(self): 87 with self.test_session() as sess: 88 queue = data_flow_ops.FIFOQueue(10, dtypes.float32) 89 qr = queue_runner_impl.QueueRunner(queue, [_MockOp("i fail"), 90 _MockOp("so fail")]) 91 threads = qr.create_threads(sess) 92 variables.global_variables_initializer().run() 93 for t in threads: 94 t.start() 95 for t in threads: 96 t.join() 97 exceptions = qr.exceptions_raised 98 self.assertEqual(2, len(exceptions)) 99 self.assertTrue("Operation not in the graph" in str(exceptions[0])) 100 self.assertTrue("Operation not in the graph" in str(exceptions[1])) 101 102 def testRealDequeueEnqueue(self): 103 with self.test_session() as sess: 104 q0 = data_flow_ops.FIFOQueue(3, dtypes.float32) 105 enqueue0 = q0.enqueue((10.0,)) 106 close0 = q0.close() 107 q1 = data_flow_ops.FIFOQueue(30, dtypes.float32) 108 enqueue1 = q1.enqueue((q0.dequeue(),)) 109 dequeue1 = q1.dequeue() 110 qr = queue_runner_impl.QueueRunner(q1, [enqueue1]) 111 threads = qr.create_threads(sess) 112 for t in threads: 113 t.start() 114 # Enqueue 2 values, then close queue0. 115 enqueue0.run() 116 enqueue0.run() 117 close0.run() 118 # Wait for the queue runner to terminate. 119 for t in threads: 120 t.join() 121 # It should have terminated cleanly. 122 self.assertEqual(0, len(qr.exceptions_raised)) 123 # The 2 values should be in queue1. 124 self.assertEqual(10.0, dequeue1.eval()) 125 self.assertEqual(10.0, dequeue1.eval()) 126 # And queue1 should now be closed. 127 with self.assertRaisesRegexp(errors_impl.OutOfRangeError, "is closed"): 128 dequeue1.eval() 129 130 def testRespectCoordShouldStop(self): 131 with self.test_session() as sess: 132 # CountUpTo will raise OUT_OF_RANGE when it reaches the count. 133 zero64 = constant_op.constant(0, dtype=dtypes.int64) 134 var = variables.Variable(zero64) 135 count_up_to = var.count_up_to(3) 136 queue = data_flow_ops.FIFOQueue(10, dtypes.float32) 137 variables.global_variables_initializer().run() 138 qr = queue_runner_impl.QueueRunner(queue, [count_up_to]) 139 # As the coordinator to stop. The queue runner should 140 # finish immediately. 141 coord = coordinator.Coordinator() 142 coord.request_stop() 143 threads = qr.create_threads(sess, coord) 144 self.assertEqual(sorted(t.name for t in threads), 145 ["QueueRunnerThread-fifo_queue-CountUpTo:0", 146 "QueueRunnerThread-fifo_queue-close_on_stop"]) 147 for t in threads: 148 t.start() 149 coord.join() 150 self.assertEqual(0, len(qr.exceptions_raised)) 151 # The variable should be 0. 152 self.assertEqual(0, var.eval()) 153 154 def testRequestStopOnException(self): 155 with self.test_session() as sess: 156 queue = data_flow_ops.FIFOQueue(10, dtypes.float32) 157 qr = queue_runner_impl.QueueRunner(queue, [_MockOp("not an op")]) 158 coord = coordinator.Coordinator() 159 threads = qr.create_threads(sess, coord) 160 for t in threads: 161 t.start() 162 # The exception should be re-raised when joining. 163 with self.assertRaisesRegexp(ValueError, "Operation not in the graph"): 164 coord.join() 165 166 def testGracePeriod(self): 167 with self.test_session() as sess: 168 # The enqueue will quickly block. 169 queue = data_flow_ops.FIFOQueue(2, dtypes.float32) 170 enqueue = queue.enqueue((10.0,)) 171 dequeue = queue.dequeue() 172 qr = queue_runner_impl.QueueRunner(queue, [enqueue]) 173 coord = coordinator.Coordinator() 174 qr.create_threads(sess, coord, start=True) 175 # Dequeue one element and then request stop. 176 dequeue.op.run() 177 time.sleep(0.02) 178 coord.request_stop() 179 # We should be able to join because the RequestStop() will cause 180 # the queue to be closed and the enqueue to terminate. 181 coord.join(stop_grace_period_secs=1.0) 182 183 def testMultipleSessions(self): 184 with self.test_session() as sess: 185 with session.Session() as other_sess: 186 zero64 = constant_op.constant(0, dtype=dtypes.int64) 187 var = variables.Variable(zero64) 188 count_up_to = var.count_up_to(3) 189 queue = data_flow_ops.FIFOQueue(10, dtypes.float32) 190 variables.global_variables_initializer().run() 191 coord = coordinator.Coordinator() 192 qr = queue_runner_impl.QueueRunner(queue, [count_up_to]) 193 # NOTE that this test does not actually start the threads. 194 threads = qr.create_threads(sess, coord=coord) 195 other_threads = qr.create_threads(other_sess, coord=coord) 196 self.assertEqual(len(threads), len(other_threads)) 197 198 def testIgnoreMultiStarts(self): 199 with self.test_session() as sess: 200 # CountUpTo will raise OUT_OF_RANGE when it reaches the count. 201 zero64 = constant_op.constant(0, dtype=dtypes.int64) 202 var = variables.Variable(zero64) 203 count_up_to = var.count_up_to(3) 204 queue = data_flow_ops.FIFOQueue(10, dtypes.float32) 205 variables.global_variables_initializer().run() 206 coord = coordinator.Coordinator() 207 qr = queue_runner_impl.QueueRunner(queue, [count_up_to]) 208 threads = [] 209 # NOTE that this test does not actually start the threads. 210 threads.extend(qr.create_threads(sess, coord=coord)) 211 new_threads = qr.create_threads(sess, coord=coord) 212 self.assertEqual([], new_threads) 213 214 def testThreads(self): 215 with self.test_session() as sess: 216 # CountUpTo will raise OUT_OF_RANGE when it reaches the count. 217 zero64 = constant_op.constant(0, dtype=dtypes.int64) 218 var = variables.Variable(zero64) 219 count_up_to = var.count_up_to(3) 220 queue = data_flow_ops.FIFOQueue(10, dtypes.float32) 221 variables.global_variables_initializer().run() 222 qr = queue_runner_impl.QueueRunner(queue, [count_up_to, 223 _MockOp("bad_op")]) 224 threads = qr.create_threads(sess, start=True) 225 self.assertEqual(sorted(t.name for t in threads), 226 ["QueueRunnerThread-fifo_queue-CountUpTo:0", 227 "QueueRunnerThread-fifo_queue-bad_op"]) 228 for t in threads: 229 t.join() 230 exceptions = qr.exceptions_raised 231 self.assertEqual(1, len(exceptions)) 232 self.assertTrue("Operation not in the graph" in str(exceptions[0])) 233 234 threads = qr.create_threads(sess, start=True) 235 for t in threads: 236 t.join() 237 exceptions = qr.exceptions_raised 238 self.assertEqual(1, len(exceptions)) 239 self.assertTrue("Operation not in the graph" in str(exceptions[0])) 240 241 def testName(self): 242 with ops.name_scope("scope"): 243 queue = data_flow_ops.FIFOQueue(10, dtypes.float32, name="queue") 244 qr = queue_runner_impl.QueueRunner(queue, [control_flow_ops.no_op()]) 245 self.assertEqual("scope/queue", qr.name) 246 queue_runner_impl.add_queue_runner(qr) 247 self.assertEqual( 248 1, len(ops.get_collection(ops.GraphKeys.QUEUE_RUNNERS, "scope"))) 249 250 def testStartQueueRunners(self): 251 # CountUpTo will raise OUT_OF_RANGE when it reaches the count. 252 zero64 = constant_op.constant(0, dtype=dtypes.int64) 253 var = variables.Variable(zero64) 254 count_up_to = var.count_up_to(3) 255 queue = data_flow_ops.FIFOQueue(10, dtypes.float32) 256 init_op = variables.global_variables_initializer() 257 qr = queue_runner_impl.QueueRunner(queue, [count_up_to]) 258 queue_runner_impl.add_queue_runner(qr) 259 with self.test_session() as sess: 260 init_op.run() 261 threads = queue_runner_impl.start_queue_runners(sess) 262 for t in threads: 263 t.join() 264 self.assertEqual(0, len(qr.exceptions_raised)) 265 # The variable should be 3. 266 self.assertEqual(3, var.eval()) 267 268 def testStartQueueRunnersRaisesIfNotASession(self): 269 zero64 = constant_op.constant(0, dtype=dtypes.int64) 270 var = variables.Variable(zero64) 271 count_up_to = var.count_up_to(3) 272 queue = data_flow_ops.FIFOQueue(10, dtypes.float32) 273 init_op = variables.global_variables_initializer() 274 qr = queue_runner_impl.QueueRunner(queue, [count_up_to]) 275 queue_runner_impl.add_queue_runner(qr) 276 with self.test_session(): 277 init_op.run() 278 with self.assertRaisesRegexp(TypeError, "tf.Session"): 279 queue_runner_impl.start_queue_runners("NotASession") 280 281 def testStartQueueRunnersIgnoresMonitoredSession(self): 282 zero64 = constant_op.constant(0, dtype=dtypes.int64) 283 var = variables.Variable(zero64) 284 count_up_to = var.count_up_to(3) 285 queue = data_flow_ops.FIFOQueue(10, dtypes.float32) 286 init_op = variables.global_variables_initializer() 287 qr = queue_runner_impl.QueueRunner(queue, [count_up_to]) 288 queue_runner_impl.add_queue_runner(qr) 289 with self.test_session(): 290 init_op.run() 291 threads = queue_runner_impl.start_queue_runners( 292 monitored_session.MonitoredSession()) 293 self.assertFalse(threads) 294 295 def testStartQueueRunnersNonDefaultGraph(self): 296 # CountUpTo will raise OUT_OF_RANGE when it reaches the count. 297 graph = ops.Graph() 298 with graph.as_default(): 299 zero64 = constant_op.constant(0, dtype=dtypes.int64) 300 var = variables.Variable(zero64) 301 count_up_to = var.count_up_to(3) 302 queue = data_flow_ops.FIFOQueue(10, dtypes.float32) 303 init_op = variables.global_variables_initializer() 304 qr = queue_runner_impl.QueueRunner(queue, [count_up_to]) 305 queue_runner_impl.add_queue_runner(qr) 306 with self.test_session(graph=graph) as sess: 307 init_op.run() 308 threads = queue_runner_impl.start_queue_runners(sess) 309 for t in threads: 310 t.join() 311 self.assertEqual(0, len(qr.exceptions_raised)) 312 # The variable should be 3. 313 self.assertEqual(3, var.eval()) 314 315 def testQueueRunnerSerializationRoundTrip(self): 316 graph = ops.Graph() 317 with graph.as_default(): 318 queue = data_flow_ops.FIFOQueue(10, dtypes.float32, name="queue") 319 enqueue_op = control_flow_ops.no_op(name="enqueue") 320 close_op = control_flow_ops.no_op(name="close") 321 cancel_op = control_flow_ops.no_op(name="cancel") 322 qr0 = queue_runner_impl.QueueRunner( 323 queue, [enqueue_op], 324 close_op, 325 cancel_op, 326 queue_closed_exception_types=(errors_impl.OutOfRangeError, 327 errors_impl.CancelledError)) 328 qr0_proto = queue_runner_impl.QueueRunner.to_proto(qr0) 329 qr0_recon = queue_runner_impl.QueueRunner.from_proto(qr0_proto) 330 self.assertEqual("queue", qr0_recon.queue.name) 331 self.assertEqual(1, len(qr0_recon.enqueue_ops)) 332 self.assertEqual(enqueue_op, qr0_recon.enqueue_ops[0]) 333 self.assertEqual(close_op, qr0_recon.close_op) 334 self.assertEqual(cancel_op, qr0_recon.cancel_op) 335 self.assertEqual( 336 (errors_impl.OutOfRangeError, errors_impl.CancelledError), 337 qr0_recon.queue_closed_exception_types) 338 339 # Assert we reconstruct an OutOfRangeError for QueueRunners 340 # created before QueueRunnerDef had a queue_closed_exception_types field. 341 del qr0_proto.queue_closed_exception_types[:] 342 qr0_legacy_recon = queue_runner_impl.QueueRunner.from_proto(qr0_proto) 343 self.assertEqual("queue", qr0_legacy_recon.queue.name) 344 self.assertEqual(1, len(qr0_legacy_recon.enqueue_ops)) 345 self.assertEqual(enqueue_op, qr0_legacy_recon.enqueue_ops[0]) 346 self.assertEqual(close_op, qr0_legacy_recon.close_op) 347 self.assertEqual(cancel_op, qr0_legacy_recon.cancel_op) 348 self.assertEqual((errors_impl.OutOfRangeError,), 349 qr0_legacy_recon.queue_closed_exception_types) 350 351 352 if __name__ == "__main__": 353 test.main() 354