Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for the experimental input pipeline ops."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import numpy as np
     21 
     22 from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
     23 from tensorflow.contrib.data.python.ops import shuffle_ops
     24 from tensorflow.python.data.ops import dataset_ops
     25 from tensorflow.python.framework import errors
     26 from tensorflow.python.framework import ops
     27 from tensorflow.python.platform import test
     28 
     29 
     30 class ShuffleDatasetSerializationTest(
     31     dataset_serialization_test_base.DatasetSerializationTestBase):
     32 
     33   def _build_shuffle_dataset(
     34       self,
     35       range_limit=10,
     36       num_repeats=5,
     37       buffer_size=5,
     38       seed=None,
     39       reshuffle_each_iteration=None,
     40   ):
     41     return dataset_ops.Dataset.range(range_limit).shuffle(
     42         buffer_size,
     43         seed=seed,
     44         reshuffle_each_iteration=reshuffle_each_iteration).repeat(num_repeats)
     45 
     46   def testShuffleCore(self):
     47 
     48     seed = 55
     49     range_limit = 10
     50     num_repeats = 5
     51     num_outputs = range_limit * num_repeats
     52     buffer_sizes = [1, 3, 8, 10, 25, 50]
     53     reshuffle_each_iteration = False
     54     # pylint: disable=cell-var-from-loop
     55     # pylint: disable=g-long-lambda
     56     for buffer_size in buffer_sizes:
     57       self.run_core_tests(
     58           lambda: self._build_shuffle_dataset(
     59               range_limit=range_limit,
     60               num_repeats=num_repeats,
     61               buffer_size=buffer_size,
     62               seed=seed,
     63               reshuffle_each_iteration=reshuffle_each_iteration),
     64           lambda: self._build_shuffle_dataset(
     65               range_limit=range_limit,
     66               num_repeats=num_repeats,
     67               buffer_size=buffer_size,
     68               seed=10,
     69               reshuffle_each_iteration=reshuffle_each_iteration),
     70           num_outputs)
     71     # pylint: enable=cell-var-from-loop
     72     # pylint: enable=g-long-lambda
     73 
     74 
     75 class ShuffleAndRepeatTest(
     76     dataset_serialization_test_base.DatasetSerializationTestBase):
     77 
     78   def _build_ds(self, seed, count=5, num_elements=20):
     79     return dataset_ops.Dataset.range(num_elements).apply(
     80         shuffle_ops.shuffle_and_repeat(buffer_size=5, count=count, seed=seed))
     81 
     82   def testCorrectOutput(self):
     83     output = self.gen_outputs(lambda: self._build_ds(10), [], 100)
     84     self.assertSequenceEqual(
     85         sorted(output), sorted(
     86             np.array([range(20) for _ in range(5)]).flatten()))
     87     for i in range(5):
     88       self.assertSequenceEqual(sorted(output[i * 20:(i + 1) * 20]), range(20))
     89 
     90   def testReshuffling(self):
     91     # Check that the output orders of different epochs are indeed different.
     92     output = self.gen_outputs(lambda: self._build_ds(10), [], 100)
     93     for i in range(4):
     94       epoch1 = output[i * 20:(i + 1) * 20]
     95       epoch2 = output[(i + 1) * 20:(i + 2) * 20]
     96       self.assertNotEqual(epoch1, epoch2)
     97 
     98   def testSameOrderForSameSeeds(self):
     99     output1 = self.gen_outputs(lambda: self._build_ds(10), [], 100)
    100     output2 = self.gen_outputs(lambda: self._build_ds(10), [], 100)
    101     self.assertEqual(output1, output2)
    102 
    103   def testDifferentOrderForDifferentSeeds(self):
    104     output1 = self.gen_outputs(lambda: self._build_ds(10), [], 100)
    105     output2 = self.gen_outputs(lambda: self._build_ds(20), [], 100)
    106     self.assertNotEqual(output1, output2)
    107     self.assertEqual(sorted(output1), sorted(output2))
    108 
    109   def testCountNone(self):
    110     output1 = self.gen_outputs(
    111         lambda: self._build_ds(10, count=None), [], 100, verify_exhausted=False)
    112     output2 = self.gen_outputs(
    113         lambda: self._build_ds(20, count=None), [], 100, verify_exhausted=False)
    114     self.assertNotEqual(output1, output2)
    115     self.assertEqual(sorted(output1), sorted(output2))
    116 
    117   def testCountMinusOne(self):
    118     output1 = self.gen_outputs(
    119         lambda: self._build_ds(10, count=-1), [], 100, verify_exhausted=False)
    120     output2 = self.gen_outputs(
    121         lambda: self._build_ds(20, count=-1), [], 100, verify_exhausted=False)
    122     self.assertNotEqual(output1, output2)
    123     self.assertEqual(sorted(output1), sorted(output2))
    124 
    125   def testInfiniteOutputs(self):
    126     # Asserting the iterator is exhausted after producing 100 items should fail.
    127     with self.assertRaises(AssertionError):
    128       self.gen_outputs(lambda: self._build_ds(10, count=None), [], 100)
    129     with self.assertRaises(AssertionError):
    130       self.gen_outputs(lambda: self._build_ds(10, count=-1), [], 100)
    131 
    132   def testInfiniteEmpty(self):
    133     with self.assertRaises(errors.OutOfRangeError):
    134       self.gen_outputs(lambda: self._build_ds(10, count=None, num_elements=0),
    135                        [], 100)
    136     with self.assertRaises(errors.OutOfRangeError):
    137       self.gen_outputs(lambda: self._build_ds(10, count=-1, num_elements=0), [],
    138                        100)
    139 
    140   def testLargeBufferSize(self):
    141     with ops.Graph().as_default() as g:
    142       ds = dataset_ops.Dataset.range(20).apply(
    143           shuffle_ops.shuffle_and_repeat(buffer_size=21))
    144       get_next_op = ds.make_one_shot_iterator().get_next()
    145       with self.test_session(graph=g) as sess:
    146         sess.run(get_next_op)
    147 
    148 
    149 class ShuffleAndRepeatSerializationTest(
    150     dataset_serialization_test_base.DatasetSerializationTestBase):
    151 
    152   def _build_ds(self, seed):
    153     return dataset_ops.Dataset.range(20).apply(
    154         shuffle_ops.shuffle_and_repeat(buffer_size=5, count=5, seed=seed))
    155 
    156   def testCore(self):
    157     self.run_core_tests(lambda: self._build_ds(10), lambda: self._build_ds(20),
    158                         100)
    159 
    160 
    161 if __name__ == "__main__":
    162   test.main()
    163