1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for `tf.data.experimental.make_tf_record_dataset()`.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base 21 from tensorflow.python.data.experimental.ops import readers 22 from tensorflow.python.data.ops import dataset_ops 23 from tensorflow.python.data.util import nest 24 from tensorflow.python.framework import errors 25 from tensorflow.python.framework import test_util 26 from tensorflow.python.ops import string_ops 27 from tensorflow.python.platform import test 28 29 30 @test_util.run_all_in_graph_and_eager_modes 31 class MakeTFRecordDatasetTest( 32 reader_dataset_ops_test_base.TFRecordDatasetTestBase): 33 34 def _read_test(self, batch_size, num_epochs, file_index=None, 35 num_parallel_reads=1, drop_final_batch=False, parser_fn=False): 36 if file_index is None: 37 file_pattern = self.test_filenames 38 else: 39 file_pattern = self.test_filenames[file_index] 40 41 if parser_fn: 42 fn = lambda x: string_ops.substr(x, 1, 999) 43 else: 44 fn = None 45 46 outputs = self.getNext( 47 readers.make_tf_record_dataset( 48 file_pattern=file_pattern, 49 num_epochs=num_epochs, 50 batch_size=batch_size, 51 parser_fn=fn, 52 num_parallel_reads=num_parallel_reads, 53 drop_final_batch=drop_final_batch, 54 shuffle=False)) 55 self._verify_records( 56 outputs, 57 batch_size, 58 file_index, 59 num_epochs=num_epochs, 60 interleave_cycle_length=num_parallel_reads, 61 drop_final_batch=drop_final_batch, 62 use_parser_fn=parser_fn) 63 with self.assertRaises(errors.OutOfRangeError): 64 self.evaluate(outputs()) 65 66 def testRead(self): 67 for batch_size in [1, 2]: 68 for num_epochs in [1, 3]: 69 # Basic test: read from file 0. 70 self._read_test(batch_size, num_epochs, 0) 71 72 # Basic test: read from file 1. 73 self._read_test(batch_size, num_epochs, 1) 74 75 # Basic test: read from both files. 76 self._read_test(batch_size, num_epochs) 77 78 # Basic test: read from both files, with parallel reads. 79 self._read_test(batch_size, num_epochs, num_parallel_reads=8) 80 81 def testDropFinalBatch(self): 82 for batch_size in [1, 2, 10]: 83 for num_epochs in [1, 3]: 84 # Read from file 0. 85 self._read_test(batch_size, num_epochs, 0, drop_final_batch=True) 86 87 # Read from both files. 88 self._read_test(batch_size, num_epochs, drop_final_batch=True) 89 90 # Read from both files, with parallel reads. 91 self._read_test(batch_size, num_epochs, num_parallel_reads=8, 92 drop_final_batch=True) 93 94 def testParserFn(self): 95 for batch_size in [1, 2]: 96 for num_epochs in [1, 3]: 97 for drop_final_batch in [False, True]: 98 self._read_test(batch_size, num_epochs, parser_fn=True, 99 drop_final_batch=drop_final_batch) 100 self._read_test(batch_size, num_epochs, num_parallel_reads=8, 101 parser_fn=True, drop_final_batch=drop_final_batch) 102 103 def _shuffle_test(self, batch_size, num_epochs, num_parallel_reads=1, 104 seed=None): 105 dataset = readers.make_tf_record_dataset( 106 file_pattern=self.test_filenames, 107 num_epochs=num_epochs, 108 batch_size=batch_size, 109 num_parallel_reads=num_parallel_reads, 110 shuffle=True, 111 shuffle_seed=seed) 112 113 next_element = self.getNext(dataset) 114 first_batches = [] 115 try: 116 while True: 117 first_batches.append(self.evaluate(next_element())) 118 except errors.OutOfRangeError: 119 pass 120 121 next_element = self.getNext(dataset) 122 second_batches = [] 123 try: 124 while True: 125 second_batches.append(self.evaluate(next_element())) 126 except errors.OutOfRangeError: 127 pass 128 129 self.assertEqual(len(first_batches), len(second_batches)) 130 if seed is not None: 131 # if you set a seed, should get the same results 132 for i in range(len(first_batches)): 133 self.assertAllEqual(first_batches[i], second_batches[i]) 134 135 expected = [] 136 for f in range(self._num_files): 137 for r in range(self._num_records): 138 expected.extend([self._record(f, r)] * num_epochs) 139 140 for batches in (first_batches, second_batches): 141 actual = [] 142 for b in batches: 143 actual.extend(b) 144 self.assertAllEqual(sorted(expected), sorted(actual)) 145 146 def testShuffle(self): 147 for batch_size in [1, 2]: 148 for num_epochs in [1, 3]: 149 for num_parallel_reads in [1, 2]: 150 # Test that all expected elements are produced 151 self._shuffle_test(batch_size, num_epochs, num_parallel_reads) 152 # Test that elements are produced in a consistent order if 153 # you specify a seed. 154 self._shuffle_test(batch_size, num_epochs, num_parallel_reads, 155 seed=21345) 156 157 def testIndefiniteRepeatShapeInference(self): 158 dataset = readers.make_tf_record_dataset( 159 file_pattern=self.test_filenames, num_epochs=None, batch_size=32) 160 for shape in nest.flatten(dataset_ops.get_legacy_output_shapes(dataset)): 161 self.assertEqual(32, shape[0]) 162 163 164 if __name__ == "__main__": 165 test.main() 166