Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for `tf.data.experimental.make_batched_features_dataset()`."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import numpy as np
     21 
     22 from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
     23 from tensorflow.python.data.experimental.ops import readers
     24 from tensorflow.python.data.ops import dataset_ops
     25 from tensorflow.python.data.ops import readers as core_readers
     26 from tensorflow.python.data.util import nest
     27 from tensorflow.python.framework import dtypes
     28 from tensorflow.python.framework import errors
     29 from tensorflow.python.framework import ops
     30 from tensorflow.python.framework import test_util
     31 from tensorflow.python.ops import io_ops
     32 from tensorflow.python.ops import parsing_ops
     33 from tensorflow.python.platform import test
     34 
     35 
     36 @test_util.run_all_in_graph_and_eager_modes
     37 class MakeBatchedFeaturesDatasetTest(
     38     reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase):
     39 
     40   def testRead(self):
     41     for batch_size in [1, 2]:
     42       for num_epochs in [1, 10]:
     43         # Basic test: read from file 0.
     44         self.outputs = self.getNext(
     45             self.make_batch_feature(
     46                 filenames=self.test_filenames[0],
     47                 label_key="label",
     48                 num_epochs=num_epochs,
     49                 batch_size=batch_size))
     50         self.verify_records(
     51             batch_size, 0, num_epochs=num_epochs, label_key_provided=True)
     52         with self.assertRaises(errors.OutOfRangeError):
     53           self._next_actual_batch(label_key_provided=True)
     54 
     55           # Basic test: read from file 1.
     56         self.outputs = self.getNext(
     57             self.make_batch_feature(
     58                 filenames=self.test_filenames[1],
     59                 label_key="label",
     60                 num_epochs=num_epochs,
     61                 batch_size=batch_size))
     62         self.verify_records(
     63             batch_size, 1, num_epochs=num_epochs, label_key_provided=True)
     64         with self.assertRaises(errors.OutOfRangeError):
     65           self._next_actual_batch(label_key_provided=True)
     66 
     67         # Basic test: read from both files.
     68         self.outputs = self.getNext(
     69             self.make_batch_feature(
     70                 filenames=self.test_filenames,
     71                 label_key="label",
     72                 num_epochs=num_epochs,
     73                 batch_size=batch_size))
     74         self.verify_records(
     75             batch_size, num_epochs=num_epochs, label_key_provided=True)
     76         with self.assertRaises(errors.OutOfRangeError):
     77           self._next_actual_batch(label_key_provided=True)
     78         # Basic test: read from both files.
     79         self.outputs = self.getNext(
     80             self.make_batch_feature(
     81                 filenames=self.test_filenames,
     82                 num_epochs=num_epochs,
     83                 batch_size=batch_size))
     84         self.verify_records(batch_size, num_epochs=num_epochs)
     85         with self.assertRaises(errors.OutOfRangeError):
     86           self._next_actual_batch()
     87 
     88   def testReadWithEquivalentDataset(self):
     89     features = {
     90         "file": parsing_ops.FixedLenFeature([], dtypes.int64),
     91         "record": parsing_ops.FixedLenFeature([], dtypes.int64),
     92     }
     93     dataset = (
     94         core_readers.TFRecordDataset(self.test_filenames)
     95         .map(lambda x: parsing_ops.parse_single_example(x, features))
     96         .repeat(10).batch(2))
     97     next_element = self.getNext(dataset)
     98     for file_batch, _, _, _, record_batch, _ in self._next_expected_batch(
     99         range(self._num_files), 2, 10):
    100       actual_batch = self.evaluate(next_element())
    101       self.assertAllEqual(file_batch, actual_batch["file"])
    102       self.assertAllEqual(record_batch, actual_batch["record"])
    103     with self.assertRaises(errors.OutOfRangeError):
    104       self.evaluate(next_element())
    105 
    106   def testReadWithFusedShuffleRepeatDataset(self):
    107     num_epochs = 5
    108     total_records = num_epochs * self._num_records
    109     for batch_size in [1, 2]:
    110       # Test that shuffling with same seed produces the same result.
    111       outputs1 = self.getNext(
    112           self.make_batch_feature(
    113               filenames=self.test_filenames[0],
    114               num_epochs=num_epochs,
    115               batch_size=batch_size,
    116               shuffle=True,
    117               shuffle_seed=5))
    118       outputs2 = self.getNext(
    119           self.make_batch_feature(
    120               filenames=self.test_filenames[0],
    121               num_epochs=num_epochs,
    122               batch_size=batch_size,
    123               shuffle=True,
    124               shuffle_seed=5))
    125       for _ in range(total_records // batch_size):
    126         batch1 = self._run_actual_batch(outputs1)
    127         batch2 = self._run_actual_batch(outputs2)
    128         for i in range(len(batch1)):
    129           self.assertAllEqual(batch1[i], batch2[i])
    130 
    131       # Test that shuffling with different seeds produces a different order.
    132       outputs1 = self.getNext(
    133           self.make_batch_feature(
    134               filenames=self.test_filenames[0],
    135               num_epochs=num_epochs,
    136               batch_size=batch_size,
    137               shuffle=True,
    138               shuffle_seed=5))
    139       outputs2 = self.getNext(
    140           self.make_batch_feature(
    141               filenames=self.test_filenames[0],
    142               num_epochs=num_epochs,
    143               batch_size=batch_size,
    144               shuffle=True,
    145               shuffle_seed=15))
    146       all_equal = True
    147       for _ in range(total_records // batch_size):
    148         batch1 = self._run_actual_batch(outputs1)
    149         batch2 = self._run_actual_batch(outputs2)
    150         for i in range(len(batch1)):
    151           all_equal = all_equal and np.array_equal(batch1[i], batch2[i])
    152       self.assertFalse(all_equal)
    153 
    154   def testParallelReadersAndParsers(self):
    155     num_epochs = 5
    156     for batch_size in [1, 2]:
    157       for reader_num_threads in [2, 4]:
    158         for parser_num_threads in [2, 4]:
    159           self.outputs = self.getNext(
    160               self.make_batch_feature(
    161                   filenames=self.test_filenames,
    162                   label_key="label",
    163                   num_epochs=num_epochs,
    164                   batch_size=batch_size,
    165                   reader_num_threads=reader_num_threads,
    166                   parser_num_threads=parser_num_threads))
    167           self.verify_records(
    168               batch_size,
    169               num_epochs=num_epochs,
    170               label_key_provided=True,
    171               interleave_cycle_length=reader_num_threads)
    172           with self.assertRaises(errors.OutOfRangeError):
    173             self._next_actual_batch(label_key_provided=True)
    174 
    175           self.outputs = self.getNext(
    176               self.make_batch_feature(
    177                   filenames=self.test_filenames,
    178                   num_epochs=num_epochs,
    179                   batch_size=batch_size,
    180                   reader_num_threads=reader_num_threads,
    181                   parser_num_threads=parser_num_threads))
    182           self.verify_records(
    183               batch_size,
    184               num_epochs=num_epochs,
    185               interleave_cycle_length=reader_num_threads)
    186           with self.assertRaises(errors.OutOfRangeError):
    187             self._next_actual_batch()
    188 
    189   def testDropFinalBatch(self):
    190     for batch_size in [1, 2]:
    191       for num_epochs in [1, 10]:
    192         with ops.Graph().as_default():
    193           # Basic test: read from file 0.
    194           outputs = self.make_batch_feature(
    195               filenames=self.test_filenames[0],
    196               label_key="label",
    197               num_epochs=num_epochs,
    198               batch_size=batch_size,
    199               drop_final_batch=True)
    200           for tensor in nest.flatten(outputs):
    201             if isinstance(tensor, ops.Tensor):  # Guard against SparseTensor.
    202               self.assertEqual(tensor.shape[0], batch_size)
    203 
    204   def testIndefiniteRepeatShapeInference(self):
    205     dataset = self.make_batch_feature(
    206         filenames=self.test_filenames[0],
    207         label_key="label",
    208         num_epochs=None,
    209         batch_size=32)
    210     for shape, clazz in zip(
    211         nest.flatten(dataset_ops.get_legacy_output_shapes(dataset)),
    212         nest.flatten(dataset_ops.get_legacy_output_classes(dataset))):
    213       if issubclass(clazz, ops.Tensor):
    214         self.assertEqual(32, shape[0])
    215 
    216   def testOldStyleReader(self):
    217     with self.assertRaisesRegexp(
    218         TypeError, r"The `reader` argument must return a `Dataset` object. "
    219         r"`tf.ReaderBase` subclasses are not supported."):
    220       _ = readers.make_batched_features_dataset(
    221           file_pattern=self.test_filenames[0], batch_size=32,
    222           features={
    223               "file": parsing_ops.FixedLenFeature([], dtypes.int64),
    224               "record": parsing_ops.FixedLenFeature([], dtypes.int64),
    225               "keywords": parsing_ops.VarLenFeature(dtypes.string),
    226               "label": parsing_ops.FixedLenFeature([], dtypes.string),
    227           },
    228           reader=io_ops.TFRecordReader)
    229 
    230 
    231 if __name__ == "__main__":
    232   test.main()
    233