1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for `tf.data.experimental.make_batched_features_dataset()`.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import numpy as np 21 22 from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base 23 from tensorflow.python.data.experimental.ops import readers 24 from tensorflow.python.data.ops import dataset_ops 25 from tensorflow.python.data.ops import readers as core_readers 26 from tensorflow.python.data.util import nest 27 from tensorflow.python.framework import dtypes 28 from tensorflow.python.framework import errors 29 from tensorflow.python.framework import ops 30 from tensorflow.python.framework import test_util 31 from tensorflow.python.ops import io_ops 32 from tensorflow.python.ops import parsing_ops 33 from tensorflow.python.platform import test 34 35 36 @test_util.run_all_in_graph_and_eager_modes 37 class MakeBatchedFeaturesDatasetTest( 38 reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase): 39 40 def testRead(self): 41 for batch_size in [1, 2]: 42 for num_epochs in [1, 10]: 43 # Basic test: read from file 0. 44 self.outputs = self.getNext( 45 self.make_batch_feature( 46 filenames=self.test_filenames[0], 47 label_key="label", 48 num_epochs=num_epochs, 49 batch_size=batch_size)) 50 self.verify_records( 51 batch_size, 0, num_epochs=num_epochs, label_key_provided=True) 52 with self.assertRaises(errors.OutOfRangeError): 53 self._next_actual_batch(label_key_provided=True) 54 55 # Basic test: read from file 1. 56 self.outputs = self.getNext( 57 self.make_batch_feature( 58 filenames=self.test_filenames[1], 59 label_key="label", 60 num_epochs=num_epochs, 61 batch_size=batch_size)) 62 self.verify_records( 63 batch_size, 1, num_epochs=num_epochs, label_key_provided=True) 64 with self.assertRaises(errors.OutOfRangeError): 65 self._next_actual_batch(label_key_provided=True) 66 67 # Basic test: read from both files. 68 self.outputs = self.getNext( 69 self.make_batch_feature( 70 filenames=self.test_filenames, 71 label_key="label", 72 num_epochs=num_epochs, 73 batch_size=batch_size)) 74 self.verify_records( 75 batch_size, num_epochs=num_epochs, label_key_provided=True) 76 with self.assertRaises(errors.OutOfRangeError): 77 self._next_actual_batch(label_key_provided=True) 78 # Basic test: read from both files. 79 self.outputs = self.getNext( 80 self.make_batch_feature( 81 filenames=self.test_filenames, 82 num_epochs=num_epochs, 83 batch_size=batch_size)) 84 self.verify_records(batch_size, num_epochs=num_epochs) 85 with self.assertRaises(errors.OutOfRangeError): 86 self._next_actual_batch() 87 88 def testReadWithEquivalentDataset(self): 89 features = { 90 "file": parsing_ops.FixedLenFeature([], dtypes.int64), 91 "record": parsing_ops.FixedLenFeature([], dtypes.int64), 92 } 93 dataset = ( 94 core_readers.TFRecordDataset(self.test_filenames) 95 .map(lambda x: parsing_ops.parse_single_example(x, features)) 96 .repeat(10).batch(2)) 97 next_element = self.getNext(dataset) 98 for file_batch, _, _, _, record_batch, _ in self._next_expected_batch( 99 range(self._num_files), 2, 10): 100 actual_batch = self.evaluate(next_element()) 101 self.assertAllEqual(file_batch, actual_batch["file"]) 102 self.assertAllEqual(record_batch, actual_batch["record"]) 103 with self.assertRaises(errors.OutOfRangeError): 104 self.evaluate(next_element()) 105 106 def testReadWithFusedShuffleRepeatDataset(self): 107 num_epochs = 5 108 total_records = num_epochs * self._num_records 109 for batch_size in [1, 2]: 110 # Test that shuffling with same seed produces the same result. 111 outputs1 = self.getNext( 112 self.make_batch_feature( 113 filenames=self.test_filenames[0], 114 num_epochs=num_epochs, 115 batch_size=batch_size, 116 shuffle=True, 117 shuffle_seed=5)) 118 outputs2 = self.getNext( 119 self.make_batch_feature( 120 filenames=self.test_filenames[0], 121 num_epochs=num_epochs, 122 batch_size=batch_size, 123 shuffle=True, 124 shuffle_seed=5)) 125 for _ in range(total_records // batch_size): 126 batch1 = self._run_actual_batch(outputs1) 127 batch2 = self._run_actual_batch(outputs2) 128 for i in range(len(batch1)): 129 self.assertAllEqual(batch1[i], batch2[i]) 130 131 # Test that shuffling with different seeds produces a different order. 132 outputs1 = self.getNext( 133 self.make_batch_feature( 134 filenames=self.test_filenames[0], 135 num_epochs=num_epochs, 136 batch_size=batch_size, 137 shuffle=True, 138 shuffle_seed=5)) 139 outputs2 = self.getNext( 140 self.make_batch_feature( 141 filenames=self.test_filenames[0], 142 num_epochs=num_epochs, 143 batch_size=batch_size, 144 shuffle=True, 145 shuffle_seed=15)) 146 all_equal = True 147 for _ in range(total_records // batch_size): 148 batch1 = self._run_actual_batch(outputs1) 149 batch2 = self._run_actual_batch(outputs2) 150 for i in range(len(batch1)): 151 all_equal = all_equal and np.array_equal(batch1[i], batch2[i]) 152 self.assertFalse(all_equal) 153 154 def testParallelReadersAndParsers(self): 155 num_epochs = 5 156 for batch_size in [1, 2]: 157 for reader_num_threads in [2, 4]: 158 for parser_num_threads in [2, 4]: 159 self.outputs = self.getNext( 160 self.make_batch_feature( 161 filenames=self.test_filenames, 162 label_key="label", 163 num_epochs=num_epochs, 164 batch_size=batch_size, 165 reader_num_threads=reader_num_threads, 166 parser_num_threads=parser_num_threads)) 167 self.verify_records( 168 batch_size, 169 num_epochs=num_epochs, 170 label_key_provided=True, 171 interleave_cycle_length=reader_num_threads) 172 with self.assertRaises(errors.OutOfRangeError): 173 self._next_actual_batch(label_key_provided=True) 174 175 self.outputs = self.getNext( 176 self.make_batch_feature( 177 filenames=self.test_filenames, 178 num_epochs=num_epochs, 179 batch_size=batch_size, 180 reader_num_threads=reader_num_threads, 181 parser_num_threads=parser_num_threads)) 182 self.verify_records( 183 batch_size, 184 num_epochs=num_epochs, 185 interleave_cycle_length=reader_num_threads) 186 with self.assertRaises(errors.OutOfRangeError): 187 self._next_actual_batch() 188 189 def testDropFinalBatch(self): 190 for batch_size in [1, 2]: 191 for num_epochs in [1, 10]: 192 with ops.Graph().as_default(): 193 # Basic test: read from file 0. 194 outputs = self.make_batch_feature( 195 filenames=self.test_filenames[0], 196 label_key="label", 197 num_epochs=num_epochs, 198 batch_size=batch_size, 199 drop_final_batch=True) 200 for tensor in nest.flatten(outputs): 201 if isinstance(tensor, ops.Tensor): # Guard against SparseTensor. 202 self.assertEqual(tensor.shape[0], batch_size) 203 204 def testIndefiniteRepeatShapeInference(self): 205 dataset = self.make_batch_feature( 206 filenames=self.test_filenames[0], 207 label_key="label", 208 num_epochs=None, 209 batch_size=32) 210 for shape, clazz in zip( 211 nest.flatten(dataset_ops.get_legacy_output_shapes(dataset)), 212 nest.flatten(dataset_ops.get_legacy_output_classes(dataset))): 213 if issubclass(clazz, ops.Tensor): 214 self.assertEqual(32, shape[0]) 215 216 def testOldStyleReader(self): 217 with self.assertRaisesRegexp( 218 TypeError, r"The `reader` argument must return a `Dataset` object. " 219 r"`tf.ReaderBase` subclasses are not supported."): 220 _ = readers.make_batched_features_dataset( 221 file_pattern=self.test_filenames[0], batch_size=32, 222 features={ 223 "file": parsing_ops.FixedLenFeature([], dtypes.int64), 224 "record": parsing_ops.FixedLenFeature([], dtypes.int64), 225 "keywords": parsing_ops.VarLenFeature(dtypes.string), 226 "label": parsing_ops.FixedLenFeature([], dtypes.string), 227 }, 228 reader=io_ops.TFRecordReader) 229 230 231 if __name__ == "__main__": 232 test.main() 233