Home | History | Annotate | Download | only in tests
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for tensorflow.ops.data_flow_ops.FIFOQueue."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import time
     22 
     23 from six.moves import xrange  # pylint: disable=redefined-builtin
     24 
     25 from tensorflow.compiler.tests import xla_test
     26 from tensorflow.python.framework import dtypes as dtypes_lib
     27 from tensorflow.python.ops import data_flow_ops
     28 from tensorflow.python.platform import test
     29 
     30 
     31 class FIFOQueueTest(xla_test.XLATestCase):
     32 
     33   def testEnqueue(self):
     34     with self.cached_session(), self.test_scope():
     35       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
     36       enqueue_op = q.enqueue((10.0,))
     37       enqueue_op.run()
     38 
     39   def testEnqueueWithShape(self):
     40     with self.cached_session(), self.test_scope():
     41       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shapes=(3, 2))
     42       enqueue_correct_op = q.enqueue(([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],))
     43       enqueue_correct_op.run()
     44       with self.assertRaises(ValueError):
     45         q.enqueue(([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],))
     46       self.assertEqual(1, q.size().eval())
     47 
     48   def testMultipleDequeues(self):
     49     with self.cached_session(), self.test_scope():
     50       q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
     51       self.evaluate(q.enqueue([1]))
     52       self.evaluate(q.enqueue([2]))
     53       self.evaluate(q.enqueue([3]))
     54       a, b, c = self.evaluate([q.dequeue(), q.dequeue(), q.dequeue()])
     55       self.assertAllEqual(set([1, 2, 3]), set([a, b, c]))
     56 
     57   def testQueuesDontShare(self):
     58     with self.cached_session(), self.test_scope():
     59       q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
     60       self.evaluate(q.enqueue(1))
     61       q2 = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
     62       self.evaluate(q2.enqueue(2))
     63       self.assertAllEqual(self.evaluate(q2.dequeue()), 2)
     64       self.assertAllEqual(self.evaluate(q.dequeue()), 1)
     65 
     66   def testEnqueueDictWithoutNames(self):
     67     with self.cached_session(), self.test_scope():
     68       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
     69       with self.assertRaisesRegexp(ValueError, "must have names"):
     70         q.enqueue({"a": 12.0})
     71 
     72   def testParallelEnqueue(self):
     73     with self.cached_session() as sess, self.test_scope():
     74       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
     75       elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
     76       enqueue_ops = [q.enqueue((x,)) for x in elems]
     77       dequeued_t = q.dequeue()
     78 
     79       # Run one producer thread for each element in elems.
     80       def enqueue(enqueue_op):
     81         sess.run(enqueue_op)
     82 
     83       threads = [
     84           self.checkedThread(target=enqueue, args=(e,)) for e in enqueue_ops
     85       ]
     86       for thread in threads:
     87         thread.start()
     88       for thread in threads:
     89         thread.join()
     90 
     91       # Dequeue every element using a single thread.
     92       results = []
     93       for _ in xrange(len(elems)):
     94         results.append(dequeued_t.eval())
     95       self.assertItemsEqual(elems, results)
     96 
     97   def testParallelDequeue(self):
     98     with self.cached_session() as sess, self.test_scope():
     99       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
    100       elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
    101       enqueue_ops = [q.enqueue((x,)) for x in elems]
    102       dequeued_t = q.dequeue()
    103 
    104       # Enqueue every element using a single thread.
    105       for enqueue_op in enqueue_ops:
    106         enqueue_op.run()
    107 
    108       # Run one consumer thread for each element in elems.
    109       results = []
    110 
    111       def dequeue():
    112         results.append(sess.run(dequeued_t))
    113 
    114       threads = [self.checkedThread(target=dequeue) for _ in enqueue_ops]
    115       for thread in threads:
    116         thread.start()
    117       for thread in threads:
    118         thread.join()
    119       self.assertItemsEqual(elems, results)
    120 
    121   def testDequeue(self):
    122     with self.cached_session(), self.test_scope():
    123       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
    124       elems = [10.0, 20.0, 30.0]
    125       enqueue_ops = [q.enqueue((x,)) for x in elems]
    126       dequeued_t = q.dequeue()
    127 
    128       for enqueue_op in enqueue_ops:
    129         enqueue_op.run()
    130 
    131       for i in xrange(len(elems)):
    132         vals = self.evaluate(dequeued_t)
    133         self.assertEqual([elems[i]], vals)
    134 
    135   def testEnqueueAndBlockingDequeue(self):
    136     with self.cached_session() as sess, self.test_scope():
    137       q = data_flow_ops.FIFOQueue(3, dtypes_lib.float32)
    138       elems = [10.0, 20.0, 30.0]
    139       enqueue_ops = [q.enqueue((x,)) for x in elems]
    140       dequeued_t = q.dequeue()
    141 
    142       def enqueue():
    143         # The enqueue_ops should run after the dequeue op has blocked.
    144         # TODO(mrry): Figure out how to do this without sleeping.
    145         time.sleep(0.1)
    146         for enqueue_op in enqueue_ops:
    147           sess.run(enqueue_op)
    148 
    149       results = []
    150 
    151       def dequeue():
    152         for _ in xrange(len(elems)):
    153           results.append(sess.run(dequeued_t))
    154 
    155       enqueue_thread = self.checkedThread(target=enqueue)
    156       dequeue_thread = self.checkedThread(target=dequeue)
    157       enqueue_thread.start()
    158       dequeue_thread.start()
    159       enqueue_thread.join()
    160       dequeue_thread.join()
    161 
    162       for elem, result in zip(elems, results):
    163         self.assertEqual([elem], result)
    164 
    165   def testMultiEnqueueAndDequeue(self):
    166     with self.cached_session() as sess, self.test_scope():
    167       q = data_flow_ops.FIFOQueue(10, (dtypes_lib.int32, dtypes_lib.float32))
    168       elems = [(5, 10.0), (10, 20.0), (15, 30.0)]
    169       enqueue_ops = [q.enqueue((x, y)) for x, y in elems]
    170       dequeued_t = q.dequeue()
    171 
    172       for enqueue_op in enqueue_ops:
    173         enqueue_op.run()
    174 
    175       for i in xrange(len(elems)):
    176         x_val, y_val = sess.run(dequeued_t)
    177         x, y = elems[i]
    178         self.assertEqual([x], x_val)
    179         self.assertEqual([y], y_val)
    180 
    181   def testQueueSizeEmpty(self):
    182     with self.cached_session(), self.test_scope():
    183       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
    184       self.assertEqual([0], q.size().eval())
    185 
    186   def testQueueSizeAfterEnqueueAndDequeue(self):
    187     with self.cached_session(), self.test_scope():
    188       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
    189       enqueue_op = q.enqueue((10.0,))
    190       dequeued_t = q.dequeue()
    191       size = q.size()
    192       self.assertEqual([], size.get_shape())
    193 
    194       enqueue_op.run()
    195       self.assertEqual(1, self.evaluate(size))
    196       dequeued_t.op.run()
    197       self.assertEqual(0, self.evaluate(size))
    198 
    199 
    200 if __name__ == "__main__":
    201   test.main()
    202