1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for the experimental input pipeline ops.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import numpy as np 21 22 from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base 23 from tensorflow.contrib.data.python.ops import shuffle_ops 24 from tensorflow.python.data.ops import dataset_ops 25 from tensorflow.python.framework import errors 26 from tensorflow.python.framework import ops 27 from tensorflow.python.platform import test 28 29 30 class ShuffleDatasetSerializationTest( 31 dataset_serialization_test_base.DatasetSerializationTestBase): 32 33 def _build_shuffle_dataset( 34 self, 35 range_limit=10, 36 num_repeats=5, 37 buffer_size=5, 38 seed=None, 39 reshuffle_each_iteration=None, 40 ): 41 return dataset_ops.Dataset.range(range_limit).shuffle( 42 buffer_size, 43 seed=seed, 44 reshuffle_each_iteration=reshuffle_each_iteration).repeat(num_repeats) 45 46 def testShuffleCore(self): 47 48 seed = 55 49 range_limit = 10 50 num_repeats = 5 51 num_outputs = range_limit * num_repeats 52 buffer_sizes = [1, 3, 8, 10, 25, 50] 53 reshuffle_each_iteration = False 54 # pylint: disable=cell-var-from-loop 55 # pylint: disable=g-long-lambda 56 for buffer_size in buffer_sizes: 57 self.run_core_tests( 58 lambda: self._build_shuffle_dataset( 59 range_limit=range_limit, 60 num_repeats=num_repeats, 61 buffer_size=buffer_size, 62 seed=seed, 63 reshuffle_each_iteration=reshuffle_each_iteration), 64 lambda: self._build_shuffle_dataset( 65 range_limit=range_limit, 66 num_repeats=num_repeats, 67 buffer_size=buffer_size, 68 seed=10, 69 reshuffle_each_iteration=reshuffle_each_iteration), 70 num_outputs) 71 # pylint: enable=cell-var-from-loop 72 # pylint: enable=g-long-lambda 73 74 75 class ShuffleAndRepeatTest( 76 dataset_serialization_test_base.DatasetSerializationTestBase): 77 78 def _build_ds(self, seed, count=5, num_elements=20): 79 return dataset_ops.Dataset.range(num_elements).apply( 80 shuffle_ops.shuffle_and_repeat(buffer_size=5, count=count, seed=seed)) 81 82 def testCorrectOutput(self): 83 output = self.gen_outputs(lambda: self._build_ds(10), [], 100) 84 self.assertSequenceEqual( 85 sorted(output), sorted( 86 np.array([range(20) for _ in range(5)]).flatten())) 87 for i in range(5): 88 self.assertSequenceEqual(sorted(output[i * 20:(i + 1) * 20]), range(20)) 89 90 def testReshuffling(self): 91 # Check that the output orders of different epochs are indeed different. 92 output = self.gen_outputs(lambda: self._build_ds(10), [], 100) 93 for i in range(4): 94 epoch1 = output[i * 20:(i + 1) * 20] 95 epoch2 = output[(i + 1) * 20:(i + 2) * 20] 96 self.assertNotEqual(epoch1, epoch2) 97 98 def testSameOrderForSameSeeds(self): 99 output1 = self.gen_outputs(lambda: self._build_ds(10), [], 100) 100 output2 = self.gen_outputs(lambda: self._build_ds(10), [], 100) 101 self.assertEqual(output1, output2) 102 103 def testDifferentOrderForDifferentSeeds(self): 104 output1 = self.gen_outputs(lambda: self._build_ds(10), [], 100) 105 output2 = self.gen_outputs(lambda: self._build_ds(20), [], 100) 106 self.assertNotEqual(output1, output2) 107 self.assertEqual(sorted(output1), sorted(output2)) 108 109 def testCountNone(self): 110 output1 = self.gen_outputs( 111 lambda: self._build_ds(10, count=None), [], 100, verify_exhausted=False) 112 output2 = self.gen_outputs( 113 lambda: self._build_ds(20, count=None), [], 100, verify_exhausted=False) 114 self.assertNotEqual(output1, output2) 115 self.assertEqual(sorted(output1), sorted(output2)) 116 117 def testCountMinusOne(self): 118 output1 = self.gen_outputs( 119 lambda: self._build_ds(10, count=-1), [], 100, verify_exhausted=False) 120 output2 = self.gen_outputs( 121 lambda: self._build_ds(20, count=-1), [], 100, verify_exhausted=False) 122 self.assertNotEqual(output1, output2) 123 self.assertEqual(sorted(output1), sorted(output2)) 124 125 def testInfiniteOutputs(self): 126 # Asserting the iterator is exhausted after producing 100 items should fail. 127 with self.assertRaises(AssertionError): 128 self.gen_outputs(lambda: self._build_ds(10, count=None), [], 100) 129 with self.assertRaises(AssertionError): 130 self.gen_outputs(lambda: self._build_ds(10, count=-1), [], 100) 131 132 def testInfiniteEmpty(self): 133 with self.assertRaises(errors.OutOfRangeError): 134 self.gen_outputs(lambda: self._build_ds(10, count=None, num_elements=0), 135 [], 100) 136 with self.assertRaises(errors.OutOfRangeError): 137 self.gen_outputs(lambda: self._build_ds(10, count=-1, num_elements=0), [], 138 100) 139 140 def testLargeBufferSize(self): 141 with ops.Graph().as_default() as g: 142 ds = dataset_ops.Dataset.range(20).apply( 143 shuffle_ops.shuffle_and_repeat(buffer_size=21)) 144 get_next_op = ds.make_one_shot_iterator().get_next() 145 with self.test_session(graph=g) as sess: 146 sess.run(get_next_op) 147 148 149 class ShuffleAndRepeatSerializationTest( 150 dataset_serialization_test_base.DatasetSerializationTestBase): 151 152 def _build_ds(self, seed): 153 return dataset_ops.Dataset.range(20).apply( 154 shuffle_ops.shuffle_and_repeat(buffer_size=5, count=5, seed=seed)) 155 156 def testCore(self): 157 self.run_core_tests(lambda: self._build_ds(10), lambda: self._build_ds(20), 158 100) 159 160 161 if __name__ == "__main__": 162 test.main() 163