Home | History | Annotate | Download | only in data
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for slim.data.prefetch_queue."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.contrib.slim.python.slim.data import prefetch_queue
     24 from tensorflow.python.framework import constant_op
     25 from tensorflow.python.framework import dtypes
     26 from tensorflow.python.framework import errors_impl
     27 from tensorflow.python.framework import ops
     28 from tensorflow.python.ops import array_ops
     29 from tensorflow.python.ops import data_flow_ops
     30 from tensorflow.python.ops import random_ops
     31 from tensorflow.python.ops import variables
     32 from tensorflow.python.platform import test
     33 from tensorflow.python.training import input as input_lib
     34 from tensorflow.python.training import queue_runner_impl
     35 
     36 
     37 class PrefetchQueueTest(test.TestCase):
     38 
     39   def testOneThread(self):
     40     with self.test_session() as sess:
     41       batch_size = 10
     42       image_size = 32
     43       num_batches = 5
     44 
     45       zero64 = constant_op.constant(0, dtype=dtypes.int64)
     46 
     47       examples = variables.Variable(zero64)
     48       counter = examples.count_up_to(num_batches * batch_size)
     49       image = random_ops.random_normal(
     50           [image_size, image_size, 3], dtype=dtypes.float32, name='images')
     51       label = random_ops.random_uniform(
     52           [1], 0, 10, dtype=dtypes.int32, name='labels')
     53 
     54       batches = input_lib.batch(
     55           [counter, image, label], batch_size=batch_size, num_threads=1)
     56 
     57       batches = prefetch_queue.prefetch_queue(batches).dequeue()
     58 
     59       variables.global_variables_initializer().run()
     60       threads = queue_runner_impl.start_queue_runners()
     61 
     62       for i in range(num_batches):
     63         results = sess.run(batches)
     64         self.assertAllEqual(results[0],
     65                             np.arange(i * batch_size, (i + 1) * batch_size))
     66         self.assertEquals(results[1].shape,
     67                           (batch_size, image_size, image_size, 3))
     68         self.assertEquals(results[2].shape, (batch_size, 1))
     69 
     70       # Reached the limit.
     71       with self.assertRaises(errors_impl.OutOfRangeError):
     72         sess.run(batches)
     73       for thread in threads:
     74         thread.join()
     75 
     76   def testMultiThread(self):
     77     with self.test_session() as sess:
     78       batch_size = 10
     79       image_size = 32
     80       num_batches = 5
     81 
     82       zero64 = constant_op.constant(0, dtype=dtypes.int64)
     83 
     84       examples = variables.Variable(zero64)
     85       counter = examples.count_up_to(num_batches * batch_size)
     86       image = random_ops.random_normal(
     87           [image_size, image_size, 3], dtype=dtypes.float32, name='images')
     88       label = random_ops.random_uniform(
     89           [1], 0, 10, dtype=dtypes.int32, name='labels')
     90 
     91       batches = input_lib.batch(
     92           [counter, image, label], batch_size=batch_size, num_threads=4)
     93 
     94       batches = prefetch_queue.prefetch_queue(batches).dequeue()
     95 
     96       variables.global_variables_initializer().run()
     97       threads = queue_runner_impl.start_queue_runners()
     98 
     99       value_counter = []
    100       for _ in range(num_batches):
    101         results = sess.run(batches)
    102         value_counter.append(results[0])
    103         self.assertEqual(results[1].shape,
    104                          (batch_size, image_size, image_size, 3))
    105         self.assertEqual(results[2].shape, (batch_size, 1))
    106 
    107       self.assertAllEqual(
    108           np.sort(np.concatenate(value_counter)),
    109           np.arange(0, num_batches * batch_size))
    110       # Reached the limit.
    111       with self.assertRaises(errors_impl.OutOfRangeError):
    112         sess.run(batches)
    113       for thread in threads:
    114         thread.join()
    115 
    116   def testMultipleDequeue(self):
    117     with self.test_session() as sess:
    118       batch_size = 10
    119       image_size = 32
    120       num_batches = 4
    121 
    122       zero64 = constant_op.constant(0, dtype=dtypes.int64)
    123 
    124       examples = variables.Variable(zero64)
    125       counter = examples.count_up_to(num_batches * batch_size)
    126       image = random_ops.random_normal(
    127           [image_size, image_size, 3], dtype=dtypes.float32, name='images')
    128       label = random_ops.random_uniform(
    129           [1], 0, 10, dtype=dtypes.int32, name='labels')
    130 
    131       batches = input_lib.batch(
    132           [counter, image, label], batch_size=batch_size, num_threads=4)
    133 
    134       batcher = prefetch_queue.prefetch_queue(batches)
    135       batches_list = [batcher.dequeue() for _ in range(2)]
    136 
    137       variables.global_variables_initializer().run()
    138       threads = queue_runner_impl.start_queue_runners()
    139 
    140       value_counter = []
    141       for _ in range(int(num_batches / 2)):
    142         for batches in batches_list:
    143           results = sess.run(batches)
    144           value_counter.append(results[0])
    145           self.assertEquals(results[1].shape,
    146                             (batch_size, image_size, image_size, 3))
    147           self.assertEquals(results[2].shape, (batch_size, 1))
    148 
    149       self.assertAllEqual(
    150           np.sort(np.concatenate(value_counter)),
    151           np.arange(0, num_batches * batch_size))
    152       # Reached the limit.
    153       with self.assertRaises(errors_impl.OutOfRangeError):
    154         sess.run(batches)
    155       for thread in threads:
    156         thread.join()
    157 
    158   def testDynamicPad_failure(self):
    159     with ops.Graph().as_default():
    160       variable_tensor = array_ops.placeholder(dtypes.int32, shape=[None, 3])
    161       with self.assertRaisesRegexp(ValueError, 'shapes must be fully defined'):
    162         prefetch_queue.prefetch_queue([variable_tensor])
    163 
    164   def testDynamicPad(self):
    165     with self.test_session() as sess:
    166       # Create 3 tensors of variable but compatible shapes.
    167       var_shape = [None, 2]
    168       p1 = constant_op.constant([[1, 2], [3, 4]])
    169       p1.set_shape(var_shape)
    170       p2 = constant_op.constant([[5, 6], [7, 8], [9, 10]])
    171       p2.set_shape(var_shape)
    172       p3 = constant_op.constant([[11, 12]])
    173       p3.set_shape(var_shape)
    174       batch = [p1, p2, p3]
    175       batch_size = len(batch)
    176 
    177       zero64 = constant_op.constant(0, dtype=dtypes.int64)
    178       examples = variables.Variable(zero64)
    179       counter = examples.count_up_to(batch_size)
    180 
    181       # Create a PaddingFIFOQueue to enqueue these tensors.
    182       q = data_flow_ops.PaddingFIFOQueue(
    183           capacity=10, dtypes=[dtypes.int32], shapes=[var_shape])
    184       for tensor in [p1, p2, p3]:
    185         q.enqueue([tensor]).run()
    186 
    187       # Dequeue from the queue and batch them using batch().
    188       batches = input_lib.batch([q.dequeue(), counter], batch_size=batch_size,
    189                                 num_threads=1, dynamic_pad=True)
    190       self.assertEqual([batch_size, None, 2], batches[0].shape.as_list())
    191 
    192       # Finally, assemble them into prefetch_queue with dynamic_pad.
    193       batcher = prefetch_queue.prefetch_queue(batches, dynamic_pad=True)
    194       batches = batcher.dequeue()
    195       self.assertEqual([batch_size, None, 2], batches[0].shape.as_list())
    196 
    197       variables.global_variables_initializer().run()
    198       threads = queue_runner_impl.start_queue_runners()
    199 
    200       values, _ = sess.run(batches)
    201       # We enqueued 3 tensors of [None, 2] shapes, so using dynamic_pad
    202       # they should be padded to the fixed size [3, 3, 2], where 3
    203       # is the maximum length of the batch.
    204       self.assertTrue(np.array_equal(
    205           np.array([[[1, 2], [3, 4], [0, 0]],
    206                     [[5, 6], [7, 8], [9, 10]],
    207                     [[11, 12], [0, 0], [0, 0]]]),
    208           values))
    209 
    210       with self.assertRaises(errors_impl.OutOfRangeError):
    211         sess.run(batches)
    212       for thread in threads:
    213         thread.join()
    214 
    215   def testDictConstruction(self):
    216     with ops.Graph().as_default():
    217       batches = {
    218           'first': constant_op.constant([1]),
    219           'second': constant_op.constant([2.0, 2.1])
    220       }
    221       prefetcher = prefetch_queue.prefetch_queue(batches)
    222       dequeued = prefetcher.dequeue()
    223       self.assertTrue(isinstance(dequeued, dict))
    224       self.assertEqual(2, len(dequeued))
    225       self.assertEqual(dtypes.int32, dequeued['first'].dtype)
    226       self.assertEqual(dtypes.float32, dequeued['second'].dtype)
    227 
    228 
    229 if __name__ == '__main__':
    230   test.main()
    231