1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for tensorflow.ops.data_flow_ops.FIFOQueue.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import time 22 23 from six.moves import xrange # pylint: disable=redefined-builtin 24 25 from tensorflow.compiler.tests import xla_test 26 from tensorflow.python.framework import dtypes as dtypes_lib 27 from tensorflow.python.ops import data_flow_ops 28 from tensorflow.python.platform import test 29 30 31 class FIFOQueueTest(xla_test.XLATestCase): 32 33 def testEnqueue(self): 34 with self.cached_session(), self.test_scope(): 35 q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) 36 enqueue_op = q.enqueue((10.0,)) 37 enqueue_op.run() 38 39 def testEnqueueWithShape(self): 40 with self.cached_session(), self.test_scope(): 41 q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shapes=(3, 2)) 42 enqueue_correct_op = q.enqueue(([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],)) 43 enqueue_correct_op.run() 44 with self.assertRaises(ValueError): 45 q.enqueue(([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],)) 46 self.assertEqual(1, q.size().eval()) 47 48 def testMultipleDequeues(self): 49 with self.cached_session(), self.test_scope(): 50 q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) 51 self.evaluate(q.enqueue([1])) 52 self.evaluate(q.enqueue([2])) 53 self.evaluate(q.enqueue([3])) 54 a, b, c = self.evaluate([q.dequeue(), q.dequeue(), q.dequeue()]) 55 self.assertAllEqual(set([1, 2, 3]), set([a, b, c])) 56 57 def testQueuesDontShare(self): 58 with self.cached_session(), self.test_scope(): 59 q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) 60 self.evaluate(q.enqueue(1)) 61 q2 = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) 62 self.evaluate(q2.enqueue(2)) 63 self.assertAllEqual(self.evaluate(q2.dequeue()), 2) 64 self.assertAllEqual(self.evaluate(q.dequeue()), 1) 65 66 def testEnqueueDictWithoutNames(self): 67 with self.cached_session(), self.test_scope(): 68 q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) 69 with self.assertRaisesRegexp(ValueError, "must have names"): 70 q.enqueue({"a": 12.0}) 71 72 def testParallelEnqueue(self): 73 with self.cached_session() as sess, self.test_scope(): 74 q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) 75 elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0] 76 enqueue_ops = [q.enqueue((x,)) for x in elems] 77 dequeued_t = q.dequeue() 78 79 # Run one producer thread for each element in elems. 80 def enqueue(enqueue_op): 81 sess.run(enqueue_op) 82 83 threads = [ 84 self.checkedThread(target=enqueue, args=(e,)) for e in enqueue_ops 85 ] 86 for thread in threads: 87 thread.start() 88 for thread in threads: 89 thread.join() 90 91 # Dequeue every element using a single thread. 92 results = [] 93 for _ in xrange(len(elems)): 94 results.append(dequeued_t.eval()) 95 self.assertItemsEqual(elems, results) 96 97 def testParallelDequeue(self): 98 with self.cached_session() as sess, self.test_scope(): 99 q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) 100 elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0] 101 enqueue_ops = [q.enqueue((x,)) for x in elems] 102 dequeued_t = q.dequeue() 103 104 # Enqueue every element using a single thread. 105 for enqueue_op in enqueue_ops: 106 enqueue_op.run() 107 108 # Run one consumer thread for each element in elems. 109 results = [] 110 111 def dequeue(): 112 results.append(sess.run(dequeued_t)) 113 114 threads = [self.checkedThread(target=dequeue) for _ in enqueue_ops] 115 for thread in threads: 116 thread.start() 117 for thread in threads: 118 thread.join() 119 self.assertItemsEqual(elems, results) 120 121 def testDequeue(self): 122 with self.cached_session(), self.test_scope(): 123 q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) 124 elems = [10.0, 20.0, 30.0] 125 enqueue_ops = [q.enqueue((x,)) for x in elems] 126 dequeued_t = q.dequeue() 127 128 for enqueue_op in enqueue_ops: 129 enqueue_op.run() 130 131 for i in xrange(len(elems)): 132 vals = self.evaluate(dequeued_t) 133 self.assertEqual([elems[i]], vals) 134 135 def testEnqueueAndBlockingDequeue(self): 136 with self.cached_session() as sess, self.test_scope(): 137 q = data_flow_ops.FIFOQueue(3, dtypes_lib.float32) 138 elems = [10.0, 20.0, 30.0] 139 enqueue_ops = [q.enqueue((x,)) for x in elems] 140 dequeued_t = q.dequeue() 141 142 def enqueue(): 143 # The enqueue_ops should run after the dequeue op has blocked. 144 # TODO(mrry): Figure out how to do this without sleeping. 145 time.sleep(0.1) 146 for enqueue_op in enqueue_ops: 147 sess.run(enqueue_op) 148 149 results = [] 150 151 def dequeue(): 152 for _ in xrange(len(elems)): 153 results.append(sess.run(dequeued_t)) 154 155 enqueue_thread = self.checkedThread(target=enqueue) 156 dequeue_thread = self.checkedThread(target=dequeue) 157 enqueue_thread.start() 158 dequeue_thread.start() 159 enqueue_thread.join() 160 dequeue_thread.join() 161 162 for elem, result in zip(elems, results): 163 self.assertEqual([elem], result) 164 165 def testMultiEnqueueAndDequeue(self): 166 with self.cached_session() as sess, self.test_scope(): 167 q = data_flow_ops.FIFOQueue(10, (dtypes_lib.int32, dtypes_lib.float32)) 168 elems = [(5, 10.0), (10, 20.0), (15, 30.0)] 169 enqueue_ops = [q.enqueue((x, y)) for x, y in elems] 170 dequeued_t = q.dequeue() 171 172 for enqueue_op in enqueue_ops: 173 enqueue_op.run() 174 175 for i in xrange(len(elems)): 176 x_val, y_val = sess.run(dequeued_t) 177 x, y = elems[i] 178 self.assertEqual([x], x_val) 179 self.assertEqual([y], y_val) 180 181 def testQueueSizeEmpty(self): 182 with self.cached_session(), self.test_scope(): 183 q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) 184 self.assertEqual([0], q.size().eval()) 185 186 def testQueueSizeAfterEnqueueAndDequeue(self): 187 with self.cached_session(), self.test_scope(): 188 q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) 189 enqueue_op = q.enqueue((10.0,)) 190 dequeued_t = q.dequeue() 191 size = q.size() 192 self.assertEqual([], size.get_shape()) 193 194 enqueue_op.run() 195 self.assertEqual(1, self.evaluate(size)) 196 dequeued_t.op.run() 197 self.assertEqual(0, self.evaluate(size)) 198 199 200 if __name__ == "__main__": 201 test.main() 202