Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for the experimental input pipeline ops."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import gzip
     21 import os
     22 import zlib
     23 
     24 from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
     25 from tensorflow.contrib.data.python.ops import readers
     26 from tensorflow.core.example import example_pb2
     27 from tensorflow.core.example import feature_pb2
     28 from tensorflow.python.data.ops import iterator_ops
     29 from tensorflow.python.data.ops import readers as core_readers
     30 from tensorflow.python.framework import constant_op
     31 from tensorflow.python.framework import dtypes
     32 from tensorflow.python.framework import errors
     33 from tensorflow.python.framework import ops
     34 from tensorflow.python.lib.io import python_io
     35 from tensorflow.python.ops import array_ops
     36 from tensorflow.python.ops import parsing_ops
     37 from tensorflow.python.platform import test
     38 from tensorflow.python.util import compat
     39 
     40 
     41 class TextLineDatasetTestBase(test.TestCase):
     42 
     43   def _lineText(self, f, l):
     44     return compat.as_bytes("%d: %d" % (f, l))
     45 
     46   def _createFiles(self,
     47                    num_files,
     48                    num_lines,
     49                    crlf=False,
     50                    compression_type=None):
     51     filenames = []
     52     for i in range(num_files):
     53       fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i)
     54       filenames.append(fn)
     55       contents = []
     56       for j in range(num_lines):
     57         contents.append(self._lineText(i, j))
     58         # Always include a newline after the record unless it is
     59         # at the end of the file, in which case we include it
     60         if j + 1 != num_lines or i == 0:
     61           contents.append(b"\r\n" if crlf else b"\n")
     62       contents = b"".join(contents)
     63 
     64       if not compression_type:
     65         with open(fn, "wb") as f:
     66           f.write(contents)
     67       elif compression_type == "GZIP":
     68         with gzip.GzipFile(fn, "wb") as f:
     69           f.write(contents)
     70       elif compression_type == "ZLIB":
     71         contents = zlib.compress(contents)
     72         with open(fn, "wb") as f:
     73           f.write(contents)
     74       else:
     75         raise ValueError("Unsupported compression_type", compression_type)
     76 
     77     return filenames
     78 
     79 
     80 class TextLineDatasetSerializationTest(
     81     TextLineDatasetTestBase,
     82     dataset_serialization_test_base.DatasetSerializationTestBase):
     83 
     84   def _build_iterator_graph(self, test_filenames, compression_type=None):
     85     return core_readers.TextLineDataset(
     86         test_filenames, compression_type=compression_type, buffer_size=10)
     87 
     88   def testTextLineCore(self):
     89     compression_types = [None, "GZIP", "ZLIB"]
     90     num_files = 5
     91     lines_per_file = 5
     92     num_outputs = num_files * lines_per_file
     93     for compression_type in compression_types:
     94       test_filenames = self._createFiles(
     95           num_files,
     96           lines_per_file,
     97           crlf=True,
     98           compression_type=compression_type)
     99       # pylint: disable=cell-var-from-loop
    100       self.run_core_tests(
    101           lambda: self._build_iterator_graph(test_filenames, compression_type),
    102           lambda: self._build_iterator_graph(test_filenames), num_outputs)
    103       # pylint: enable=cell-var-from-loop
    104 
    105 
    106 class FixedLengthRecordReaderTestBase(test.TestCase):
    107 
    108   def setUp(self):
    109     super(FixedLengthRecordReaderTestBase, self).setUp()
    110     self._num_files = 2
    111     self._num_records = 7
    112     self._header_bytes = 5
    113     self._record_bytes = 3
    114     self._footer_bytes = 2
    115 
    116   def _record(self, f, r):
    117     return compat.as_bytes(str(f * 2 + r) * self._record_bytes)
    118 
    119   def _createFiles(self):
    120     filenames = []
    121     for i in range(self._num_files):
    122       fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i)
    123       filenames.append(fn)
    124       with open(fn, "wb") as f:
    125         f.write(b"H" * self._header_bytes)
    126         for j in range(self._num_records):
    127           f.write(self._record(i, j))
    128         f.write(b"F" * self._footer_bytes)
    129     return filenames
    130 
    131 
    132 class FixedLengthRecordDatasetSerializationTest(
    133     FixedLengthRecordReaderTestBase,
    134     dataset_serialization_test_base.DatasetSerializationTestBase):
    135 
    136   def _build_iterator_graph(self, num_epochs, compression_type=None):
    137     filenames = self._createFiles()
    138     return core_readers.FixedLengthRecordDataset(
    139         filenames, self._record_bytes, self._header_bytes,
    140         self._footer_bytes).repeat(num_epochs)
    141 
    142   def testFixedLengthRecordCore(self):
    143     num_epochs = 5
    144     num_outputs = num_epochs * self._num_files * self._num_records
    145     self.run_core_tests(lambda: self._build_iterator_graph(num_epochs),
    146                         lambda: self._build_iterator_graph(num_epochs * 2),
    147                         num_outputs)
    148 
    149 
    150 class TFRecordDatasetTestBase(test.TestCase):
    151 
    152   def setUp(self):
    153     super(TFRecordDatasetTestBase, self).setUp()
    154     self._num_files = 2
    155     self._num_records = 7
    156 
    157     self.test_filenames = self._createFiles()
    158 
    159     self.filenames = array_ops.placeholder(dtypes.string, shape=[None])
    160     self.num_epochs = array_ops.placeholder_with_default(
    161         constant_op.constant(1, dtypes.int64), shape=[])
    162     self.compression_type = array_ops.placeholder_with_default("", shape=[])
    163     self.batch_size = array_ops.placeholder(dtypes.int64, shape=[])
    164 
    165     repeat_dataset = core_readers.TFRecordDataset(
    166         self.filenames, self.compression_type).repeat(self.num_epochs)
    167     batch_dataset = repeat_dataset.batch(self.batch_size)
    168 
    169     iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
    170     self.init_op = iterator.make_initializer(repeat_dataset)
    171     self.init_batch_op = iterator.make_initializer(batch_dataset)
    172     self.get_next = iterator.get_next()
    173 
    174   def _record(self, f, r):
    175     return compat.as_bytes("Record %d of file %d" % (r, f))
    176 
    177   def _createFiles(self):
    178     filenames = []
    179     for i in range(self._num_files):
    180       fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
    181       filenames.append(fn)
    182       writer = python_io.TFRecordWriter(fn)
    183       for j in range(self._num_records):
    184         writer.write(self._record(i, j))
    185       writer.close()
    186     return filenames
    187 
    188 
    189 class TFRecordDatasetSerializationTest(
    190     TFRecordDatasetTestBase,
    191     dataset_serialization_test_base.DatasetSerializationTestBase):
    192 
    193   def _build_iterator_graph(self,
    194                             num_epochs,
    195                             batch_size=1,
    196                             compression_type=None,
    197                             buffer_size=None):
    198     filenames = self._createFiles()
    199     if compression_type is "ZLIB":
    200       zlib_files = []
    201       for i, fn in enumerate(filenames):
    202         with open(fn, "rb") as f:
    203           cdata = zlib.compress(f.read())
    204           zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i)
    205           with open(zfn, "wb") as f:
    206             f.write(cdata)
    207           zlib_files.append(zfn)
    208       filenames = zlib_files
    209 
    210     elif compression_type is "GZIP":
    211       gzip_files = []
    212       for i, fn in enumerate(self.test_filenames):
    213         with open(fn, "rb") as f:
    214           gzfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i)
    215           with gzip.GzipFile(gzfn, "wb") as gzf:
    216             gzf.write(f.read())
    217           gzip_files.append(gzfn)
    218       filenames = gzip_files
    219 
    220     return core_readers.TFRecordDataset(
    221         filenames, compression_type,
    222         buffer_size=buffer_size).repeat(num_epochs).batch(batch_size)
    223 
    224   def testTFRecordWithoutBufferCore(self):
    225     num_epochs = 5
    226     batch_size = num_epochs
    227     num_outputs = num_epochs * self._num_files * self._num_records // batch_size
    228     # pylint: disable=g-long-lambda
    229     self.run_core_tests(
    230         lambda: self._build_iterator_graph(num_epochs, batch_size,
    231                                            buffer_size=0),
    232         lambda: self._build_iterator_graph(num_epochs * 2, batch_size),
    233         num_outputs)
    234     self.run_core_tests(
    235         lambda: self._build_iterator_graph(num_epochs, buffer_size=0), None,
    236         num_outputs * batch_size)
    237     # pylint: enable=g-long-lambda
    238 
    239   def testTFRecordWithBufferCore(self):
    240     num_epochs = 5
    241     num_outputs = num_epochs * self._num_files * self._num_records
    242     self.run_core_tests(lambda: self._build_iterator_graph(num_epochs),
    243                         lambda: self._build_iterator_graph(num_epochs * 2),
    244                         num_outputs)
    245 
    246   def testTFRecordWithCompressionCore(self):
    247     num_epochs = 5
    248     num_outputs = num_epochs * self._num_files * self._num_records
    249     self.run_core_tests(
    250         lambda: self._build_iterator_graph(num_epochs, compression_type="ZLIB"),
    251         lambda: self._build_iterator_graph(num_epochs * 2), num_outputs)
    252     self.run_core_tests(
    253         lambda: self._build_iterator_graph(num_epochs, compression_type="GZIP"),
    254         lambda: self._build_iterator_graph(num_epochs * 2), num_outputs)
    255 
    256 
    257 class ReadBatchFeaturesTest(test.TestCase):
    258 
    259   def setUp(self):
    260     super(ReadBatchFeaturesTest, self).setUp()
    261     self._num_files = 2
    262     self._num_records = 7
    263     self.test_filenames = self._createFiles()
    264 
    265   def _read_batch_features(self, filenames, num_epochs, batch_size):
    266     self.filenames = filenames
    267     self.num_epochs = num_epochs
    268     self.batch_size = batch_size
    269 
    270     return readers.read_batch_features(
    271         file_pattern=self.filenames,
    272         batch_size=self.batch_size,
    273         features={
    274             "file": parsing_ops.FixedLenFeature([], dtypes.int64),
    275             "record": parsing_ops.FixedLenFeature([], dtypes.int64),
    276             "keywords": parsing_ops.VarLenFeature(dtypes.string)
    277         },
    278         reader=core_readers.TFRecordDataset,
    279         randomize_input=False,
    280         num_epochs=self.num_epochs)
    281 
    282   def _record(self, f, r):
    283     example = example_pb2.Example(features=feature_pb2.Features(
    284         feature={
    285             "file":
    286                 feature_pb2.Feature(int64_list=feature_pb2.Int64List(
    287                     value=[f])),
    288             "record":
    289                 feature_pb2.Feature(int64_list=feature_pb2.Int64List(
    290                     value=[r])),
    291             "keywords":
    292                 feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
    293                     value=self._get_keywords(f, r)))
    294         }))
    295     return example.SerializeToString()
    296 
    297   def _get_keywords(self, f, r):
    298     num_keywords = 1 + (f + r) % 2
    299     keywords = []
    300     for index in range(num_keywords):
    301       keywords.append(compat.as_bytes("keyword%d" % index))
    302     return keywords
    303 
    304   def _createFiles(self):
    305     filenames = []
    306     for i in range(self._num_files):
    307       fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
    308       filenames.append(fn)
    309       writer = python_io.TFRecordWriter(fn)
    310       for j in range(self._num_records):
    311         writer.write(self._record(i, j))
    312       writer.close()
    313     return filenames
    314 
    315   def _next_actual_batch(self, sess):
    316     file_op = self.outputs["file"]
    317     keywords_indices_op = self.outputs["keywords"].indices
    318     keywords_values_op = self.outputs["keywords"].values
    319     keywords_dense_shape_op = self.outputs["keywords"].dense_shape
    320     record_op = self.outputs["record"]
    321     return sess.run([
    322         file_op, keywords_indices_op, keywords_values_op,
    323         keywords_dense_shape_op, record_op
    324     ])
    325 
    326   def _next_expected_batch(self, file_indices, batch_size, num_epochs):
    327 
    328     def _next_record(file_indices):
    329       for j in file_indices:
    330         for i in range(self._num_records):
    331           yield j, i
    332 
    333     file_batch = []
    334     keywords_batch_indices = []
    335     keywords_batch_values = []
    336     keywords_batch_max_len = 0
    337     record_batch = []
    338     batch_index = 0
    339     for _ in range(num_epochs):
    340       for record in _next_record(file_indices):
    341         f = record[0]
    342         r = record[1]
    343         file_batch.append(f)
    344         record_batch.append(r)
    345         keywords = self._get_keywords(f, r)
    346         keywords_batch_values.extend(keywords)
    347         keywords_batch_indices.extend([[batch_index, i]
    348                                        for i in range(len(keywords))])
    349         batch_index += 1
    350         keywords_batch_max_len = max(keywords_batch_max_len, len(keywords))
    351         if len(file_batch) == batch_size:
    352           yield [
    353               file_batch, keywords_batch_indices, keywords_batch_values,
    354               [batch_size, keywords_batch_max_len], record_batch
    355           ]
    356           file_batch = []
    357           keywords_batch_indices = []
    358           keywords_batch_values = []
    359           keywords_batch_max_len = 0
    360           record_batch = []
    361           batch_index = 0
    362     if file_batch:
    363       yield [
    364           file_batch, keywords_batch_indices, keywords_batch_values,
    365           [len(file_batch), keywords_batch_max_len], record_batch
    366       ]
    367 
    368   def _verify_records(self, sess, batch_size, file_index=None, num_epochs=1):
    369     if file_index is not None:
    370       file_indices = [file_index]
    371     else:
    372       file_indices = range(self._num_files)
    373 
    374     for expected_batch in self._next_expected_batch(file_indices, batch_size,
    375                                                     num_epochs):
    376       actual_batch = self._next_actual_batch(sess)
    377       for i in range(len(expected_batch)):
    378         self.assertAllEqual(expected_batch[i], actual_batch[i])
    379 
    380   def testRead(self):
    381     for batch_size in [1, 2]:
    382       for num_epochs in [1, 10]:
    383         with ops.Graph().as_default() as g:
    384           with self.test_session(graph=g) as sess:
    385             # Basic test: read from file 0.
    386             self.outputs = self._read_batch_features(
    387                 filenames=self.test_filenames[0],
    388                 num_epochs=num_epochs,
    389                 batch_size=batch_size)
    390             self._verify_records(sess, batch_size, 0, num_epochs=num_epochs)
    391             with self.assertRaises(errors.OutOfRangeError):
    392               self._next_actual_batch(sess)
    393 
    394         with ops.Graph().as_default() as g:
    395           with self.test_session(graph=g) as sess:
    396             # Basic test: read from file 1.
    397             self.outputs = self._read_batch_features(
    398                 filenames=self.test_filenames[1],
    399                 num_epochs=num_epochs,
    400                 batch_size=batch_size)
    401             self._verify_records(sess, batch_size, 1, num_epochs=num_epochs)
    402             with self.assertRaises(errors.OutOfRangeError):
    403               self._next_actual_batch(sess)
    404 
    405         with ops.Graph().as_default() as g:
    406           with self.test_session(graph=g) as sess:
    407             # Basic test: read from both files.
    408             self.outputs = self._read_batch_features(
    409                 filenames=self.test_filenames,
    410                 num_epochs=num_epochs,
    411                 batch_size=batch_size)
    412             self._verify_records(sess, batch_size, num_epochs=num_epochs)
    413             with self.assertRaises(errors.OutOfRangeError):
    414               self._next_actual_batch(sess)
    415 
    416   def testReadWithEquivalentDataset(self):
    417     features = {
    418         "file": parsing_ops.FixedLenFeature([], dtypes.int64),
    419         "record": parsing_ops.FixedLenFeature([], dtypes.int64),
    420     }
    421     dataset = (core_readers.TFRecordDataset(self.test_filenames)
    422                .map(lambda x: parsing_ops.parse_single_example(x, features))
    423                .repeat(10).batch(2))
    424     iterator = dataset.make_initializable_iterator()
    425     init_op = iterator.initializer
    426     next_element = iterator.get_next()
    427 
    428     with self.test_session() as sess:
    429       sess.run(init_op)
    430       for file_batch, _, _, _, record_batch in self._next_expected_batch(
    431           range(self._num_files), 2, 10):
    432         actual_batch = sess.run(next_element)
    433         self.assertAllEqual(file_batch, actual_batch["file"])
    434         self.assertAllEqual(record_batch, actual_batch["record"])
    435       with self.assertRaises(errors.OutOfRangeError):
    436         sess.run(next_element)
    437 
    438 
    439 if __name__ == "__main__":
    440   test.main()
    441