Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for the experimental input pipeline ops."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import collections
     21 
     22 import numpy as np
     23 
     24 from tensorflow.python.data.ops import dataset_ops
     25 from tensorflow.python.data.ops import iterator_ops
     26 from tensorflow.python.framework import constant_op
     27 from tensorflow.python.framework import dtypes
     28 from tensorflow.python.framework import errors
     29 from tensorflow.python.ops import array_ops
     30 from tensorflow.python.platform import test
     31 
     32 
     33 class ShuffleDatasetTest(test.TestCase):
     34 
     35   def testShuffleDataset(self):
     36     components = (
     37         np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
     38         np.array([9.0, 10.0, 11.0, 12.0])
     39     )
     40     count_placeholder = array_ops.placeholder_with_default(
     41         constant_op.constant(5, dtypes.int64), shape=[])
     42     buffer_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
     43     seed_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
     44 
     45     repeat_dataset = (dataset_ops.Dataset.from_tensor_slices(components)
     46                       .repeat(count_placeholder))
     47 
     48     shuffle_dataset = repeat_dataset.shuffle(buffer_size_placeholder,
     49                                              seed_placeholder)
     50 
     51     self.assertEqual(tuple([c.shape[1:] for c in components]),
     52                      shuffle_dataset.output_shapes)
     53 
     54     # Create initialization ops for iterators without and with
     55     # shuffling, respectively.
     56     iterator = iterator_ops.Iterator.from_structure(
     57         shuffle_dataset.output_types, shuffle_dataset.output_shapes)
     58     init_fifo_op = iterator.make_initializer(repeat_dataset)
     59     init_shuffle_op = iterator.make_initializer(shuffle_dataset)
     60 
     61     get_next = iterator.get_next()
     62 
     63     with self.test_session() as sess:
     64       # First run without shuffling to collect the "ground truth".
     65       sess.run(init_fifo_op)
     66       unshuffled_elements = []
     67       for _ in range(20):
     68         unshuffled_elements.append(sess.run(get_next))
     69       with self.assertRaises(errors.OutOfRangeError):
     70         sess.run(get_next)
     71 
     72       # Assert that the shuffled dataset has the same elements as the
     73       # "ground truth".
     74       sess.run(
     75           init_shuffle_op,
     76           feed_dict={buffer_size_placeholder: 100,
     77                      seed_placeholder: 37})
     78       shuffled_elements = []
     79       for _ in range(20):
     80         shuffled_elements.append(sess.run(get_next))
     81       with self.assertRaises(errors.OutOfRangeError):
     82         sess.run(get_next)
     83       self.assertAllEqual(
     84           sorted(unshuffled_elements), sorted(shuffled_elements))
     85 
     86       # Assert that shuffling twice with the same seeds gives the same sequence.
     87       sess.run(
     88           init_shuffle_op,
     89           feed_dict={buffer_size_placeholder: 100,
     90                      seed_placeholder: 37})
     91       reshuffled_elements_same_seed = []
     92       for _ in range(20):
     93         reshuffled_elements_same_seed.append(sess.run(get_next))
     94       with self.assertRaises(errors.OutOfRangeError):
     95         sess.run(get_next)
     96       self.assertEqual(shuffled_elements, reshuffled_elements_same_seed)
     97 
     98       # Assert that shuffling twice with a different seed gives a different
     99       # permutation of the same elements.
    100       sess.run(
    101           init_shuffle_op,
    102           feed_dict={buffer_size_placeholder: 100,
    103                      seed_placeholder: 1037})
    104       reshuffled_elements_different_seed = []
    105       for _ in range(20):
    106         reshuffled_elements_different_seed.append(sess.run(get_next))
    107       with self.assertRaises(errors.OutOfRangeError):
    108         sess.run(get_next)
    109       self.assertNotEqual(shuffled_elements, reshuffled_elements_different_seed)
    110       self.assertAllEqual(
    111           sorted(shuffled_elements), sorted(reshuffled_elements_different_seed))
    112 
    113       # Assert that the shuffled dataset has the same elements as the
    114       # "ground truth" when the buffer size is smaller than the input
    115       # dataset.
    116       sess.run(
    117           init_shuffle_op,
    118           feed_dict={buffer_size_placeholder: 2,
    119                      seed_placeholder: 37})
    120       reshuffled_elements_small_buffer = []
    121       for _ in range(20):
    122         reshuffled_elements_small_buffer.append(sess.run(get_next))
    123       with self.assertRaises(errors.OutOfRangeError):
    124         sess.run(get_next)
    125       self.assertAllEqual(
    126           sorted(unshuffled_elements), sorted(reshuffled_elements_small_buffer))
    127 
    128       # Test the case of shuffling an empty dataset.
    129       sess.run(init_shuffle_op, feed_dict={buffer_size_placeholder: 2,
    130                                            seed_placeholder: 37,
    131                                            count_placeholder: 0})
    132       with self.assertRaises(errors.OutOfRangeError):
    133         sess.run(get_next)
    134 
    135   def testDefaultArguments(self):
    136     components = [0, 1, 2, 3, 4]
    137     iterator = (dataset_ops.Dataset.from_tensor_slices(components).shuffle(5)
    138                 .repeat().make_one_shot_iterator())
    139 
    140     get_next = iterator.get_next()
    141 
    142     with self.test_session() as sess:
    143       counts = collections.defaultdict(lambda: 0)
    144       for _ in range(10):
    145         for _ in range(5):
    146           counts[sess.run(get_next)] += 1
    147 
    148     for i in range(5):
    149       self.assertEqual(10, counts[i])
    150 
    151   def testShuffleNoReshuffleEachIteration(self):
    152     iterator = (dataset_ops.Dataset.range(10)
    153                 .shuffle(10, reshuffle_each_iteration=False)
    154                 .batch(10)
    155                 .repeat(3)
    156                 .make_one_shot_iterator())
    157     next_element = iterator.get_next()
    158 
    159     with self.test_session() as sess:
    160       initial_permutation = sess.run(next_element)
    161       self.assertAllEqual(initial_permutation, sess.run(next_element))
    162       self.assertAllEqual(initial_permutation, sess.run(next_element))
    163       with self.assertRaises(errors.OutOfRangeError):
    164         sess.run(next_element)
    165 
    166   def testShuffleReshuffleEachIteration(self):
    167     iterator = (dataset_ops.Dataset.range(10)
    168                 .shuffle(10, seed=3, reshuffle_each_iteration=True)
    169                 .batch(10)
    170                 .repeat(3)
    171                 .make_one_shot_iterator())
    172     next_element = iterator.get_next()
    173 
    174     with self.test_session() as sess:
    175       initial_permutation = list(sess.run(next_element))
    176       for _ in range(2):
    177         next_permutation = list(sess.run(next_element))
    178         self.assertNotEqual(initial_permutation, next_permutation)
    179         self.assertAllEqual(
    180             sorted(initial_permutation), sorted(next_permutation))
    181       with self.assertRaises(errors.OutOfRangeError):
    182         sess.run(next_element)
    183 
    184 if __name__ == "__main__":
    185   test.main()
    186