Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Base class for testing reader datasets."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import gzip
     22 import os
     23 import zlib
     24 
     25 from tensorflow.core.example import example_pb2
     26 from tensorflow.core.example import feature_pb2
     27 from tensorflow.python.data.experimental.ops import readers
     28 from tensorflow.python.data.kernel_tests import test_base
     29 from tensorflow.python.data.ops import readers as core_readers
     30 from tensorflow.python.framework import dtypes
     31 from tensorflow.python.lib.io import python_io
     32 from tensorflow.python.ops import parsing_ops
     33 from tensorflow.python.util import compat
     34 
     35 
     36 class FixedLengthRecordDatasetTestBase(test_base.DatasetTestBase):
     37   """Base class for setting up and testing FixedLengthRecordDataset."""
     38 
     39   def setUp(self):
     40     super(FixedLengthRecordDatasetTestBase, self).setUp()
     41     self._num_files = 2
     42     self._num_records = 7
     43     self._header_bytes = 5
     44     self._record_bytes = 3
     45     self._footer_bytes = 2
     46 
     47   def _record(self, f, r):
     48     return compat.as_bytes(str(f * 2 + r) * self._record_bytes)
     49 
     50   def _createFiles(self):
     51     filenames = []
     52     for i in range(self._num_files):
     53       fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i)
     54       filenames.append(fn)
     55       with open(fn, "wb") as f:
     56         f.write(b"H" * self._header_bytes)
     57         for j in range(self._num_records):
     58           f.write(self._record(i, j))
     59         f.write(b"F" * self._footer_bytes)
     60     return filenames
     61 
     62 
     63 class MakeBatchedFeaturesDatasetTestBase(test_base.DatasetTestBase):
     64   """Base class for setting up and testing `make_batched_features_dataset`."""
     65 
     66   def setUp(self):
     67     super(MakeBatchedFeaturesDatasetTestBase, self).setUp()
     68     self._num_files = 2
     69     self._num_records = 7
     70     self.test_filenames = self._createFiles()
     71 
     72   def make_batch_feature(self,
     73                          filenames,
     74                          num_epochs,
     75                          batch_size,
     76                          label_key=None,
     77                          reader_num_threads=1,
     78                          parser_num_threads=1,
     79                          shuffle=False,
     80                          shuffle_seed=None,
     81                          drop_final_batch=False):
     82     self.filenames = filenames
     83     self.num_epochs = num_epochs
     84     self.batch_size = batch_size
     85 
     86     return readers.make_batched_features_dataset(
     87         file_pattern=self.filenames,
     88         batch_size=self.batch_size,
     89         features={
     90             "file": parsing_ops.FixedLenFeature([], dtypes.int64),
     91             "record": parsing_ops.FixedLenFeature([], dtypes.int64),
     92             "keywords": parsing_ops.VarLenFeature(dtypes.string),
     93             "label": parsing_ops.FixedLenFeature([], dtypes.string),
     94         },
     95         label_key=label_key,
     96         reader=core_readers.TFRecordDataset,
     97         num_epochs=self.num_epochs,
     98         shuffle=shuffle,
     99         shuffle_seed=shuffle_seed,
    100         reader_num_threads=reader_num_threads,
    101         parser_num_threads=parser_num_threads,
    102         drop_final_batch=drop_final_batch)
    103 
    104   def _record(self, f, r, l):
    105     example = example_pb2.Example(
    106         features=feature_pb2.Features(
    107             feature={
    108                 "file":
    109                     feature_pb2.Feature(
    110                         int64_list=feature_pb2.Int64List(value=[f])),
    111                 "record":
    112                     feature_pb2.Feature(
    113                         int64_list=feature_pb2.Int64List(value=[r])),
    114                 "keywords":
    115                     feature_pb2.Feature(
    116                         bytes_list=feature_pb2.BytesList(
    117                             value=self._get_keywords(f, r))),
    118                 "label":
    119                     feature_pb2.Feature(
    120                         bytes_list=feature_pb2.BytesList(
    121                             value=[compat.as_bytes(l)]))
    122             }))
    123     return example.SerializeToString()
    124 
    125   def _get_keywords(self, f, r):
    126     num_keywords = 1 + (f + r) % 2
    127     keywords = []
    128     for index in range(num_keywords):
    129       keywords.append(compat.as_bytes("keyword%d" % index))
    130     return keywords
    131 
    132   def _sum_keywords(self, num_files):
    133     sum_keywords = 0
    134     for i in range(num_files):
    135       for j in range(self._num_records):
    136         sum_keywords += 1 + (i + j) % 2
    137     return sum_keywords
    138 
    139   def _createFiles(self):
    140     filenames = []
    141     for i in range(self._num_files):
    142       fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
    143       filenames.append(fn)
    144       writer = python_io.TFRecordWriter(fn)
    145       for j in range(self._num_records):
    146         writer.write(self._record(i, j, "fake-label"))
    147       writer.close()
    148     return filenames
    149 
    150   def _run_actual_batch(self, outputs, label_key_provided=False):
    151     if label_key_provided:
    152       # outputs would be a tuple of (feature dict, label)
    153       features, label = self.evaluate(outputs())
    154     else:
    155       features = self.evaluate(outputs())
    156       label = features["label"]
    157     file_out = features["file"]
    158     keywords_indices = features["keywords"].indices
    159     keywords_values = features["keywords"].values
    160     keywords_dense_shape = features["keywords"].dense_shape
    161     record = features["record"]
    162     return ([
    163         file_out, keywords_indices, keywords_values, keywords_dense_shape,
    164         record, label
    165     ])
    166 
    167   def _next_actual_batch(self, label_key_provided=False):
    168     return self._run_actual_batch(self.outputs, label_key_provided)
    169 
    170   def _interleave(self, iterators, cycle_length):
    171     pending_iterators = iterators
    172     open_iterators = []
    173     num_open = 0
    174     for i in range(cycle_length):
    175       if pending_iterators:
    176         open_iterators.append(pending_iterators.pop(0))
    177         num_open += 1
    178 
    179     while num_open:
    180       for i in range(min(cycle_length, len(open_iterators))):
    181         if open_iterators[i] is None:
    182           continue
    183         try:
    184           yield next(open_iterators[i])
    185         except StopIteration:
    186           if pending_iterators:
    187             open_iterators[i] = pending_iterators.pop(0)
    188           else:
    189             open_iterators[i] = None
    190             num_open -= 1
    191 
    192   def _next_expected_batch(self,
    193                            file_indices,
    194                            batch_size,
    195                            num_epochs,
    196                            cycle_length=1):
    197 
    198     def _next_record(file_indices):
    199       for j in file_indices:
    200         for i in range(self._num_records):
    201           yield j, i, compat.as_bytes("fake-label")
    202 
    203     def _next_record_interleaved(file_indices, cycle_length):
    204       return self._interleave([_next_record([i]) for i in file_indices],
    205                               cycle_length)
    206 
    207     file_batch = []
    208     keywords_batch_indices = []
    209     keywords_batch_values = []
    210     keywords_batch_max_len = 0
    211     record_batch = []
    212     batch_index = 0
    213     label_batch = []
    214     for _ in range(num_epochs):
    215       if cycle_length == 1:
    216         next_records = _next_record(file_indices)
    217       else:
    218         next_records = _next_record_interleaved(file_indices, cycle_length)
    219       for record in next_records:
    220         f = record[0]
    221         r = record[1]
    222         label_batch.append(record[2])
    223         file_batch.append(f)
    224         record_batch.append(r)
    225         keywords = self._get_keywords(f, r)
    226         keywords_batch_values.extend(keywords)
    227         keywords_batch_indices.extend(
    228             [[batch_index, i] for i in range(len(keywords))])
    229         batch_index += 1
    230         keywords_batch_max_len = max(keywords_batch_max_len, len(keywords))
    231         if len(file_batch) == batch_size:
    232           yield [
    233               file_batch, keywords_batch_indices, keywords_batch_values,
    234               [batch_size, keywords_batch_max_len], record_batch, label_batch
    235           ]
    236           file_batch = []
    237           keywords_batch_indices = []
    238           keywords_batch_values = []
    239           keywords_batch_max_len = 0
    240           record_batch = []
    241           batch_index = 0
    242           label_batch = []
    243     if file_batch:
    244       yield [
    245           file_batch, keywords_batch_indices, keywords_batch_values,
    246           [len(file_batch), keywords_batch_max_len], record_batch, label_batch
    247       ]
    248 
    249   def verify_records(self,
    250                      batch_size,
    251                      file_index=None,
    252                      num_epochs=1,
    253                      label_key_provided=False,
    254                      interleave_cycle_length=1):
    255     if file_index is not None:
    256       file_indices = [file_index]
    257     else:
    258       file_indices = range(self._num_files)
    259 
    260     for expected_batch in self._next_expected_batch(
    261         file_indices,
    262         batch_size,
    263         num_epochs,
    264         cycle_length=interleave_cycle_length):
    265       actual_batch = self._next_actual_batch(
    266           label_key_provided=label_key_provided)
    267       for i in range(len(expected_batch)):
    268         self.assertAllEqual(expected_batch[i], actual_batch[i])
    269 
    270 
    271 class TextLineDatasetTestBase(test_base.DatasetTestBase):
    272   """Base class for setting up and testing TextLineDataset."""
    273 
    274   def _lineText(self, f, l):
    275     return compat.as_bytes("%d: %d" % (f, l))
    276 
    277   def _createFiles(self,
    278                    num_files,
    279                    num_lines,
    280                    crlf=False,
    281                    compression_type=None):
    282     filenames = []
    283     for i in range(num_files):
    284       fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i)
    285       filenames.append(fn)
    286       contents = []
    287       for j in range(num_lines):
    288         contents.append(self._lineText(i, j))
    289         # Always include a newline after the record unless it is
    290         # at the end of the file, in which case we include it
    291         if j + 1 != num_lines or i == 0:
    292           contents.append(b"\r\n" if crlf else b"\n")
    293       contents = b"".join(contents)
    294 
    295       if not compression_type:
    296         with open(fn, "wb") as f:
    297           f.write(contents)
    298       elif compression_type == "GZIP":
    299         with gzip.GzipFile(fn, "wb") as f:
    300           f.write(contents)
    301       elif compression_type == "ZLIB":
    302         contents = zlib.compress(contents)
    303         with open(fn, "wb") as f:
    304           f.write(contents)
    305       else:
    306         raise ValueError("Unsupported compression_type", compression_type)
    307 
    308     return filenames
    309 
    310 
    311 class TFRecordDatasetTestBase(test_base.DatasetTestBase):
    312   """Base class for setting up and testing TFRecordDataset."""
    313 
    314   def _interleave(self, iterators, cycle_length):
    315     pending_iterators = iterators
    316     open_iterators = []
    317     num_open = 0
    318     for i in range(cycle_length):
    319       if pending_iterators:
    320         open_iterators.append(pending_iterators.pop(0))
    321         num_open += 1
    322 
    323     while num_open:
    324       for i in range(min(cycle_length, len(open_iterators))):
    325         if open_iterators[i] is None:
    326           continue
    327         try:
    328           yield next(open_iterators[i])
    329         except StopIteration:
    330           if pending_iterators:
    331             open_iterators[i] = pending_iterators.pop(0)
    332           else:
    333             open_iterators[i] = None
    334             num_open -= 1
    335 
    336   def _next_expected_batch(self, file_indices, batch_size, num_epochs,
    337                            cycle_length, drop_final_batch, use_parser_fn):
    338 
    339     def _next_record(file_indices):
    340       for j in file_indices:
    341         for i in range(self._num_records):
    342           yield j, i
    343 
    344     def _next_record_interleaved(file_indices, cycle_length):
    345       return self._interleave([_next_record([i]) for i in file_indices],
    346                               cycle_length)
    347 
    348     record_batch = []
    349     batch_index = 0
    350     for _ in range(num_epochs):
    351       if cycle_length == 1:
    352         next_records = _next_record(file_indices)
    353       else:
    354         next_records = _next_record_interleaved(file_indices, cycle_length)
    355       for f, r in next_records:
    356         record = self._record(f, r)
    357         if use_parser_fn:
    358           record = record[1:]
    359         record_batch.append(record)
    360         batch_index += 1
    361         if len(record_batch) == batch_size:
    362           yield record_batch
    363           record_batch = []
    364           batch_index = 0
    365     if record_batch and not drop_final_batch:
    366       yield record_batch
    367 
    368   def _verify_records(self, outputs, batch_size, file_index, num_epochs,
    369                       interleave_cycle_length, drop_final_batch, use_parser_fn):
    370     if file_index is not None:
    371       if isinstance(file_index, list):
    372         file_indices = file_index
    373       else:
    374         file_indices = [file_index]
    375     else:
    376       file_indices = range(self._num_files)
    377 
    378     for expected_batch in self._next_expected_batch(
    379         file_indices, batch_size, num_epochs, interleave_cycle_length,
    380         drop_final_batch, use_parser_fn):
    381       actual_batch = self.evaluate(outputs())
    382       self.assertAllEqual(expected_batch, actual_batch)
    383 
    384   def setUp(self):
    385     super(TFRecordDatasetTestBase, self).setUp()
    386     self._num_files = 2
    387     self._num_records = 7
    388 
    389     self.test_filenames = self._createFiles()
    390 
    391   def _record(self, f, r):
    392     return compat.as_bytes("Record %d of file %d" % (r, f))
    393 
    394   def _createFiles(self):
    395     filenames = []
    396     for i in range(self._num_files):
    397       fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
    398       filenames.append(fn)
    399       writer = python_io.TFRecordWriter(fn)
    400       for j in range(self._num_records):
    401         writer.write(self._record(i, j))
    402       writer.close()
    403     return filenames
    404