Home | History | Annotate | Download | only in training
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for QueueRunner."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import collections
     22 import time
     23 
     24 from tensorflow.python.client import session
     25 from tensorflow.python.framework import constant_op
     26 from tensorflow.python.framework import dtypes
     27 from tensorflow.python.framework import errors_impl
     28 from tensorflow.python.framework import ops
     29 from tensorflow.python.ops import control_flow_ops
     30 from tensorflow.python.ops import data_flow_ops
     31 from tensorflow.python.ops import variables
     32 from tensorflow.python.platform import test
     33 from tensorflow.python.training import coordinator
     34 from tensorflow.python.training import monitored_session
     35 from tensorflow.python.training import queue_runner_impl
     36 
     37 
     38 _MockOp = collections.namedtuple("MockOp", ["name"])
     39 
     40 
     41 class QueueRunnerTest(test.TestCase):
     42 
     43   def testBasic(self):
     44     with self.test_session() as sess:
     45       # CountUpTo will raise OUT_OF_RANGE when it reaches the count.
     46       zero64 = constant_op.constant(0, dtype=dtypes.int64)
     47       var = variables.Variable(zero64)
     48       count_up_to = var.count_up_to(3)
     49       queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
     50       variables.global_variables_initializer().run()
     51       qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
     52       threads = qr.create_threads(sess)
     53       self.assertEqual(sorted(t.name for t in threads),
     54                        ["QueueRunnerThread-fifo_queue-CountUpTo:0"])
     55       for t in threads:
     56         t.start()
     57       for t in threads:
     58         t.join()
     59       self.assertEqual(0, len(qr.exceptions_raised))
     60       # The variable should be 3.
     61       self.assertEqual(3, var.eval())
     62 
     63   def testTwoOps(self):
     64     with self.test_session() as sess:
     65       # CountUpTo will raise OUT_OF_RANGE when it reaches the count.
     66       zero64 = constant_op.constant(0, dtype=dtypes.int64)
     67       var0 = variables.Variable(zero64)
     68       count_up_to_3 = var0.count_up_to(3)
     69       var1 = variables.Variable(zero64)
     70       count_up_to_30 = var1.count_up_to(30)
     71       queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
     72       qr = queue_runner_impl.QueueRunner(queue, [count_up_to_3, count_up_to_30])
     73       threads = qr.create_threads(sess)
     74       self.assertEqual(sorted(t.name for t in threads),
     75                        ["QueueRunnerThread-fifo_queue-CountUpTo:0",
     76                         "QueueRunnerThread-fifo_queue-CountUpTo_1:0"])
     77       variables.global_variables_initializer().run()
     78       for t in threads:
     79         t.start()
     80       for t in threads:
     81         t.join()
     82       self.assertEqual(0, len(qr.exceptions_raised))
     83       self.assertEqual(3, var0.eval())
     84       self.assertEqual(30, var1.eval())
     85 
     86   def testExceptionsCaptured(self):
     87     with self.test_session() as sess:
     88       queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
     89       qr = queue_runner_impl.QueueRunner(queue, [_MockOp("i fail"),
     90                                                  _MockOp("so fail")])
     91       threads = qr.create_threads(sess)
     92       variables.global_variables_initializer().run()
     93       for t in threads:
     94         t.start()
     95       for t in threads:
     96         t.join()
     97       exceptions = qr.exceptions_raised
     98       self.assertEqual(2, len(exceptions))
     99       self.assertTrue("Operation not in the graph" in str(exceptions[0]))
    100       self.assertTrue("Operation not in the graph" in str(exceptions[1]))
    101 
    102   def testRealDequeueEnqueue(self):
    103     with self.test_session() as sess:
    104       q0 = data_flow_ops.FIFOQueue(3, dtypes.float32)
    105       enqueue0 = q0.enqueue((10.0,))
    106       close0 = q0.close()
    107       q1 = data_flow_ops.FIFOQueue(30, dtypes.float32)
    108       enqueue1 = q1.enqueue((q0.dequeue(),))
    109       dequeue1 = q1.dequeue()
    110       qr = queue_runner_impl.QueueRunner(q1, [enqueue1])
    111       threads = qr.create_threads(sess)
    112       for t in threads:
    113         t.start()
    114       # Enqueue 2 values, then close queue0.
    115       enqueue0.run()
    116       enqueue0.run()
    117       close0.run()
    118       # Wait for the queue runner to terminate.
    119       for t in threads:
    120         t.join()
    121       # It should have terminated cleanly.
    122       self.assertEqual(0, len(qr.exceptions_raised))
    123       # The 2 values should be in queue1.
    124       self.assertEqual(10.0, dequeue1.eval())
    125       self.assertEqual(10.0, dequeue1.eval())
    126       # And queue1 should now be closed.
    127       with self.assertRaisesRegexp(errors_impl.OutOfRangeError, "is closed"):
    128         dequeue1.eval()
    129 
    130   def testRespectCoordShouldStop(self):
    131     with self.test_session() as sess:
    132       # CountUpTo will raise OUT_OF_RANGE when it reaches the count.
    133       zero64 = constant_op.constant(0, dtype=dtypes.int64)
    134       var = variables.Variable(zero64)
    135       count_up_to = var.count_up_to(3)
    136       queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
    137       variables.global_variables_initializer().run()
    138       qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
    139       # As the coordinator to stop.  The queue runner should
    140       # finish immediately.
    141       coord = coordinator.Coordinator()
    142       coord.request_stop()
    143       threads = qr.create_threads(sess, coord)
    144       self.assertEqual(sorted(t.name for t in threads),
    145                        ["QueueRunnerThread-fifo_queue-CountUpTo:0",
    146                         "QueueRunnerThread-fifo_queue-close_on_stop"])
    147       for t in threads:
    148         t.start()
    149       coord.join()
    150       self.assertEqual(0, len(qr.exceptions_raised))
    151       # The variable should be 0.
    152       self.assertEqual(0, var.eval())
    153 
    154   def testRequestStopOnException(self):
    155     with self.test_session() as sess:
    156       queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
    157       qr = queue_runner_impl.QueueRunner(queue, [_MockOp("not an op")])
    158       coord = coordinator.Coordinator()
    159       threads = qr.create_threads(sess, coord)
    160       for t in threads:
    161         t.start()
    162       # The exception should be re-raised when joining.
    163       with self.assertRaisesRegexp(ValueError, "Operation not in the graph"):
    164         coord.join()
    165 
    166   def testGracePeriod(self):
    167     with self.test_session() as sess:
    168       # The enqueue will quickly block.
    169       queue = data_flow_ops.FIFOQueue(2, dtypes.float32)
    170       enqueue = queue.enqueue((10.0,))
    171       dequeue = queue.dequeue()
    172       qr = queue_runner_impl.QueueRunner(queue, [enqueue])
    173       coord = coordinator.Coordinator()
    174       qr.create_threads(sess, coord, start=True)
    175       # Dequeue one element and then request stop.
    176       dequeue.op.run()
    177       time.sleep(0.02)
    178       coord.request_stop()
    179       # We should be able to join because the RequestStop() will cause
    180       # the queue to be closed and the enqueue to terminate.
    181       coord.join(stop_grace_period_secs=1.0)
    182 
    183   def testMultipleSessions(self):
    184     with self.test_session() as sess:
    185       with session.Session() as other_sess:
    186         zero64 = constant_op.constant(0, dtype=dtypes.int64)
    187         var = variables.Variable(zero64)
    188         count_up_to = var.count_up_to(3)
    189         queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
    190         variables.global_variables_initializer().run()
    191         coord = coordinator.Coordinator()
    192         qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
    193         # NOTE that this test does not actually start the threads.
    194         threads = qr.create_threads(sess, coord=coord)
    195         other_threads = qr.create_threads(other_sess, coord=coord)
    196         self.assertEqual(len(threads), len(other_threads))
    197 
    198   def testIgnoreMultiStarts(self):
    199     with self.test_session() as sess:
    200       # CountUpTo will raise OUT_OF_RANGE when it reaches the count.
    201       zero64 = constant_op.constant(0, dtype=dtypes.int64)
    202       var = variables.Variable(zero64)
    203       count_up_to = var.count_up_to(3)
    204       queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
    205       variables.global_variables_initializer().run()
    206       coord = coordinator.Coordinator()
    207       qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
    208       threads = []
    209       # NOTE that this test does not actually start the threads.
    210       threads.extend(qr.create_threads(sess, coord=coord))
    211       new_threads = qr.create_threads(sess, coord=coord)
    212       self.assertEqual([], new_threads)
    213 
    214   def testThreads(self):
    215     with self.test_session() as sess:
    216       # CountUpTo will raise OUT_OF_RANGE when it reaches the count.
    217       zero64 = constant_op.constant(0, dtype=dtypes.int64)
    218       var = variables.Variable(zero64)
    219       count_up_to = var.count_up_to(3)
    220       queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
    221       variables.global_variables_initializer().run()
    222       qr = queue_runner_impl.QueueRunner(queue, [count_up_to,
    223                                                  _MockOp("bad_op")])
    224       threads = qr.create_threads(sess, start=True)
    225       self.assertEqual(sorted(t.name for t in threads),
    226                        ["QueueRunnerThread-fifo_queue-CountUpTo:0",
    227                         "QueueRunnerThread-fifo_queue-bad_op"])
    228       for t in threads:
    229         t.join()
    230       exceptions = qr.exceptions_raised
    231       self.assertEqual(1, len(exceptions))
    232       self.assertTrue("Operation not in the graph" in str(exceptions[0]))
    233 
    234       threads = qr.create_threads(sess, start=True)
    235       for t in threads:
    236         t.join()
    237       exceptions = qr.exceptions_raised
    238       self.assertEqual(1, len(exceptions))
    239       self.assertTrue("Operation not in the graph" in str(exceptions[0]))
    240 
    241   def testName(self):
    242     with ops.name_scope("scope"):
    243       queue = data_flow_ops.FIFOQueue(10, dtypes.float32, name="queue")
    244     qr = queue_runner_impl.QueueRunner(queue, [control_flow_ops.no_op()])
    245     self.assertEqual("scope/queue", qr.name)
    246     queue_runner_impl.add_queue_runner(qr)
    247     self.assertEqual(
    248         1, len(ops.get_collection(ops.GraphKeys.QUEUE_RUNNERS, "scope")))
    249 
    250   def testStartQueueRunners(self):
    251     # CountUpTo will raise OUT_OF_RANGE when it reaches the count.
    252     zero64 = constant_op.constant(0, dtype=dtypes.int64)
    253     var = variables.Variable(zero64)
    254     count_up_to = var.count_up_to(3)
    255     queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
    256     init_op = variables.global_variables_initializer()
    257     qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
    258     queue_runner_impl.add_queue_runner(qr)
    259     with self.test_session() as sess:
    260       init_op.run()
    261       threads = queue_runner_impl.start_queue_runners(sess)
    262       for t in threads:
    263         t.join()
    264       self.assertEqual(0, len(qr.exceptions_raised))
    265       # The variable should be 3.
    266       self.assertEqual(3, var.eval())
    267 
    268   def testStartQueueRunnersRaisesIfNotASession(self):
    269     zero64 = constant_op.constant(0, dtype=dtypes.int64)
    270     var = variables.Variable(zero64)
    271     count_up_to = var.count_up_to(3)
    272     queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
    273     init_op = variables.global_variables_initializer()
    274     qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
    275     queue_runner_impl.add_queue_runner(qr)
    276     with self.test_session():
    277       init_op.run()
    278       with self.assertRaisesRegexp(TypeError, "tf.Session"):
    279         queue_runner_impl.start_queue_runners("NotASession")
    280 
    281   def testStartQueueRunnersIgnoresMonitoredSession(self):
    282     zero64 = constant_op.constant(0, dtype=dtypes.int64)
    283     var = variables.Variable(zero64)
    284     count_up_to = var.count_up_to(3)
    285     queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
    286     init_op = variables.global_variables_initializer()
    287     qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
    288     queue_runner_impl.add_queue_runner(qr)
    289     with self.test_session():
    290       init_op.run()
    291       threads = queue_runner_impl.start_queue_runners(
    292           monitored_session.MonitoredSession())
    293       self.assertFalse(threads)
    294 
    295   def testStartQueueRunnersNonDefaultGraph(self):
    296     # CountUpTo will raise OUT_OF_RANGE when it reaches the count.
    297     graph = ops.Graph()
    298     with graph.as_default():
    299       zero64 = constant_op.constant(0, dtype=dtypes.int64)
    300       var = variables.Variable(zero64)
    301       count_up_to = var.count_up_to(3)
    302       queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
    303       init_op = variables.global_variables_initializer()
    304       qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
    305       queue_runner_impl.add_queue_runner(qr)
    306     with self.test_session(graph=graph) as sess:
    307       init_op.run()
    308       threads = queue_runner_impl.start_queue_runners(sess)
    309       for t in threads:
    310         t.join()
    311       self.assertEqual(0, len(qr.exceptions_raised))
    312       # The variable should be 3.
    313       self.assertEqual(3, var.eval())
    314 
    315   def testQueueRunnerSerializationRoundTrip(self):
    316     graph = ops.Graph()
    317     with graph.as_default():
    318       queue = data_flow_ops.FIFOQueue(10, dtypes.float32, name="queue")
    319       enqueue_op = control_flow_ops.no_op(name="enqueue")
    320       close_op = control_flow_ops.no_op(name="close")
    321       cancel_op = control_flow_ops.no_op(name="cancel")
    322       qr0 = queue_runner_impl.QueueRunner(
    323           queue, [enqueue_op],
    324           close_op,
    325           cancel_op,
    326           queue_closed_exception_types=(errors_impl.OutOfRangeError,
    327                                         errors_impl.CancelledError))
    328       qr0_proto = queue_runner_impl.QueueRunner.to_proto(qr0)
    329       qr0_recon = queue_runner_impl.QueueRunner.from_proto(qr0_proto)
    330       self.assertEqual("queue", qr0_recon.queue.name)
    331       self.assertEqual(1, len(qr0_recon.enqueue_ops))
    332       self.assertEqual(enqueue_op, qr0_recon.enqueue_ops[0])
    333       self.assertEqual(close_op, qr0_recon.close_op)
    334       self.assertEqual(cancel_op, qr0_recon.cancel_op)
    335       self.assertEqual(
    336           (errors_impl.OutOfRangeError, errors_impl.CancelledError),
    337           qr0_recon.queue_closed_exception_types)
    338 
    339       # Assert we reconstruct an OutOfRangeError for QueueRunners
    340       # created before QueueRunnerDef had a queue_closed_exception_types field.
    341       del qr0_proto.queue_closed_exception_types[:]
    342       qr0_legacy_recon = queue_runner_impl.QueueRunner.from_proto(qr0_proto)
    343       self.assertEqual("queue", qr0_legacy_recon.queue.name)
    344       self.assertEqual(1, len(qr0_legacy_recon.enqueue_ops))
    345       self.assertEqual(enqueue_op, qr0_legacy_recon.enqueue_ops[0])
    346       self.assertEqual(close_op, qr0_legacy_recon.close_op)
    347       self.assertEqual(cancel_op, qr0_legacy_recon.cancel_op)
    348       self.assertEqual((errors_impl.OutOfRangeError,),
    349                        qr0_legacy_recon.queue_closed_exception_types)
    350 
    351 
    352 if __name__ == "__main__":
    353   test.main()
    354