Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for the experimental input pipeline ops."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import gzip
     21 import os
     22 import zlib
     23 
     24 from tensorflow.python.data.ops import iterator_ops
     25 from tensorflow.python.data.ops import readers
     26 from tensorflow.python.framework import constant_op
     27 from tensorflow.python.framework import dtypes
     28 from tensorflow.python.framework import errors
     29 from tensorflow.python.framework import ops
     30 from tensorflow.python.framework import tensor_shape
     31 from tensorflow.python.lib.io import python_io
     32 from tensorflow.python.ops import array_ops
     33 from tensorflow.python.ops import gen_dataset_ops
     34 from tensorflow.python.ops import io_ops
     35 from tensorflow.python.ops import parsing_ops
     36 from tensorflow.python.platform import test
     37 from tensorflow.python.util import compat
     38 
     39 
     40 class TextLineDatasetTest(test.TestCase):
     41 
     42   def _lineText(self, f, l):
     43     return compat.as_bytes("%d: %d" % (f, l))
     44 
     45   def _createFiles(self,
     46                    num_files,
     47                    num_lines,
     48                    crlf=False,
     49                    compression_type=None):
     50     filenames = []
     51     for i in range(num_files):
     52       fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i)
     53       filenames.append(fn)
     54       contents = []
     55       for j in range(num_lines):
     56         contents.append(self._lineText(i, j))
     57         # Always include a newline after the record unless it is
     58         # at the end of the file, in which case we include it
     59         if j + 1 != num_lines or i == 0:
     60           contents.append(b"\r\n" if crlf else b"\n")
     61       contents = b"".join(contents)
     62 
     63       if not compression_type:
     64         with open(fn, "wb") as f:
     65           f.write(contents)
     66       elif compression_type == "GZIP":
     67         with gzip.GzipFile(fn, "wb") as f:
     68           f.write(contents)
     69       elif compression_type == "ZLIB":
     70         contents = zlib.compress(contents)
     71         with open(fn, "wb") as f:
     72           f.write(contents)
     73       else:
     74         raise ValueError("Unsupported compression_type", compression_type)
     75 
     76     return filenames
     77 
     78   def _testTextLineDataset(self, compression_type=None):
     79     test_filenames = self._createFiles(
     80         2, 5, crlf=True, compression_type=compression_type)
     81     filenames = array_ops.placeholder(dtypes.string, shape=[None])
     82     num_epochs = array_ops.placeholder(dtypes.int64, shape=[])
     83     batch_size = array_ops.placeholder(dtypes.int64, shape=[])
     84 
     85     repeat_dataset = readers.TextLineDataset(
     86         filenames, compression_type=compression_type).repeat(num_epochs)
     87     batch_dataset = repeat_dataset.batch(batch_size)
     88 
     89     iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
     90     init_op = iterator.make_initializer(repeat_dataset)
     91     init_batch_op = iterator.make_initializer(batch_dataset)
     92     get_next = iterator.get_next()
     93 
     94     with self.test_session() as sess:
     95       # Basic test: read from file 0.
     96       sess.run(
     97           init_op, feed_dict={filenames: [test_filenames[0]],
     98                               num_epochs: 1})
     99       for i in range(5):
    100         self.assertEqual(self._lineText(0, i), sess.run(get_next))
    101       with self.assertRaises(errors.OutOfRangeError):
    102         sess.run(get_next)
    103 
    104       # Basic test: read from file 1.
    105       sess.run(
    106           init_op, feed_dict={filenames: [test_filenames[1]],
    107                               num_epochs: 1})
    108       for i in range(5):
    109         self.assertEqual(self._lineText(1, i), sess.run(get_next))
    110       with self.assertRaises(errors.OutOfRangeError):
    111         sess.run(get_next)
    112 
    113       # Basic test: read from both files.
    114       sess.run(init_op, feed_dict={filenames: test_filenames, num_epochs: 1})
    115       for j in range(2):
    116         for i in range(5):
    117           self.assertEqual(self._lineText(j, i), sess.run(get_next))
    118       with self.assertRaises(errors.OutOfRangeError):
    119         sess.run(get_next)
    120 
    121       # Test repeated iteration through both files.
    122       sess.run(init_op, feed_dict={filenames: test_filenames, num_epochs: 10})
    123       for _ in range(10):
    124         for j in range(2):
    125           for i in range(5):
    126             self.assertEqual(self._lineText(j, i), sess.run(get_next))
    127       with self.assertRaises(errors.OutOfRangeError):
    128         sess.run(get_next)
    129 
    130       # Test batched and repeated iteration through both files.
    131       sess.run(
    132           init_batch_op,
    133           feed_dict={filenames: test_filenames,
    134                      num_epochs: 10,
    135                      batch_size: 5})
    136       for _ in range(10):
    137         self.assertAllEqual([self._lineText(0, i) for i in range(5)],
    138                             sess.run(get_next))
    139         self.assertAllEqual([self._lineText(1, i) for i in range(5)],
    140                             sess.run(get_next))
    141 
    142   def testTextLineDatasetNoCompression(self):
    143     self._testTextLineDataset()
    144 
    145   def testTextLineDatasetGzipCompression(self):
    146     self._testTextLineDataset(compression_type="GZIP")
    147 
    148   def testTextLineDatasetZlibCompression(self):
    149     self._testTextLineDataset(compression_type="ZLIB")
    150 
    151   def testTextLineDatasetBuffering(self):
    152     test_filenames = self._createFiles(2, 5, crlf=True)
    153 
    154     repeat_dataset = readers.TextLineDataset(test_filenames, buffer_size=10)
    155     iterator = repeat_dataset.make_one_shot_iterator()
    156 
    157     with self.test_session() as sess:
    158       for j in range(2):
    159         for i in range(5):
    160           self.assertEqual(self._lineText(j, i), sess.run(iterator.get_next()))
    161       with self.assertRaises(errors.OutOfRangeError):
    162         sess.run(iterator.get_next())
    163 
    164 
    165 class FixedLengthRecordReaderTest(test.TestCase):
    166 
    167   def setUp(self):
    168     super(FixedLengthRecordReaderTest, self).setUp()
    169     self._num_files = 2
    170     self._num_records = 7
    171     self._header_bytes = 5
    172     self._record_bytes = 3
    173     self._footer_bytes = 2
    174 
    175   def _record(self, f, r):
    176     return compat.as_bytes(str(f * 2 + r) * self._record_bytes)
    177 
    178   def _createFiles(self):
    179     filenames = []
    180     for i in range(self._num_files):
    181       fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i)
    182       filenames.append(fn)
    183       with open(fn, "wb") as f:
    184         f.write(b"H" * self._header_bytes)
    185         for j in range(self._num_records):
    186           f.write(self._record(i, j))
    187         f.write(b"F" * self._footer_bytes)
    188     return filenames
    189 
    190   def testFixedLengthRecordDataset(self):
    191     test_filenames = self._createFiles()
    192     filenames = array_ops.placeholder(dtypes.string, shape=[None])
    193     num_epochs = array_ops.placeholder(dtypes.int64, shape=[])
    194     batch_size = array_ops.placeholder(dtypes.int64, shape=[])
    195 
    196     repeat_dataset = (readers.FixedLengthRecordDataset(
    197         filenames, self._record_bytes, self._header_bytes, self._footer_bytes)
    198                       .repeat(num_epochs))
    199     batch_dataset = repeat_dataset.batch(batch_size)
    200 
    201     iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
    202     init_op = iterator.make_initializer(repeat_dataset)
    203     init_batch_op = iterator.make_initializer(batch_dataset)
    204     get_next = iterator.get_next()
    205 
    206     with self.test_session() as sess:
    207       # Basic test: read from file 0.
    208       sess.run(
    209           init_op, feed_dict={filenames: [test_filenames[0]],
    210                               num_epochs: 1})
    211       for i in range(self._num_records):
    212         self.assertEqual(self._record(0, i), sess.run(get_next))
    213       with self.assertRaises(errors.OutOfRangeError):
    214         sess.run(get_next)
    215 
    216       # Basic test: read from file 1.
    217       sess.run(
    218           init_op, feed_dict={filenames: [test_filenames[1]],
    219                               num_epochs: 1})
    220       for i in range(self._num_records):
    221         self.assertEqual(self._record(1, i), sess.run(get_next))
    222       with self.assertRaises(errors.OutOfRangeError):
    223         sess.run(get_next)
    224 
    225       # Basic test: read from both files.
    226       sess.run(init_op, feed_dict={filenames: test_filenames, num_epochs: 1})
    227       for j in range(self._num_files):
    228         for i in range(self._num_records):
    229           self.assertEqual(self._record(j, i), sess.run(get_next))
    230       with self.assertRaises(errors.OutOfRangeError):
    231         sess.run(get_next)
    232 
    233       # Test repeated iteration through both files.
    234       sess.run(init_op, feed_dict={filenames: test_filenames, num_epochs: 10})
    235       for _ in range(10):
    236         for j in range(self._num_files):
    237           for i in range(self._num_records):
    238             self.assertEqual(self._record(j, i), sess.run(get_next))
    239       with self.assertRaises(errors.OutOfRangeError):
    240         sess.run(get_next)
    241 
    242       # Test batched and repeated iteration through both files.
    243       sess.run(
    244           init_batch_op,
    245           feed_dict={
    246               filenames: test_filenames,
    247               num_epochs: 10,
    248               batch_size: self._num_records
    249           })
    250       for _ in range(10):
    251         for j in range(self._num_files):
    252           self.assertAllEqual(
    253               [self._record(j, i) for i in range(self._num_records)],
    254               sess.run(get_next))
    255       with self.assertRaises(errors.OutOfRangeError):
    256         sess.run(get_next)
    257 
    258   def testFixedLengthRecordDatasetBuffering(self):
    259     test_filenames = self._createFiles()
    260     dataset = readers.FixedLengthRecordDataset(
    261         test_filenames,
    262         self._record_bytes,
    263         self._header_bytes,
    264         self._footer_bytes,
    265         buffer_size=10)
    266     iterator = dataset.make_one_shot_iterator()
    267 
    268     with self.test_session() as sess:
    269       for j in range(self._num_files):
    270         for i in range(self._num_records):
    271           self.assertEqual(self._record(j, i), sess.run(iterator.get_next()))
    272       with self.assertRaises(errors.OutOfRangeError):
    273         sess.run(iterator.get_next())
    274 
    275   def testFixedLengthRecordDatasetWrongSize(self):
    276     test_filenames = self._createFiles()
    277     dataset = readers.FixedLengthRecordDataset(
    278         test_filenames,
    279         self._record_bytes + 1,  # Incorrect record length.
    280         self._header_bytes,
    281         self._footer_bytes,
    282         buffer_size=10)
    283     iterator = dataset.make_one_shot_iterator()
    284 
    285     with self.test_session() as sess:
    286       with self.assertRaisesRegexp(
    287           errors.InvalidArgumentError,
    288           r"Excluding the header \(5 bytes\) and footer \(2 bytes\), input "
    289           r"file \".*fixed_length_record.0.txt\" has body length 21 bytes, "
    290           r"which is not an exact multiple of the record length \(4 bytes\)."):
    291         sess.run(iterator.get_next())
    292 
    293   def _iterator_checkpoint_path(self):
    294     return os.path.join(self.get_temp_dir(), "iterator")
    295 
    296   def _save_op(self, iterator_resource):
    297     iterator_state_variant = gen_dataset_ops.serialize_iterator(
    298         iterator_resource)
    299     save_op = io_ops.write_file(
    300         self._iterator_checkpoint_path(),
    301         parsing_ops.serialize_tensor(iterator_state_variant))
    302     return save_op
    303 
    304   def _restore_op(self, iterator_resource):
    305     iterator_state_variant = parsing_ops.parse_tensor(
    306         io_ops.read_file(self._iterator_checkpoint_path()), dtypes.variant)
    307     restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
    308                                                       iterator_state_variant)
    309     return restore_op
    310 
    311   def _build_iterator_graph(self, num_epochs):
    312     filenames = self._createFiles()
    313     dataset = (readers.FixedLengthRecordDataset(
    314         filenames, self._record_bytes, self._header_bytes, self._footer_bytes)
    315                .repeat(num_epochs))
    316     iterator = dataset.make_initializable_iterator()
    317     init_op = iterator.initializer
    318     get_next_op = iterator.get_next()
    319     save_op = self._save_op(iterator._iterator_resource)
    320     restore_op = self._restore_op(iterator._iterator_resource)
    321     return init_op, get_next_op, save_op, restore_op
    322 
    323   def _restore_iterator(self):
    324     output_types = dtypes.string
    325     output_shapes = tensor_shape.scalar()
    326     iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes)
    327     get_next = iterator.get_next()
    328     restore_op = self._restore_op(iterator._iterator_resource)
    329     return restore_op, get_next
    330 
    331   def testSaveRestore(self):
    332     num_epochs = 10
    333     epoch_break = 5
    334     file_break = self._num_files // 2
    335     record_break = self._num_records // 2
    336 
    337     with ops.Graph().as_default() as g:
    338       init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
    339           num_epochs=num_epochs)
    340       with self.test_session(graph=g) as sess:
    341         sess.run(init_op)
    342         # Note: There is no checkpoint saved currently so a NotFoundError is
    343         # raised.
    344         with self.assertRaises(errors.NotFoundError):
    345           sess.run(restore_op)
    346         for epoch in range(num_epochs):
    347           for f in range(self._num_files):
    348             for r in range(self._num_records):
    349               if (epoch == epoch_break and f == file_break and
    350                   r == record_break):
    351                 sess.run(save_op)
    352                 break
    353               self.assertEqual(self._record(f, r), sess.run(get_next_op))
    354             else:
    355               continue
    356             break
    357           else:
    358             continue
    359           break
    360         else:
    361           with self.assertRaises(errors.OutOfRangeError):
    362             sess.run(get_next_op)
    363 
    364     with ops.Graph().as_default() as g:
    365       init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
    366           num_epochs=num_epochs)
    367       with self.test_session(graph=g) as sess:
    368         sess.run(restore_op)
    369         for epoch in range(num_epochs):
    370           for f in range(self._num_files):
    371             for r in range(self._num_records):
    372               if (epoch < epoch_break or
    373                   (epoch == epoch_break and f < file_break) or
    374                   (epoch == epoch_break and f == file_break and
    375                    r < record_break)):
    376                 continue
    377               self.assertEqual(self._record(f, r), sess.run(get_next_op))
    378         with self.assertRaises(errors.OutOfRangeError):
    379           sess.run(get_next_op)
    380 
    381   def testInitThenRestore(self):
    382     # Note: Calling init_op before restore_op is redundant. This test just makes
    383     # sure we do not fail if restore is called on an already initialized
    384     # iterator resource.
    385     num_epochs = 10
    386     epoch_break = 5
    387     file_break = self._num_files // 2
    388     record_break = self._num_records // 2
    389 
    390     with ops.Graph().as_default() as g:
    391       init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
    392           num_epochs=num_epochs)
    393       with self.test_session(graph=g) as sess:
    394         sess.run(init_op)
    395         # Note: There is no checkpoint saved currently so a NotFoundError is
    396         # raised.
    397         with self.assertRaises(errors.NotFoundError):
    398           sess.run(restore_op)
    399         for epoch in range(num_epochs):
    400           for f in range(self._num_files):
    401             for r in range(self._num_records):
    402               if (epoch == epoch_break and f == file_break and
    403                   r == record_break):
    404                 sess.run(save_op)
    405                 break
    406               self.assertEqual(self._record(f, r), sess.run(get_next_op))
    407             else:
    408               continue
    409             break
    410           else:
    411             continue
    412           break
    413         else:
    414           with self.assertRaises(errors.OutOfRangeError):
    415             sess.run(get_next_op)
    416 
    417     with ops.Graph().as_default() as g:
    418       init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
    419           num_epochs=num_epochs)
    420       with self.test_session(graph=g) as sess:
    421         sess.run(init_op)
    422         sess.run(restore_op)
    423         for epoch in range(num_epochs):
    424           for f in range(self._num_files):
    425             for r in range(self._num_records):
    426               if (epoch < epoch_break or
    427                   (epoch == epoch_break and f < file_break) or
    428                   (epoch == epoch_break and f == file_break and
    429                    r < record_break)):
    430                 continue
    431               self.assertEqual(self._record(f, r), sess.run(get_next_op))
    432         with self.assertRaises(errors.OutOfRangeError):
    433           sess.run(get_next_op)
    434 
    435   def testRestoreInModifiedGraph(self):
    436     num_epochs = 10
    437     num_epochs_1 = 20
    438     epoch_break = 5
    439     file_break = self._num_files // 2
    440     record_break = self._num_records // 2
    441 
    442     with ops.Graph().as_default() as g:
    443       init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
    444           num_epochs=num_epochs)
    445       with self.test_session(graph=g) as sess:
    446         sess.run(init_op)
    447         # Note: There is no checkpoint saved currently so a NotFoundError is
    448         # raised.
    449         with self.assertRaises(errors.NotFoundError):
    450           sess.run(restore_op)
    451         for epoch in range(num_epochs):
    452           for f in range(self._num_files):
    453             for r in range(self._num_records):
    454               if (epoch == epoch_break and f == file_break and
    455                   r == record_break):
    456                 sess.run(save_op)
    457                 break
    458               self.assertEqual(self._record(f, r), sess.run(get_next_op))
    459             else:
    460               continue
    461             break
    462           else:
    463             continue
    464           break
    465         else:
    466           with self.assertRaises(errors.OutOfRangeError):
    467             sess.run(get_next_op)
    468 
    469     with ops.Graph().as_default() as g:
    470       init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
    471           num_epochs=num_epochs_1)
    472       with self.test_session(graph=g) as sess:
    473         sess.run(restore_op)
    474         for epoch in range(num_epochs):
    475           for f in range(self._num_files):
    476             for r in range(self._num_records):
    477               if (epoch < epoch_break or
    478                   (epoch == epoch_break and f < file_break) or
    479                   (epoch == epoch_break and f == file_break and
    480                    r < record_break)):
    481                 continue
    482               self.assertEqual(self._record(f, r), sess.run(get_next_op))
    483         with self.assertRaises(errors.OutOfRangeError):
    484           sess.run(get_next_op)
    485 
    486   def testRestoreWithoutBuildingDatasetGraph(self):
    487     num_epochs = 10
    488     epoch_break = 5
    489     file_break = self._num_files // 2
    490     record_break = self._num_records // 2
    491 
    492     with ops.Graph().as_default() as g:
    493       init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
    494           num_epochs=num_epochs)
    495       with self.test_session(graph=g) as sess:
    496         sess.run(init_op)
    497         # Note: There is no checkpoint saved currently so a NotFoundError is
    498         # raised.
    499         with self.assertRaises(errors.NotFoundError):
    500           sess.run(restore_op)
    501         for epoch in range(num_epochs):
    502           for f in range(self._num_files):
    503             for r in range(self._num_records):
    504               if (epoch == epoch_break and f == file_break and
    505                   r == record_break):
    506                 sess.run(save_op)
    507                 break
    508               self.assertEqual(self._record(f, r), sess.run(get_next_op))
    509             else:
    510               continue
    511             break
    512           else:
    513             continue
    514           break
    515         else:
    516           with self.assertRaises(errors.OutOfRangeError):
    517             sess.run(get_next_op)
    518 
    519     with ops.Graph().as_default() as g:
    520       restore_op, get_next_op = self._restore_iterator()
    521       with self.test_session(graph=g) as sess:
    522         sess.run(restore_op)
    523         for epoch in range(num_epochs):
    524           for f in range(self._num_files):
    525             for r in range(self._num_records):
    526               if (epoch < epoch_break or
    527                   (epoch == epoch_break and f < file_break) or
    528                   (epoch == epoch_break and f == file_break and
    529                    r < record_break)):
    530                 continue
    531               self.assertEqual(self._record(f, r), sess.run(get_next_op))
    532         with self.assertRaises(errors.OutOfRangeError):
    533           sess.run(get_next_op)
    534 
    535   def testRestoreUnusedIterator(self):
    536     num_epochs = 10
    537     with ops.Graph().as_default() as g:
    538       init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
    539           num_epochs=num_epochs)
    540       with self.test_session(graph=g) as sess:
    541         sess.run(init_op)
    542         # Note: There is no checkpoint saved currently so a NotFoundError is
    543         # raised.
    544         with self.assertRaises(errors.NotFoundError):
    545           sess.run(restore_op)
    546         # Save unused iterator.
    547         sess.run(save_op)
    548     with ops.Graph().as_default() as g:
    549       init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
    550           num_epochs=num_epochs)
    551       with self.test_session(graph=g) as sess:
    552         sess.run(restore_op)
    553         for _ in range(num_epochs * self._num_files * self._num_records):
    554           sess.run(get_next_op)
    555         with self.assertRaises(errors.OutOfRangeError):
    556           sess.run(get_next_op)
    557 
    558   def testRestoreExhaustedIterator(self):
    559     num_epochs = 10
    560 
    561     with ops.Graph().as_default() as g:
    562       init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
    563           num_epochs=num_epochs)
    564       with self.test_session(graph=g) as sess:
    565         sess.run(init_op)
    566         # Note: There is no checkpoint saved currently so a NotFoundError is
    567         # raised.
    568         with self.assertRaises(errors.NotFoundError):
    569           sess.run(restore_op)
    570         for _ in range(num_epochs):
    571           for f in range(self._num_files):
    572             for r in range(self._num_records):
    573               self.assertEqual(self._record(f, r), sess.run(get_next_op))
    574         with self.assertRaises(errors.OutOfRangeError):
    575           sess.run(get_next_op)
    576         sess.run(save_op)
    577 
    578     with ops.Graph().as_default() as g:
    579       init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
    580           num_epochs=num_epochs)
    581       with self.test_session(graph=g) as sess:
    582         sess.run(restore_op)
    583         with self.assertRaises(errors.OutOfRangeError):
    584           sess.run(get_next_op)
    585 
    586 
    587 class TFRecordDatasetTest(test.TestCase):
    588 
    589   def setUp(self):
    590     super(TFRecordDatasetTest, self).setUp()
    591     self._num_files = 2
    592     self._num_records = 7
    593 
    594     self.test_filenames = self._createFiles()
    595 
    596     self.filenames = array_ops.placeholder(dtypes.string, shape=[None])
    597     self.num_epochs = array_ops.placeholder_with_default(
    598         constant_op.constant(1, dtypes.int64), shape=[])
    599     self.compression_type = array_ops.placeholder_with_default("", shape=[])
    600     self.batch_size = array_ops.placeholder(dtypes.int64, shape=[])
    601 
    602     repeat_dataset = readers.TFRecordDataset(self.filenames,
    603                                              self.compression_type).repeat(
    604                                                  self.num_epochs)
    605     batch_dataset = repeat_dataset.batch(self.batch_size)
    606 
    607     iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
    608     self.init_op = iterator.make_initializer(repeat_dataset)
    609     self.init_batch_op = iterator.make_initializer(batch_dataset)
    610     self.get_next = iterator.get_next()
    611 
    612   def _record(self, f, r):
    613     return compat.as_bytes("Record %d of file %d" % (r, f))
    614 
    615   def _createFiles(self):
    616     filenames = []
    617     for i in range(self._num_files):
    618       fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
    619       filenames.append(fn)
    620       writer = python_io.TFRecordWriter(fn)
    621       for j in range(self._num_records):
    622         writer.write(self._record(i, j))
    623       writer.close()
    624     return filenames
    625 
    626   def testReadOneEpoch(self):
    627     with self.test_session() as sess:
    628       # Basic test: read from file 0.
    629       sess.run(
    630           self.init_op,
    631           feed_dict={
    632               self.filenames: [self.test_filenames[0]],
    633               self.num_epochs: 1
    634           })
    635       for i in range(self._num_records):
    636         self.assertAllEqual(self._record(0, i), sess.run(self.get_next))
    637       with self.assertRaises(errors.OutOfRangeError):
    638         sess.run(self.get_next)
    639 
    640       # Basic test: read from file 1.
    641       sess.run(
    642           self.init_op,
    643           feed_dict={
    644               self.filenames: [self.test_filenames[1]],
    645               self.num_epochs: 1
    646           })
    647       for i in range(self._num_records):
    648         self.assertAllEqual(self._record(1, i), sess.run(self.get_next))
    649       with self.assertRaises(errors.OutOfRangeError):
    650         sess.run(self.get_next)
    651 
    652       # Basic test: read from both files.
    653       sess.run(
    654           self.init_op,
    655           feed_dict={self.filenames: self.test_filenames,
    656                      self.num_epochs: 1})
    657       for j in range(self._num_files):
    658         for i in range(self._num_records):
    659           self.assertAllEqual(self._record(j, i), sess.run(self.get_next))
    660       with self.assertRaises(errors.OutOfRangeError):
    661         sess.run(self.get_next)
    662 
    663   def testReadTenEpochs(self):
    664     with self.test_session() as sess:
    665       sess.run(
    666           self.init_op,
    667           feed_dict={self.filenames: self.test_filenames,
    668                      self.num_epochs: 10})
    669       for _ in range(10):
    670         for j in range(self._num_files):
    671           for i in range(self._num_records):
    672             self.assertAllEqual(self._record(j, i), sess.run(self.get_next))
    673       with self.assertRaises(errors.OutOfRangeError):
    674         sess.run(self.get_next)
    675 
    676   def testReadTenEpochsOfBatches(self):
    677     with self.test_session() as sess:
    678       sess.run(
    679           self.init_batch_op,
    680           feed_dict={
    681               self.filenames: self.test_filenames,
    682               self.num_epochs: 10,
    683               self.batch_size: self._num_records
    684           })
    685       for _ in range(10):
    686         for j in range(self._num_files):
    687           values = sess.run(self.get_next)
    688           self.assertAllEqual(
    689               [self._record(j, i) for i in range(self._num_records)], values)
    690       with self.assertRaises(errors.OutOfRangeError):
    691         sess.run(self.get_next)
    692 
    693   def testReadZlibFiles(self):
    694     zlib_files = []
    695     for i, fn in enumerate(self.test_filenames):
    696       with open(fn, "rb") as f:
    697         cdata = zlib.compress(f.read())
    698 
    699         zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i)
    700         with open(zfn, "wb") as f:
    701           f.write(cdata)
    702         zlib_files.append(zfn)
    703 
    704     with self.test_session() as sess:
    705       sess.run(
    706           self.init_op,
    707           feed_dict={self.filenames: zlib_files,
    708                      self.compression_type: "ZLIB"})
    709       for j in range(self._num_files):
    710         for i in range(self._num_records):
    711           self.assertAllEqual(self._record(j, i), sess.run(self.get_next))
    712       with self.assertRaises(errors.OutOfRangeError):
    713         sess.run(self.get_next)
    714 
    715   def testReadGzipFiles(self):
    716     gzip_files = []
    717     for i, fn in enumerate(self.test_filenames):
    718       with open(fn, "rb") as f:
    719         gzfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i)
    720         with gzip.GzipFile(gzfn, "wb") as gzf:
    721           gzf.write(f.read())
    722         gzip_files.append(gzfn)
    723 
    724     with self.test_session() as sess:
    725       sess.run(
    726           self.init_op,
    727           feed_dict={self.filenames: gzip_files,
    728                      self.compression_type: "GZIP"})
    729       for j in range(self._num_files):
    730         for i in range(self._num_records):
    731           self.assertAllEqual(self._record(j, i), sess.run(self.get_next))
    732       with self.assertRaises(errors.OutOfRangeError):
    733         sess.run(self.get_next)
    734 
    735   def testReadWithBuffer(self):
    736     one_mebibyte = 2**20
    737     d = readers.TFRecordDataset(self.test_filenames, buffer_size=one_mebibyte)
    738     iterator = d.make_one_shot_iterator()
    739     with self.test_session() as sess:
    740       for j in range(self._num_files):
    741         for i in range(self._num_records):
    742           self.assertAllEqual(self._record(j, i), sess.run(iterator.get_next()))
    743       with self.assertRaises(errors.OutOfRangeError):
    744         sess.run(iterator.get_next())
    745 
    746 
    747 if __name__ == "__main__":
    748   test.main()
    749