1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for the experimental input pipeline ops.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import collections 21 22 import numpy as np 23 24 from tensorflow.python.data.ops import dataset_ops 25 from tensorflow.python.data.ops import iterator_ops 26 from tensorflow.python.framework import constant_op 27 from tensorflow.python.framework import dtypes 28 from tensorflow.python.framework import errors 29 from tensorflow.python.ops import array_ops 30 from tensorflow.python.platform import test 31 32 33 class ShuffleDatasetTest(test.TestCase): 34 35 def testShuffleDataset(self): 36 components = ( 37 np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]), 38 np.array([9.0, 10.0, 11.0, 12.0]) 39 ) 40 count_placeholder = array_ops.placeholder_with_default( 41 constant_op.constant(5, dtypes.int64), shape=[]) 42 buffer_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) 43 seed_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) 44 45 repeat_dataset = (dataset_ops.Dataset.from_tensor_slices(components) 46 .repeat(count_placeholder)) 47 48 shuffle_dataset = repeat_dataset.shuffle(buffer_size_placeholder, 49 seed_placeholder) 50 51 self.assertEqual(tuple([c.shape[1:] for c in components]), 52 shuffle_dataset.output_shapes) 53 54 # Create initialization ops for iterators without and with 55 # shuffling, respectively. 56 iterator = iterator_ops.Iterator.from_structure( 57 shuffle_dataset.output_types, shuffle_dataset.output_shapes) 58 init_fifo_op = iterator.make_initializer(repeat_dataset) 59 init_shuffle_op = iterator.make_initializer(shuffle_dataset) 60 61 get_next = iterator.get_next() 62 63 with self.test_session() as sess: 64 # First run without shuffling to collect the "ground truth". 65 sess.run(init_fifo_op) 66 unshuffled_elements = [] 67 for _ in range(20): 68 unshuffled_elements.append(sess.run(get_next)) 69 with self.assertRaises(errors.OutOfRangeError): 70 sess.run(get_next) 71 72 # Assert that the shuffled dataset has the same elements as the 73 # "ground truth". 74 sess.run( 75 init_shuffle_op, 76 feed_dict={buffer_size_placeholder: 100, 77 seed_placeholder: 37}) 78 shuffled_elements = [] 79 for _ in range(20): 80 shuffled_elements.append(sess.run(get_next)) 81 with self.assertRaises(errors.OutOfRangeError): 82 sess.run(get_next) 83 self.assertAllEqual( 84 sorted(unshuffled_elements), sorted(shuffled_elements)) 85 86 # Assert that shuffling twice with the same seeds gives the same sequence. 87 sess.run( 88 init_shuffle_op, 89 feed_dict={buffer_size_placeholder: 100, 90 seed_placeholder: 37}) 91 reshuffled_elements_same_seed = [] 92 for _ in range(20): 93 reshuffled_elements_same_seed.append(sess.run(get_next)) 94 with self.assertRaises(errors.OutOfRangeError): 95 sess.run(get_next) 96 self.assertEqual(shuffled_elements, reshuffled_elements_same_seed) 97 98 # Assert that shuffling twice with a different seed gives a different 99 # permutation of the same elements. 100 sess.run( 101 init_shuffle_op, 102 feed_dict={buffer_size_placeholder: 100, 103 seed_placeholder: 1037}) 104 reshuffled_elements_different_seed = [] 105 for _ in range(20): 106 reshuffled_elements_different_seed.append(sess.run(get_next)) 107 with self.assertRaises(errors.OutOfRangeError): 108 sess.run(get_next) 109 self.assertNotEqual(shuffled_elements, reshuffled_elements_different_seed) 110 self.assertAllEqual( 111 sorted(shuffled_elements), sorted(reshuffled_elements_different_seed)) 112 113 # Assert that the shuffled dataset has the same elements as the 114 # "ground truth" when the buffer size is smaller than the input 115 # dataset. 116 sess.run( 117 init_shuffle_op, 118 feed_dict={buffer_size_placeholder: 2, 119 seed_placeholder: 37}) 120 reshuffled_elements_small_buffer = [] 121 for _ in range(20): 122 reshuffled_elements_small_buffer.append(sess.run(get_next)) 123 with self.assertRaises(errors.OutOfRangeError): 124 sess.run(get_next) 125 self.assertAllEqual( 126 sorted(unshuffled_elements), sorted(reshuffled_elements_small_buffer)) 127 128 # Test the case of shuffling an empty dataset. 129 sess.run(init_shuffle_op, feed_dict={buffer_size_placeholder: 2, 130 seed_placeholder: 37, 131 count_placeholder: 0}) 132 with self.assertRaises(errors.OutOfRangeError): 133 sess.run(get_next) 134 135 def testDefaultArguments(self): 136 components = [0, 1, 2, 3, 4] 137 iterator = (dataset_ops.Dataset.from_tensor_slices(components).shuffle(5) 138 .repeat().make_one_shot_iterator()) 139 140 get_next = iterator.get_next() 141 142 with self.test_session() as sess: 143 counts = collections.defaultdict(lambda: 0) 144 for _ in range(10): 145 for _ in range(5): 146 counts[sess.run(get_next)] += 1 147 148 for i in range(5): 149 self.assertEqual(10, counts[i]) 150 151 def testShuffleNoReshuffleEachIteration(self): 152 iterator = (dataset_ops.Dataset.range(10) 153 .shuffle(10, reshuffle_each_iteration=False) 154 .batch(10) 155 .repeat(3) 156 .make_one_shot_iterator()) 157 next_element = iterator.get_next() 158 159 with self.test_session() as sess: 160 initial_permutation = sess.run(next_element) 161 self.assertAllEqual(initial_permutation, sess.run(next_element)) 162 self.assertAllEqual(initial_permutation, sess.run(next_element)) 163 with self.assertRaises(errors.OutOfRangeError): 164 sess.run(next_element) 165 166 def testShuffleReshuffleEachIteration(self): 167 iterator = (dataset_ops.Dataset.range(10) 168 .shuffle(10, seed=3, reshuffle_each_iteration=True) 169 .batch(10) 170 .repeat(3) 171 .make_one_shot_iterator()) 172 next_element = iterator.get_next() 173 174 with self.test_session() as sess: 175 initial_permutation = list(sess.run(next_element)) 176 for _ in range(2): 177 next_permutation = list(sess.run(next_element)) 178 self.assertNotEqual(initial_permutation, next_permutation) 179 self.assertAllEqual( 180 sorted(initial_permutation), sorted(next_permutation)) 181 with self.assertRaises(errors.OutOfRangeError): 182 sess.run(next_element) 183 184 if __name__ == "__main__": 185 test.main() 186