Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for Reader ops from io_ops."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import collections
     22 import gzip
     23 import os
     24 import shutil
     25 import threading
     26 import zlib
     27 
     28 import six
     29 
     30 from tensorflow.core.protobuf import config_pb2
     31 from tensorflow.python.framework import dtypes
     32 from tensorflow.python.framework import errors_impl
     33 from tensorflow.python.lib.io import tf_record
     34 from tensorflow.python.ops import data_flow_ops
     35 from tensorflow.python.ops import io_ops
     36 from tensorflow.python.ops import variables
     37 from tensorflow.python.platform import test
     38 from tensorflow.python.training import coordinator
     39 from tensorflow.python.training import input as input_lib
     40 from tensorflow.python.training import queue_runner_impl
     41 from tensorflow.python.util import compat
     42 
     43 prefix_path = "tensorflow/core/lib"
     44 
     45 # pylint: disable=invalid-name
     46 TFRecordCompressionType = tf_record.TFRecordCompressionType
     47 # pylint: enable=invalid-name
     48 
     49 # Edgar Allan Poe's 'Eldorado'
     50 _TEXT = b"""Gaily bedight,
     51     A gallant knight,
     52     In sunshine and in shadow,
     53     Had journeyed long,
     54     Singing a song,
     55     In search of Eldorado.
     56 
     57     But he grew old
     58     This knight so bold
     59     And o'er his heart a shadow
     60     Fell as he found
     61     No spot of ground
     62     That looked like Eldorado.
     63 
     64    And, as his strength
     65    Failed him at length,
     66    He met a pilgrim shadow
     67    'Shadow,' said he,
     68    'Where can it be
     69    This land of Eldorado?'
     70 
     71    'Over the Mountains
     72     Of the Moon'
     73     Down the Valley of the Shadow,
     74     Ride, boldly ride,'
     75     The shade replied,
     76     'If you seek for Eldorado!'
     77     """
     78 
     79 
     80 class IdentityReaderTest(test.TestCase):
     81 
     82   def _ExpectRead(self, sess, key, value, expected):
     83     k, v = sess.run([key, value])
     84     self.assertAllEqual(expected, k)
     85     self.assertAllEqual(expected, v)
     86 
     87   def testOneEpoch(self):
     88     with self.test_session() as sess:
     89       reader = io_ops.IdentityReader("test_reader")
     90       work_completed = reader.num_work_units_completed()
     91       produced = reader.num_records_produced()
     92       queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
     93       queued_length = queue.size()
     94       key, value = reader.read(queue)
     95 
     96       self.assertAllEqual(0, work_completed.eval())
     97       self.assertAllEqual(0, produced.eval())
     98       self.assertAllEqual(0, queued_length.eval())
     99 
    100       queue.enqueue_many([["A", "B", "C"]]).run()
    101       queue.close().run()
    102       self.assertAllEqual(3, queued_length.eval())
    103 
    104       self._ExpectRead(sess, key, value, b"A")
    105       self.assertAllEqual(1, produced.eval())
    106 
    107       self._ExpectRead(sess, key, value, b"B")
    108 
    109       self._ExpectRead(sess, key, value, b"C")
    110       self.assertAllEqual(3, produced.eval())
    111       self.assertAllEqual(0, queued_length.eval())
    112 
    113       with self.assertRaisesOpError("is closed and has insufficient elements "
    114                                     "\\(requested 1, current size 0\\)"):
    115         sess.run([key, value])
    116 
    117       self.assertAllEqual(3, work_completed.eval())
    118       self.assertAllEqual(3, produced.eval())
    119       self.assertAllEqual(0, queued_length.eval())
    120 
    121   def testMultipleEpochs(self):
    122     with self.test_session() as sess:
    123       reader = io_ops.IdentityReader("test_reader")
    124       queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
    125       enqueue = queue.enqueue_many([["DD", "EE"]])
    126       key, value = reader.read(queue)
    127 
    128       enqueue.run()
    129       self._ExpectRead(sess, key, value, b"DD")
    130       self._ExpectRead(sess, key, value, b"EE")
    131       enqueue.run()
    132       self._ExpectRead(sess, key, value, b"DD")
    133       self._ExpectRead(sess, key, value, b"EE")
    134       enqueue.run()
    135       self._ExpectRead(sess, key, value, b"DD")
    136       self._ExpectRead(sess, key, value, b"EE")
    137       queue.close().run()
    138       with self.assertRaisesOpError("is closed and has insufficient elements "
    139                                     "\\(requested 1, current size 0\\)"):
    140         sess.run([key, value])
    141 
    142   def testSerializeRestore(self):
    143     with self.test_session() as sess:
    144       reader = io_ops.IdentityReader("test_reader")
    145       produced = reader.num_records_produced()
    146       queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
    147       queue.enqueue_many([["X", "Y", "Z"]]).run()
    148       key, value = reader.read(queue)
    149 
    150       self._ExpectRead(sess, key, value, b"X")
    151       self.assertAllEqual(1, produced.eval())
    152       state = reader.serialize_state().eval()
    153 
    154       self._ExpectRead(sess, key, value, b"Y")
    155       self._ExpectRead(sess, key, value, b"Z")
    156       self.assertAllEqual(3, produced.eval())
    157 
    158       queue.enqueue_many([["Y", "Z"]]).run()
    159       queue.close().run()
    160       reader.restore_state(state).run()
    161       self.assertAllEqual(1, produced.eval())
    162       self._ExpectRead(sess, key, value, b"Y")
    163       self._ExpectRead(sess, key, value, b"Z")
    164       with self.assertRaisesOpError("is closed and has insufficient elements "
    165                                     "\\(requested 1, current size 0\\)"):
    166         sess.run([key, value])
    167       self.assertAllEqual(3, produced.eval())
    168 
    169       self.assertEqual(bytes, type(state))
    170 
    171       with self.assertRaises(ValueError):
    172         reader.restore_state([])
    173 
    174       with self.assertRaises(ValueError):
    175         reader.restore_state([state, state])
    176 
    177       with self.assertRaisesOpError(
    178           "Could not parse state for IdentityReader 'test_reader'"):
    179         reader.restore_state(state[1:]).run()
    180 
    181       with self.assertRaisesOpError(
    182           "Could not parse state for IdentityReader 'test_reader'"):
    183         reader.restore_state(state[:-1]).run()
    184 
    185       with self.assertRaisesOpError(
    186           "Could not parse state for IdentityReader 'test_reader'"):
    187         reader.restore_state(state + b"ExtraJunk").run()
    188 
    189       with self.assertRaisesOpError(
    190           "Could not parse state for IdentityReader 'test_reader'"):
    191         reader.restore_state(b"PREFIX" + state).run()
    192 
    193       with self.assertRaisesOpError(
    194           "Could not parse state for IdentityReader 'test_reader'"):
    195         reader.restore_state(b"BOGUS" + state[5:]).run()
    196 
    197   def testReset(self):
    198     with self.test_session() as sess:
    199       reader = io_ops.IdentityReader("test_reader")
    200       work_completed = reader.num_work_units_completed()
    201       produced = reader.num_records_produced()
    202       queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
    203       queued_length = queue.size()
    204       key, value = reader.read(queue)
    205 
    206       queue.enqueue_many([["X", "Y", "Z"]]).run()
    207       self._ExpectRead(sess, key, value, b"X")
    208       self.assertLess(0, queued_length.eval())
    209       self.assertAllEqual(1, produced.eval())
    210 
    211       self._ExpectRead(sess, key, value, b"Y")
    212       self.assertLess(0, work_completed.eval())
    213       self.assertAllEqual(2, produced.eval())
    214 
    215       reader.reset().run()
    216       self.assertAllEqual(0, work_completed.eval())
    217       self.assertAllEqual(0, produced.eval())
    218       self.assertAllEqual(1, queued_length.eval())
    219       self._ExpectRead(sess, key, value, b"Z")
    220 
    221       queue.enqueue_many([["K", "L"]]).run()
    222       self._ExpectRead(sess, key, value, b"K")
    223 
    224 
    225 class WholeFileReaderTest(test.TestCase):
    226 
    227   def setUp(self):
    228     super(WholeFileReaderTest, self).setUp()
    229     self._filenames = [
    230         os.path.join(self.get_temp_dir(), "whole_file.%d.txt" % i)
    231         for i in range(3)
    232     ]
    233     self._content = [b"One\na\nb\n", b"Two\nC\nD", b"Three x, y, z"]
    234     for fn, c in zip(self._filenames, self._content):
    235       with open(fn, "wb") as h:
    236         h.write(c)
    237 
    238   def tearDown(self):
    239     for fn in self._filenames:
    240       os.remove(fn)
    241     super(WholeFileReaderTest, self).tearDown()
    242 
    243   def _ExpectRead(self, sess, key, value, index):
    244     k, v = sess.run([key, value])
    245     self.assertAllEqual(compat.as_bytes(self._filenames[index]), k)
    246     self.assertAllEqual(self._content[index], v)
    247 
    248   def testOneEpoch(self):
    249     with self.test_session() as sess:
    250       reader = io_ops.WholeFileReader("test_reader")
    251       queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
    252       queue.enqueue_many([self._filenames]).run()
    253       queue.close().run()
    254       key, value = reader.read(queue)
    255 
    256       self._ExpectRead(sess, key, value, 0)
    257       self._ExpectRead(sess, key, value, 1)
    258       self._ExpectRead(sess, key, value, 2)
    259 
    260       with self.assertRaisesOpError("is closed and has insufficient elements "
    261                                     "\\(requested 1, current size 0\\)"):
    262         sess.run([key, value])
    263 
    264   def testInfiniteEpochs(self):
    265     with self.test_session() as sess:
    266       reader = io_ops.WholeFileReader("test_reader")
    267       queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
    268       enqueue = queue.enqueue_many([self._filenames])
    269       key, value = reader.read(queue)
    270 
    271       enqueue.run()
    272       self._ExpectRead(sess, key, value, 0)
    273       self._ExpectRead(sess, key, value, 1)
    274       enqueue.run()
    275       self._ExpectRead(sess, key, value, 2)
    276       self._ExpectRead(sess, key, value, 0)
    277       self._ExpectRead(sess, key, value, 1)
    278       enqueue.run()
    279       self._ExpectRead(sess, key, value, 2)
    280       self._ExpectRead(sess, key, value, 0)
    281 
    282 
    283 class TextLineReaderTest(test.TestCase):
    284 
    285   def setUp(self):
    286     super(TextLineReaderTest, self).setUp()
    287     self._num_files = 2
    288     self._num_lines = 5
    289 
    290   def _LineText(self, f, l):
    291     return compat.as_bytes("%d: %d" % (f, l))
    292 
    293   def _CreateFiles(self, crlf=False):
    294     filenames = []
    295     for i in range(self._num_files):
    296       fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i)
    297       filenames.append(fn)
    298       with open(fn, "wb") as f:
    299         for j in range(self._num_lines):
    300           f.write(self._LineText(i, j))
    301           # Always include a newline after the record unless it is
    302           # at the end of the file, in which case we include it sometimes.
    303           if j + 1 != self._num_lines or i == 0:
    304             f.write(b"\r\n" if crlf else b"\n")
    305     return filenames
    306 
    307   def _testOneEpoch(self, files):
    308     with self.test_session() as sess:
    309       reader = io_ops.TextLineReader(name="test_reader")
    310       queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
    311       key, value = reader.read(queue)
    312 
    313       queue.enqueue_many([files]).run()
    314       queue.close().run()
    315       for i in range(self._num_files):
    316         for j in range(self._num_lines):
    317           k, v = sess.run([key, value])
    318           self.assertAllEqual("%s:%d" % (files[i], j + 1), compat.as_text(k))
    319           self.assertAllEqual(self._LineText(i, j), v)
    320 
    321       with self.assertRaisesOpError("is closed and has insufficient elements "
    322                                     "\\(requested 1, current size 0\\)"):
    323         k, v = sess.run([key, value])
    324 
    325   def testOneEpochLF(self):
    326     self._testOneEpoch(self._CreateFiles(crlf=False))
    327 
    328   def testOneEpochCRLF(self):
    329     self._testOneEpoch(self._CreateFiles(crlf=True))
    330 
    331   def testSkipHeaderLines(self):
    332     files = self._CreateFiles()
    333     with self.test_session() as sess:
    334       reader = io_ops.TextLineReader(skip_header_lines=1, name="test_reader")
    335       queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
    336       key, value = reader.read(queue)
    337 
    338       queue.enqueue_many([files]).run()
    339       queue.close().run()
    340       for i in range(self._num_files):
    341         for j in range(self._num_lines - 1):
    342           k, v = sess.run([key, value])
    343           self.assertAllEqual("%s:%d" % (files[i], j + 2), compat.as_text(k))
    344           self.assertAllEqual(self._LineText(i, j + 1), v)
    345 
    346       with self.assertRaisesOpError("is closed and has insufficient elements "
    347                                     "\\(requested 1, current size 0\\)"):
    348         k, v = sess.run([key, value])
    349 
    350 
    351 class FixedLengthRecordReaderTest(test.TestCase):
    352 
    353   def setUp(self):
    354     super(FixedLengthRecordReaderTest, self).setUp()
    355     self._num_files = 2
    356     self._header_bytes = 5
    357     self._record_bytes = 3
    358     self._footer_bytes = 2
    359 
    360     self._hop_bytes = 2
    361 
    362   def _Record(self, f, r):
    363     return compat.as_bytes(str(f * 2 + r) * self._record_bytes)
    364 
    365   def _OverlappedRecord(self, f, r):
    366     record_str = "".join([
    367         str(i)[0]
    368         for i in range(r * self._hop_bytes,
    369                        r * self._hop_bytes + self._record_bytes)
    370     ])
    371     return compat.as_bytes(record_str)
    372 
    373   # gap_bytes=hop_bytes-record_bytes
    374   def _CreateFiles(self, num_records, gap_bytes):
    375     filenames = []
    376     for i in range(self._num_files):
    377       fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i)
    378       filenames.append(fn)
    379       with open(fn, "wb") as f:
    380         f.write(b"H" * self._header_bytes)
    381         if num_records > 0:
    382           f.write(self._Record(i, 0))
    383         for j in range(1, num_records):
    384           if gap_bytes > 0:
    385             f.write(b"G" * gap_bytes)
    386           f.write(self._Record(i, j))
    387         f.write(b"F" * self._footer_bytes)
    388     return filenames
    389 
    390   def _CreateOverlappedRecordFiles(self, num_overlapped_records):
    391     filenames = []
    392     for i in range(self._num_files):
    393       fn = os.path.join(self.get_temp_dir(),
    394                         "fixed_length_overlapped_record.%d.txt" % i)
    395       filenames.append(fn)
    396       with open(fn, "wb") as f:
    397         f.write(b"H" * self._header_bytes)
    398         if num_overlapped_records > 0:
    399           all_records_str = "".join([
    400               str(i)[0]
    401               for i in range(self._record_bytes + self._hop_bytes *
    402                              (num_overlapped_records - 1))
    403           ])
    404           f.write(compat.as_bytes(all_records_str))
    405         f.write(b"F" * self._footer_bytes)
    406     return filenames
    407 
    408   # gap_bytes=hop_bytes-record_bytes
    409   def _CreateGzipFiles(self, num_records, gap_bytes):
    410     filenames = []
    411     for i in range(self._num_files):
    412       fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i)
    413       filenames.append(fn)
    414       with gzip.GzipFile(fn, "wb") as f:
    415         f.write(b"H" * self._header_bytes)
    416         if num_records > 0:
    417           f.write(self._Record(i, 0))
    418         for j in range(1, num_records):
    419           if gap_bytes > 0:
    420             f.write(b"G" * gap_bytes)
    421           f.write(self._Record(i, j))
    422         f.write(b"F" * self._footer_bytes)
    423     return filenames
    424 
    425   # gap_bytes=hop_bytes-record_bytes
    426   def _CreateZlibFiles(self, num_records, gap_bytes):
    427     filenames = []
    428     for i in range(self._num_files):
    429       fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i)
    430       filenames.append(fn)
    431       with open(fn + ".tmp", "wb") as f:
    432         f.write(b"H" * self._header_bytes)
    433         if num_records > 0:
    434           f.write(self._Record(i, 0))
    435         for j in range(1, num_records):
    436           if gap_bytes > 0:
    437             f.write(b"G" * gap_bytes)
    438           f.write(self._Record(i, j))
    439         f.write(b"F" * self._footer_bytes)
    440       with open(fn + ".tmp", "rb") as f:
    441         cdata = zlib.compress(f.read())
    442         with open(fn, "wb") as zf:
    443           zf.write(cdata)
    444     return filenames
    445 
    446   def _CreateGzipOverlappedRecordFiles(self, num_overlapped_records):
    447     filenames = []
    448     for i in range(self._num_files):
    449       fn = os.path.join(self.get_temp_dir(),
    450                         "fixed_length_overlapped_record.%d.txt" % i)
    451       filenames.append(fn)
    452       with gzip.GzipFile(fn, "wb") as f:
    453         f.write(b"H" * self._header_bytes)
    454         if num_overlapped_records > 0:
    455           all_records_str = "".join([
    456               str(i)[0]
    457               for i in range(self._record_bytes + self._hop_bytes *
    458                              (num_overlapped_records - 1))
    459           ])
    460           f.write(compat.as_bytes(all_records_str))
    461         f.write(b"F" * self._footer_bytes)
    462     return filenames
    463 
    464   def _CreateZlibOverlappedRecordFiles(self, num_overlapped_records):
    465     filenames = []
    466     for i in range(self._num_files):
    467       fn = os.path.join(self.get_temp_dir(),
    468                         "fixed_length_overlapped_record.%d.txt" % i)
    469       filenames.append(fn)
    470       with open(fn + ".tmp", "wb") as f:
    471         f.write(b"H" * self._header_bytes)
    472         if num_overlapped_records > 0:
    473           all_records_str = "".join([
    474               str(i)[0]
    475               for i in range(self._record_bytes + self._hop_bytes *
    476                              (num_overlapped_records - 1))
    477           ])
    478           f.write(compat.as_bytes(all_records_str))
    479         f.write(b"F" * self._footer_bytes)
    480       with open(fn + ".tmp", "rb") as f:
    481         cdata = zlib.compress(f.read())
    482         with open(fn, "wb") as zf:
    483           zf.write(cdata)
    484     return filenames
    485 
    486   # gap_bytes=hop_bytes-record_bytes
    487   def _TestOneEpoch(self, files, num_records, gap_bytes, encoding=None):
    488     hop_bytes = 0 if gap_bytes == 0 else self._record_bytes + gap_bytes
    489     with self.test_session() as sess:
    490       reader = io_ops.FixedLengthRecordReader(
    491           header_bytes=self._header_bytes,
    492           record_bytes=self._record_bytes,
    493           footer_bytes=self._footer_bytes,
    494           hop_bytes=hop_bytes,
    495           encoding=encoding,
    496           name="test_reader")
    497       queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
    498       key, value = reader.read(queue)
    499 
    500       queue.enqueue_many([files]).run()
    501       queue.close().run()
    502       for i in range(self._num_files):
    503         for j in range(num_records):
    504           k, v = sess.run([key, value])
    505           self.assertAllEqual("%s:%d" % (files[i], j), compat.as_text(k))
    506           self.assertAllEqual(self._Record(i, j), v)
    507 
    508       with self.assertRaisesOpError("is closed and has insufficient elements "
    509                                     "\\(requested 1, current size 0\\)"):
    510         k, v = sess.run([key, value])
    511 
    512   def _TestOneEpochWithHopBytes(self,
    513                                 files,
    514                                 num_overlapped_records,
    515                                 encoding=None):
    516     with self.test_session() as sess:
    517       reader = io_ops.FixedLengthRecordReader(
    518           header_bytes=self._header_bytes,
    519           record_bytes=self._record_bytes,
    520           footer_bytes=self._footer_bytes,
    521           hop_bytes=self._hop_bytes,
    522           encoding=encoding,
    523           name="test_reader")
    524       queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
    525       key, value = reader.read(queue)
    526 
    527       queue.enqueue_many([files]).run()
    528       queue.close().run()
    529       for i in range(self._num_files):
    530         for j in range(num_overlapped_records):
    531           k, v = sess.run([key, value])
    532           print(v)
    533           self.assertAllEqual("%s:%d" % (files[i], j), compat.as_text(k))
    534           self.assertAllEqual(self._OverlappedRecord(i, j), v)
    535 
    536       with self.assertRaisesOpError("is closed and has insufficient elements "
    537                                     "\\(requested 1, current size 0\\)"):
    538         k, v = sess.run([key, value])
    539 
    540   def testOneEpoch(self):
    541     for num_records in [0, 7]:
    542       # gap_bytes=0: hop_bytes=0
    543       # gap_bytes=1: hop_bytes=record_bytes+1
    544       for gap_bytes in [0, 1]:
    545         files = self._CreateFiles(num_records, gap_bytes)
    546         self._TestOneEpoch(files, num_records, gap_bytes)
    547 
    548   def testGzipOneEpoch(self):
    549     for num_records in [0, 7]:
    550       # gap_bytes=0: hop_bytes=0
    551       # gap_bytes=1: hop_bytes=record_bytes+1
    552       for gap_bytes in [0, 1]:
    553         files = self._CreateGzipFiles(num_records, gap_bytes)
    554         self._TestOneEpoch(files, num_records, gap_bytes, encoding="GZIP")
    555 
    556   def testZlibOneEpoch(self):
    557     for num_records in [0, 7]:
    558       # gap_bytes=0: hop_bytes=0
    559       # gap_bytes=1: hop_bytes=record_bytes+1
    560       for gap_bytes in [0, 1]:
    561         files = self._CreateZlibFiles(num_records, gap_bytes)
    562         self._TestOneEpoch(files, num_records, gap_bytes, encoding="ZLIB")
    563 
    564   def testOneEpochWithHopBytes(self):
    565     for num_overlapped_records in [0, 2]:
    566       files = self._CreateOverlappedRecordFiles(num_overlapped_records)
    567       self._TestOneEpochWithHopBytes(files, num_overlapped_records)
    568 
    569   def testGzipOneEpochWithHopBytes(self):
    570     for num_overlapped_records in [0, 2]:
    571       files = self._CreateGzipOverlappedRecordFiles(num_overlapped_records,)
    572       self._TestOneEpochWithHopBytes(
    573           files, num_overlapped_records, encoding="GZIP")
    574 
    575   def testZlibOneEpochWithHopBytes(self):
    576     for num_overlapped_records in [0, 2]:
    577       files = self._CreateZlibOverlappedRecordFiles(num_overlapped_records)
    578       self._TestOneEpochWithHopBytes(
    579           files, num_overlapped_records, encoding="ZLIB")
    580 
    581 
    582 class TFRecordReaderTest(test.TestCase):
    583 
    584   def setUp(self):
    585     super(TFRecordReaderTest, self).setUp()
    586     self._num_files = 2
    587     self._num_records = 7
    588 
    589   def _Record(self, f, r):
    590     return compat.as_bytes("Record %d of file %d" % (r, f))
    591 
    592   def _CreateFiles(self):
    593     filenames = []
    594     for i in range(self._num_files):
    595       fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
    596       filenames.append(fn)
    597       writer = tf_record.TFRecordWriter(fn)
    598       for j in range(self._num_records):
    599         writer.write(self._Record(i, j))
    600     return filenames
    601 
    602   def testOneEpoch(self):
    603     files = self._CreateFiles()
    604     with self.test_session() as sess:
    605       reader = io_ops.TFRecordReader(name="test_reader")
    606       queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
    607       key, value = reader.read(queue)
    608 
    609       queue.enqueue_many([files]).run()
    610       queue.close().run()
    611       for i in range(self._num_files):
    612         for j in range(self._num_records):
    613           k, v = sess.run([key, value])
    614           self.assertTrue(compat.as_text(k).startswith("%s:" % files[i]))
    615           self.assertAllEqual(self._Record(i, j), v)
    616 
    617       with self.assertRaisesOpError("is closed and has insufficient elements "
    618                                     "\\(requested 1, current size 0\\)"):
    619         k, v = sess.run([key, value])
    620 
    621   def testReadUpTo(self):
    622     files = self._CreateFiles()
    623     with self.test_session() as sess:
    624       reader = io_ops.TFRecordReader(name="test_reader")
    625       queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
    626       batch_size = 3
    627       key, value = reader.read_up_to(queue, batch_size)
    628 
    629       queue.enqueue_many([files]).run()
    630       queue.close().run()
    631       num_k = 0
    632       num_v = 0
    633 
    634       while True:
    635         try:
    636           k, v = sess.run([key, value])
    637           # Test reading *up to* batch_size records
    638           self.assertLessEqual(len(k), batch_size)
    639           self.assertLessEqual(len(v), batch_size)
    640           num_k += len(k)
    641           num_v += len(v)
    642         except errors_impl.OutOfRangeError:
    643           break
    644 
    645       # Test that we have read everything
    646       self.assertEqual(self._num_files * self._num_records, num_k)
    647       self.assertEqual(self._num_files * self._num_records, num_v)
    648 
    649   def testReadZlibFiles(self):
    650     files = self._CreateFiles()
    651     zlib_files = []
    652     for i, fn in enumerate(files):
    653       with open(fn, "rb") as f:
    654         cdata = zlib.compress(f.read())
    655 
    656         zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i)
    657         with open(zfn, "wb") as f:
    658           f.write(cdata)
    659         zlib_files.append(zfn)
    660 
    661     with self.test_session() as sess:
    662       options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB)
    663       reader = io_ops.TFRecordReader(name="test_reader", options=options)
    664       queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
    665       key, value = reader.read(queue)
    666 
    667       queue.enqueue_many([zlib_files]).run()
    668       queue.close().run()
    669       for i in range(self._num_files):
    670         for j in range(self._num_records):
    671           k, v = sess.run([key, value])
    672           self.assertTrue(compat.as_text(k).startswith("%s:" % zlib_files[i]))
    673           self.assertAllEqual(self._Record(i, j), v)
    674 
    675   def testReadGzipFiles(self):
    676     files = self._CreateFiles()
    677     gzip_files = []
    678     for i, fn in enumerate(files):
    679       with open(fn, "rb") as f:
    680         cdata = f.read()
    681 
    682         zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i)
    683         with gzip.GzipFile(zfn, "wb") as f:
    684           f.write(cdata)
    685         gzip_files.append(zfn)
    686 
    687     with self.test_session() as sess:
    688       options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP)
    689       reader = io_ops.TFRecordReader(name="test_reader", options=options)
    690       queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
    691       key, value = reader.read(queue)
    692 
    693       queue.enqueue_many([gzip_files]).run()
    694       queue.close().run()
    695       for i in range(self._num_files):
    696         for j in range(self._num_records):
    697           k, v = sess.run([key, value])
    698           self.assertTrue(compat.as_text(k).startswith("%s:" % gzip_files[i]))
    699           self.assertAllEqual(self._Record(i, j), v)
    700 
    701 
    702 class TFRecordWriterZlibTest(test.TestCase):
    703 
    704   def setUp(self):
    705     super(TFRecordWriterZlibTest, self).setUp()
    706     self._num_files = 2
    707     self._num_records = 7
    708 
    709   def _Record(self, f, r):
    710     return compat.as_bytes("Record %d of file %d" % (r, f))
    711 
    712   def _CreateFiles(self):
    713     filenames = []
    714     for i in range(self._num_files):
    715       fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
    716       filenames.append(fn)
    717       options = tf_record.TFRecordOptions(
    718           compression_type=TFRecordCompressionType.ZLIB)
    719       writer = tf_record.TFRecordWriter(fn, options=options)
    720       for j in range(self._num_records):
    721         writer.write(self._Record(i, j))
    722       writer.close()
    723       del writer
    724 
    725     return filenames
    726 
    727   def _WriteRecordsToFile(self, records, name="tf_record"):
    728     fn = os.path.join(self.get_temp_dir(), name)
    729     writer = tf_record.TFRecordWriter(fn, options=None)
    730     for r in records:
    731       writer.write(r)
    732     writer.close()
    733     del writer
    734     return fn
    735 
    736   def _ZlibCompressFile(self, infile, name="tfrecord.z"):
    737     # zlib compress the file and write compressed contents to file.
    738     with open(infile, "rb") as f:
    739       cdata = zlib.compress(f.read())
    740 
    741     zfn = os.path.join(self.get_temp_dir(), name)
    742     with open(zfn, "wb") as f:
    743       f.write(cdata)
    744     return zfn
    745 
    746   def testOneEpoch(self):
    747     files = self._CreateFiles()
    748     with self.test_session() as sess:
    749       options = tf_record.TFRecordOptions(
    750           compression_type=TFRecordCompressionType.ZLIB)
    751       reader = io_ops.TFRecordReader(name="test_reader", options=options)
    752       queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
    753       key, value = reader.read(queue)
    754 
    755       queue.enqueue_many([files]).run()
    756       queue.close().run()
    757       for i in range(self._num_files):
    758         for j in range(self._num_records):
    759           k, v = sess.run([key, value])
    760           self.assertTrue(compat.as_text(k).startswith("%s:" % files[i]))
    761           self.assertAllEqual(self._Record(i, j), v)
    762 
    763       with self.assertRaisesOpError("is closed and has insufficient elements "
    764                                     "\\(requested 1, current size 0\\)"):
    765         k, v = sess.run([key, value])
    766 
    767   def testZLibFlushRecord(self):
    768     fn = self._WriteRecordsToFile([b"small record"], "small_record")
    769     with open(fn, "rb") as h:
    770       buff = h.read()
    771 
    772     # creating more blocks and trailing blocks shouldn't break reads
    773     compressor = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS)
    774 
    775     output = b""
    776     for c in buff:
    777       if isinstance(c, int):
    778         c = six.int2byte(c)
    779       output += compressor.compress(c)
    780       output += compressor.flush(zlib.Z_FULL_FLUSH)
    781 
    782     output += compressor.flush(zlib.Z_FULL_FLUSH)
    783     output += compressor.flush(zlib.Z_FULL_FLUSH)
    784     output += compressor.flush(zlib.Z_FINISH)
    785 
    786     # overwrite the original file with the compressed data
    787     with open(fn, "wb") as h:
    788       h.write(output)
    789 
    790     with self.test_session() as sess:
    791       options = tf_record.TFRecordOptions(
    792           compression_type=TFRecordCompressionType.ZLIB)
    793       reader = io_ops.TFRecordReader(name="test_reader", options=options)
    794       queue = data_flow_ops.FIFOQueue(1, [dtypes.string], shapes=())
    795       key, value = reader.read(queue)
    796       queue.enqueue(fn).run()
    797       queue.close().run()
    798       k, v = sess.run([key, value])
    799       self.assertTrue(compat.as_text(k).startswith("%s:" % fn))
    800       self.assertAllEqual(b"small record", v)
    801 
    802   def testZlibReadWrite(self):
    803     """Verify that files produced are zlib compatible."""
    804     original = [b"foo", b"bar"]
    805     fn = self._WriteRecordsToFile(original, "zlib_read_write.tfrecord")
    806     zfn = self._ZlibCompressFile(fn, "zlib_read_write.tfrecord.z")
    807 
    808     # read the compressed contents and verify.
    809     actual = []
    810     for r in tf_record.tf_record_iterator(
    811         zfn,
    812         options=tf_record.TFRecordOptions(
    813             tf_record.TFRecordCompressionType.ZLIB)):
    814       actual.append(r)
    815     self.assertEqual(actual, original)
    816 
    817   def testZlibReadWriteLarge(self):
    818     """Verify that writing large contents also works."""
    819 
    820     # Make it large (about 5MB)
    821     original = [_TEXT * 10240]
    822     fn = self._WriteRecordsToFile(original, "zlib_read_write_large.tfrecord")
    823     zfn = self._ZlibCompressFile(fn, "zlib_read_write_large.tfrecord.z")
    824 
    825     # read the compressed contents and verify.
    826     actual = []
    827     for r in tf_record.tf_record_iterator(
    828         zfn,
    829         options=tf_record.TFRecordOptions(
    830             tf_record.TFRecordCompressionType.ZLIB)):
    831       actual.append(r)
    832     self.assertEqual(actual, original)
    833 
    834   def testGzipReadWrite(self):
    835     """Verify that files produced are gzip compatible."""
    836     original = [b"foo", b"bar"]
    837     fn = self._WriteRecordsToFile(original, "gzip_read_write.tfrecord")
    838 
    839     # gzip compress the file and write compressed contents to file.
    840     with open(fn, "rb") as f:
    841       cdata = f.read()
    842     gzfn = os.path.join(self.get_temp_dir(), "tf_record.gz")
    843     with gzip.GzipFile(gzfn, "wb") as f:
    844       f.write(cdata)
    845 
    846     actual = []
    847     for r in tf_record.tf_record_iterator(
    848         gzfn, options=tf_record.TFRecordOptions(TFRecordCompressionType.GZIP)):
    849       actual.append(r)
    850     self.assertEqual(actual, original)
    851 
    852 
    853 class TFRecordIteratorTest(test.TestCase):
    854 
    855   def setUp(self):
    856     super(TFRecordIteratorTest, self).setUp()
    857     self._num_records = 7
    858 
    859   def _Record(self, r):
    860     return compat.as_bytes("Record %d" % r)
    861 
    862   def _WriteCompressedRecordsToFile(
    863       self,
    864       records,
    865       name="tfrecord.z",
    866       compression_type=tf_record.TFRecordCompressionType.ZLIB):
    867     fn = os.path.join(self.get_temp_dir(), name)
    868     options = tf_record.TFRecordOptions(compression_type=compression_type)
    869     writer = tf_record.TFRecordWriter(fn, options=options)
    870     for r in records:
    871       writer.write(r)
    872     writer.close()
    873     del writer
    874     return fn
    875 
    876   def _ZlibDecompressFile(self, infile, name="tfrecord", wbits=zlib.MAX_WBITS):
    877     with open(infile, "rb") as f:
    878       cdata = zlib.decompress(f.read(), wbits)
    879     zfn = os.path.join(self.get_temp_dir(), name)
    880     with open(zfn, "wb") as f:
    881       f.write(cdata)
    882     return zfn
    883 
    884   def testIterator(self):
    885     fn = self._WriteCompressedRecordsToFile(
    886         [self._Record(i) for i in range(self._num_records)],
    887         "compressed_records")
    888     options = tf_record.TFRecordOptions(
    889         compression_type=TFRecordCompressionType.ZLIB)
    890     reader = tf_record.tf_record_iterator(fn, options)
    891     for i in range(self._num_records):
    892       record = next(reader)
    893       self.assertAllEqual(self._Record(i), record)
    894     with self.assertRaises(StopIteration):
    895       record = next(reader)
    896 
    897   def testWriteZlibRead(self):
    898     """Verify compression with TFRecordWriter is zlib library compatible."""
    899     original = [b"foo", b"bar"]
    900     fn = self._WriteCompressedRecordsToFile(original,
    901                                             "write_zlib_read.tfrecord.z")
    902     zfn = self._ZlibDecompressFile(fn, "write_zlib_read.tfrecord")
    903     actual = []
    904     for r in tf_record.tf_record_iterator(zfn):
    905       actual.append(r)
    906     self.assertEqual(actual, original)
    907 
    908   def testWriteZlibReadLarge(self):
    909     """Verify compression for large records is zlib library compatible."""
    910     # Make it large (about 5MB)
    911     original = [_TEXT * 10240]
    912     fn = self._WriteCompressedRecordsToFile(original,
    913                                             "write_zlib_read_large.tfrecord.z")
    914     zfn = self._ZlibDecompressFile(fn, "write_zlib_read_large.tf_record")
    915     actual = []
    916     for r in tf_record.tf_record_iterator(zfn):
    917       actual.append(r)
    918     self.assertEqual(actual, original)
    919 
    920   def testWriteGzipRead(self):
    921     original = [b"foo", b"bar"]
    922     fn = self._WriteCompressedRecordsToFile(
    923         original,
    924         "write_gzip_read.tfrecord.gz",
    925         compression_type=TFRecordCompressionType.GZIP)
    926 
    927     with gzip.GzipFile(fn, "rb") as f:
    928       cdata = f.read()
    929     zfn = os.path.join(self.get_temp_dir(), "tf_record")
    930     with open(zfn, "wb") as f:
    931       f.write(cdata)
    932 
    933     actual = []
    934     for r in tf_record.tf_record_iterator(zfn):
    935       actual.append(r)
    936     self.assertEqual(actual, original)
    937 
    938   def testBadFile(self):
    939     """Verify that tf_record_iterator throws an exception on bad TFRecords."""
    940     fn = os.path.join(self.get_temp_dir(), "bad_file")
    941     with tf_record.TFRecordWriter(fn) as writer:
    942       writer.write(b"123")
    943     fn_truncated = os.path.join(self.get_temp_dir(), "bad_file_truncated")
    944     with open(fn, "rb") as f:
    945       with open(fn_truncated, "wb") as f2:
    946         # DataLossError requires that we've written the header, so this must
    947         # be at least 12 bytes.
    948         f2.write(f.read(14))
    949     with self.assertRaises(errors_impl.DataLossError):
    950       for _ in tf_record.tf_record_iterator(fn_truncated):
    951         pass
    952 
    953 
    954 class AsyncReaderTest(test.TestCase):
    955 
    956   def testNoDeadlockFromQueue(self):
    957     """Tests that reading does not block main execution threads."""
    958     config = config_pb2.ConfigProto(
    959         inter_op_parallelism_threads=1, intra_op_parallelism_threads=1)
    960     with self.test_session(config=config) as sess:
    961       thread_data_t = collections.namedtuple("thread_data_t",
    962                                              ["thread", "queue", "output"])
    963       thread_data = []
    964 
    965       # Create different readers, each with its own queue.
    966       for i in range(3):
    967         queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
    968         reader = io_ops.TextLineReader()
    969         _, line = reader.read(queue)
    970         output = []
    971         t = threading.Thread(
    972             target=AsyncReaderTest._RunSessionAndSave,
    973             args=(sess, [line], output))
    974         thread_data.append(thread_data_t(t, queue, output))
    975 
    976       # Start all readers. They are all blocked waiting for queue entries.
    977       sess.run(variables.global_variables_initializer())
    978       for d in thread_data:
    979         d.thread.start()
    980 
    981       # Unblock the readers.
    982       for i, d in enumerate(reversed(thread_data)):
    983         fname = os.path.join(self.get_temp_dir(), "deadlock.%s.txt" % i)
    984         with open(fname, "wb") as f:
    985           f.write(("file-%s" % i).encode())
    986         d.queue.enqueue_many([[fname]]).run()
    987         d.thread.join()
    988         self.assertEqual([[("file-%s" % i).encode()]], d.output)
    989 
    990   @staticmethod
    991   def _RunSessionAndSave(sess, args, output):
    992     output.append(sess.run(args))
    993 
    994 
    995 class LMDBReaderTest(test.TestCase):
    996 
    997   def setUp(self):
    998     super(LMDBReaderTest, self).setUp()
    999     # Copy database out because we need the path to be writable to use locks.
   1000     path = os.path.join(prefix_path, "lmdb", "testdata", "data.mdb")
   1001     self.db_path = os.path.join(self.get_temp_dir(), "data.mdb")
   1002     shutil.copy(path, self.db_path)
   1003 
   1004   def testReadFromFile(self):
   1005     with self.test_session() as sess:
   1006       reader = io_ops.LMDBReader(name="test_read_from_file")
   1007       queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
   1008       key, value = reader.read(queue)
   1009 
   1010       queue.enqueue([self.db_path]).run()
   1011       queue.close().run()
   1012       for i in range(10):
   1013         k, v = sess.run([key, value])
   1014         self.assertAllEqual(compat.as_bytes(k), compat.as_bytes(str(i)))
   1015         self.assertAllEqual(
   1016             compat.as_bytes(v), compat.as_bytes(str(chr(ord("a") + i))))
   1017 
   1018       with self.assertRaisesOpError("is closed and has insufficient elements "
   1019                                     "\\(requested 1, current size 0\\)"):
   1020         k, v = sess.run([key, value])
   1021 
   1022   def testReadFromSameFile(self):
   1023     with self.test_session() as sess:
   1024       reader1 = io_ops.LMDBReader(name="test_read_from_same_file1")
   1025       reader2 = io_ops.LMDBReader(name="test_read_from_same_file2")
   1026       filename_queue = input_lib.string_input_producer(
   1027           [self.db_path], num_epochs=None)
   1028       key1, value1 = reader1.read(filename_queue)
   1029       key2, value2 = reader2.read(filename_queue)
   1030 
   1031       coord = coordinator.Coordinator()
   1032       threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
   1033       for _ in range(3):
   1034         for _ in range(10):
   1035           k1, v1, k2, v2 = sess.run([key1, value1, key2, value2])
   1036           self.assertAllEqual(compat.as_bytes(k1), compat.as_bytes(k2))
   1037           self.assertAllEqual(compat.as_bytes(v1), compat.as_bytes(v2))
   1038       coord.request_stop()
   1039       coord.join(threads)
   1040 
   1041   def testReadFromFolder(self):
   1042     with self.test_session() as sess:
   1043       reader = io_ops.LMDBReader(name="test_read_from_folder")
   1044       queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
   1045       key, value = reader.read(queue)
   1046 
   1047       queue.enqueue([self.db_path]).run()
   1048       queue.close().run()
   1049       for i in range(10):
   1050         k, v = sess.run([key, value])
   1051         self.assertAllEqual(compat.as_bytes(k), compat.as_bytes(str(i)))
   1052         self.assertAllEqual(
   1053             compat.as_bytes(v), compat.as_bytes(str(chr(ord("a") + i))))
   1054 
   1055       with self.assertRaisesOpError("is closed and has insufficient elements "
   1056                                     "\\(requested 1, current size 0\\)"):
   1057         k, v = sess.run([key, value])
   1058 
   1059   def testReadFromFileRepeatedly(self):
   1060     with self.test_session() as sess:
   1061       reader = io_ops.LMDBReader(name="test_read_from_file_repeated")
   1062       filename_queue = input_lib.string_input_producer(
   1063           [self.db_path], num_epochs=None)
   1064       key, value = reader.read(filename_queue)
   1065 
   1066       coord = coordinator.Coordinator()
   1067       threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
   1068       # Iterate over the lmdb 3 times.
   1069       for _ in range(3):
   1070         # Go over all 10 records each time.
   1071         for j in range(10):
   1072           k, v = sess.run([key, value])
   1073           self.assertAllEqual(compat.as_bytes(k), compat.as_bytes(str(j)))
   1074           self.assertAllEqual(
   1075               compat.as_bytes(v), compat.as_bytes(str(chr(ord("a") + j))))
   1076       coord.request_stop()
   1077       coord.join(threads)
   1078 
   1079 
   1080 if __name__ == "__main__":
   1081   test.main()
   1082