1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for slim.data.prefetch_queue.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.contrib.slim.python.slim.data import prefetch_queue 24 from tensorflow.python.framework import constant_op 25 from tensorflow.python.framework import dtypes 26 from tensorflow.python.framework import errors_impl 27 from tensorflow.python.framework import ops 28 from tensorflow.python.ops import array_ops 29 from tensorflow.python.ops import data_flow_ops 30 from tensorflow.python.ops import random_ops 31 from tensorflow.python.ops import variables 32 from tensorflow.python.platform import test 33 from tensorflow.python.training import input as input_lib 34 from tensorflow.python.training import queue_runner_impl 35 36 37 class PrefetchQueueTest(test.TestCase): 38 39 def testOneThread(self): 40 with self.test_session() as sess: 41 batch_size = 10 42 image_size = 32 43 num_batches = 5 44 45 zero64 = constant_op.constant(0, dtype=dtypes.int64) 46 47 examples = variables.Variable(zero64) 48 counter = examples.count_up_to(num_batches * batch_size) 49 image = random_ops.random_normal( 50 [image_size, image_size, 3], dtype=dtypes.float32, name='images') 51 label = random_ops.random_uniform( 52 [1], 0, 10, dtype=dtypes.int32, name='labels') 53 54 batches = input_lib.batch( 55 [counter, image, label], batch_size=batch_size, num_threads=1) 56 57 batches = prefetch_queue.prefetch_queue(batches).dequeue() 58 59 variables.global_variables_initializer().run() 60 threads = queue_runner_impl.start_queue_runners() 61 62 for i in range(num_batches): 63 results = sess.run(batches) 64 self.assertAllEqual(results[0], 65 np.arange(i * batch_size, (i + 1) * batch_size)) 66 self.assertEquals(results[1].shape, 67 (batch_size, image_size, image_size, 3)) 68 self.assertEquals(results[2].shape, (batch_size, 1)) 69 70 # Reached the limit. 71 with self.assertRaises(errors_impl.OutOfRangeError): 72 sess.run(batches) 73 for thread in threads: 74 thread.join() 75 76 def testMultiThread(self): 77 with self.test_session() as sess: 78 batch_size = 10 79 image_size = 32 80 num_batches = 5 81 82 zero64 = constant_op.constant(0, dtype=dtypes.int64) 83 84 examples = variables.Variable(zero64) 85 counter = examples.count_up_to(num_batches * batch_size) 86 image = random_ops.random_normal( 87 [image_size, image_size, 3], dtype=dtypes.float32, name='images') 88 label = random_ops.random_uniform( 89 [1], 0, 10, dtype=dtypes.int32, name='labels') 90 91 batches = input_lib.batch( 92 [counter, image, label], batch_size=batch_size, num_threads=4) 93 94 batches = prefetch_queue.prefetch_queue(batches).dequeue() 95 96 variables.global_variables_initializer().run() 97 threads = queue_runner_impl.start_queue_runners() 98 99 value_counter = [] 100 for _ in range(num_batches): 101 results = sess.run(batches) 102 value_counter.append(results[0]) 103 self.assertEqual(results[1].shape, 104 (batch_size, image_size, image_size, 3)) 105 self.assertEqual(results[2].shape, (batch_size, 1)) 106 107 self.assertAllEqual( 108 np.sort(np.concatenate(value_counter)), 109 np.arange(0, num_batches * batch_size)) 110 # Reached the limit. 111 with self.assertRaises(errors_impl.OutOfRangeError): 112 sess.run(batches) 113 for thread in threads: 114 thread.join() 115 116 def testMultipleDequeue(self): 117 with self.test_session() as sess: 118 batch_size = 10 119 image_size = 32 120 num_batches = 4 121 122 zero64 = constant_op.constant(0, dtype=dtypes.int64) 123 124 examples = variables.Variable(zero64) 125 counter = examples.count_up_to(num_batches * batch_size) 126 image = random_ops.random_normal( 127 [image_size, image_size, 3], dtype=dtypes.float32, name='images') 128 label = random_ops.random_uniform( 129 [1], 0, 10, dtype=dtypes.int32, name='labels') 130 131 batches = input_lib.batch( 132 [counter, image, label], batch_size=batch_size, num_threads=4) 133 134 batcher = prefetch_queue.prefetch_queue(batches) 135 batches_list = [batcher.dequeue() for _ in range(2)] 136 137 variables.global_variables_initializer().run() 138 threads = queue_runner_impl.start_queue_runners() 139 140 value_counter = [] 141 for _ in range(int(num_batches / 2)): 142 for batches in batches_list: 143 results = sess.run(batches) 144 value_counter.append(results[0]) 145 self.assertEquals(results[1].shape, 146 (batch_size, image_size, image_size, 3)) 147 self.assertEquals(results[2].shape, (batch_size, 1)) 148 149 self.assertAllEqual( 150 np.sort(np.concatenate(value_counter)), 151 np.arange(0, num_batches * batch_size)) 152 # Reached the limit. 153 with self.assertRaises(errors_impl.OutOfRangeError): 154 sess.run(batches) 155 for thread in threads: 156 thread.join() 157 158 def testDynamicPad_failure(self): 159 with ops.Graph().as_default(): 160 variable_tensor = array_ops.placeholder(dtypes.int32, shape=[None, 3]) 161 with self.assertRaisesRegexp(ValueError, 'shapes must be fully defined'): 162 prefetch_queue.prefetch_queue([variable_tensor]) 163 164 def testDynamicPad(self): 165 with self.test_session() as sess: 166 # Create 3 tensors of variable but compatible shapes. 167 var_shape = [None, 2] 168 p1 = constant_op.constant([[1, 2], [3, 4]]) 169 p1.set_shape(var_shape) 170 p2 = constant_op.constant([[5, 6], [7, 8], [9, 10]]) 171 p2.set_shape(var_shape) 172 p3 = constant_op.constant([[11, 12]]) 173 p3.set_shape(var_shape) 174 batch = [p1, p2, p3] 175 batch_size = len(batch) 176 177 zero64 = constant_op.constant(0, dtype=dtypes.int64) 178 examples = variables.Variable(zero64) 179 counter = examples.count_up_to(batch_size) 180 181 # Create a PaddingFIFOQueue to enqueue these tensors. 182 q = data_flow_ops.PaddingFIFOQueue( 183 capacity=10, dtypes=[dtypes.int32], shapes=[var_shape]) 184 for tensor in [p1, p2, p3]: 185 q.enqueue([tensor]).run() 186 187 # Dequeue from the queue and batch them using batch(). 188 batches = input_lib.batch([q.dequeue(), counter], batch_size=batch_size, 189 num_threads=1, dynamic_pad=True) 190 self.assertEqual([batch_size, None, 2], batches[0].shape.as_list()) 191 192 # Finally, assemble them into prefetch_queue with dynamic_pad. 193 batcher = prefetch_queue.prefetch_queue(batches, dynamic_pad=True) 194 batches = batcher.dequeue() 195 self.assertEqual([batch_size, None, 2], batches[0].shape.as_list()) 196 197 variables.global_variables_initializer().run() 198 threads = queue_runner_impl.start_queue_runners() 199 200 values, _ = sess.run(batches) 201 # We enqueued 3 tensors of [None, 2] shapes, so using dynamic_pad 202 # they should be padded to the fixed size [3, 3, 2], where 3 203 # is the maximum length of the batch. 204 self.assertTrue(np.array_equal( 205 np.array([[[1, 2], [3, 4], [0, 0]], 206 [[5, 6], [7, 8], [9, 10]], 207 [[11, 12], [0, 0], [0, 0]]]), 208 values)) 209 210 with self.assertRaises(errors_impl.OutOfRangeError): 211 sess.run(batches) 212 for thread in threads: 213 thread.join() 214 215 def testDictConstruction(self): 216 with ops.Graph().as_default(): 217 batches = { 218 'first': constant_op.constant([1]), 219 'second': constant_op.constant([2.0, 2.1]) 220 } 221 prefetcher = prefetch_queue.prefetch_queue(batches) 222 dequeued = prefetcher.dequeue() 223 self.assertTrue(isinstance(dequeued, dict)) 224 self.assertEqual(2, len(dequeued)) 225 self.assertEqual(dtypes.int32, dequeued['first'].dtype) 226 self.assertEqual(dtypes.float32, dequeued['second'].dtype) 227 228 229 if __name__ == '__main__': 230 test.main() 231