Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for `tf.data.experimental.make_tf_record_dataset()`."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
     21 from tensorflow.python.data.experimental.ops import readers
     22 from tensorflow.python.data.ops import dataset_ops
     23 from tensorflow.python.data.util import nest
     24 from tensorflow.python.framework import errors
     25 from tensorflow.python.framework import test_util
     26 from tensorflow.python.ops import string_ops
     27 from tensorflow.python.platform import test
     28 
     29 
     30 @test_util.run_all_in_graph_and_eager_modes
     31 class MakeTFRecordDatasetTest(
     32     reader_dataset_ops_test_base.TFRecordDatasetTestBase):
     33 
     34   def _read_test(self, batch_size, num_epochs, file_index=None,
     35                  num_parallel_reads=1, drop_final_batch=False, parser_fn=False):
     36     if file_index is None:
     37       file_pattern = self.test_filenames
     38     else:
     39       file_pattern = self.test_filenames[file_index]
     40 
     41     if parser_fn:
     42       fn = lambda x: string_ops.substr(x, 1, 999)
     43     else:
     44       fn = None
     45 
     46     outputs = self.getNext(
     47         readers.make_tf_record_dataset(
     48             file_pattern=file_pattern,
     49             num_epochs=num_epochs,
     50             batch_size=batch_size,
     51             parser_fn=fn,
     52             num_parallel_reads=num_parallel_reads,
     53             drop_final_batch=drop_final_batch,
     54             shuffle=False))
     55     self._verify_records(
     56         outputs,
     57         batch_size,
     58         file_index,
     59         num_epochs=num_epochs,
     60         interleave_cycle_length=num_parallel_reads,
     61         drop_final_batch=drop_final_batch,
     62         use_parser_fn=parser_fn)
     63     with self.assertRaises(errors.OutOfRangeError):
     64       self.evaluate(outputs())
     65 
     66   def testRead(self):
     67     for batch_size in [1, 2]:
     68       for num_epochs in [1, 3]:
     69         # Basic test: read from file 0.
     70         self._read_test(batch_size, num_epochs, 0)
     71 
     72         # Basic test: read from file 1.
     73         self._read_test(batch_size, num_epochs, 1)
     74 
     75         # Basic test: read from both files.
     76         self._read_test(batch_size, num_epochs)
     77 
     78         # Basic test: read from both files, with parallel reads.
     79         self._read_test(batch_size, num_epochs, num_parallel_reads=8)
     80 
     81   def testDropFinalBatch(self):
     82     for batch_size in [1, 2, 10]:
     83       for num_epochs in [1, 3]:
     84         # Read from file 0.
     85         self._read_test(batch_size, num_epochs, 0, drop_final_batch=True)
     86 
     87         # Read from both files.
     88         self._read_test(batch_size, num_epochs, drop_final_batch=True)
     89 
     90         # Read from both files, with parallel reads.
     91         self._read_test(batch_size, num_epochs, num_parallel_reads=8,
     92                         drop_final_batch=True)
     93 
     94   def testParserFn(self):
     95     for batch_size in [1, 2]:
     96       for num_epochs in [1, 3]:
     97         for drop_final_batch in [False, True]:
     98           self._read_test(batch_size, num_epochs, parser_fn=True,
     99                           drop_final_batch=drop_final_batch)
    100           self._read_test(batch_size, num_epochs, num_parallel_reads=8,
    101                           parser_fn=True, drop_final_batch=drop_final_batch)
    102 
    103   def _shuffle_test(self, batch_size, num_epochs, num_parallel_reads=1,
    104                     seed=None):
    105     dataset = readers.make_tf_record_dataset(
    106         file_pattern=self.test_filenames,
    107         num_epochs=num_epochs,
    108         batch_size=batch_size,
    109         num_parallel_reads=num_parallel_reads,
    110         shuffle=True,
    111         shuffle_seed=seed)
    112 
    113     next_element = self.getNext(dataset)
    114     first_batches = []
    115     try:
    116       while True:
    117         first_batches.append(self.evaluate(next_element()))
    118     except errors.OutOfRangeError:
    119       pass
    120 
    121     next_element = self.getNext(dataset)
    122     second_batches = []
    123     try:
    124       while True:
    125         second_batches.append(self.evaluate(next_element()))
    126     except errors.OutOfRangeError:
    127       pass
    128 
    129     self.assertEqual(len(first_batches), len(second_batches))
    130     if seed is not None:
    131       # if you set a seed, should get the same results
    132       for i in range(len(first_batches)):
    133         self.assertAllEqual(first_batches[i], second_batches[i])
    134 
    135     expected = []
    136     for f in range(self._num_files):
    137       for r in range(self._num_records):
    138         expected.extend([self._record(f, r)] * num_epochs)
    139 
    140     for batches in (first_batches, second_batches):
    141       actual = []
    142       for b in batches:
    143         actual.extend(b)
    144       self.assertAllEqual(sorted(expected), sorted(actual))
    145 
    146   def testShuffle(self):
    147     for batch_size in [1, 2]:
    148       for num_epochs in [1, 3]:
    149         for num_parallel_reads in [1, 2]:
    150           # Test that all expected elements are produced
    151           self._shuffle_test(batch_size, num_epochs, num_parallel_reads)
    152           # Test that elements are produced in a consistent order if
    153           # you specify a seed.
    154           self._shuffle_test(batch_size, num_epochs, num_parallel_reads,
    155                              seed=21345)
    156 
    157   def testIndefiniteRepeatShapeInference(self):
    158     dataset = readers.make_tf_record_dataset(
    159         file_pattern=self.test_filenames, num_epochs=None, batch_size=32)
    160     for shape in nest.flatten(dataset_ops.get_legacy_output_shapes(dataset)):
    161       self.assertEqual(32, shape[0])
    162 
    163 
    164 if __name__ == "__main__":
    165   test.main()
    166