1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for the experimental input pipeline ops.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import gzip 21 import os 22 import zlib 23 24 from tensorflow.python.data.ops import iterator_ops 25 from tensorflow.python.data.ops import readers 26 from tensorflow.python.framework import constant_op 27 from tensorflow.python.framework import dtypes 28 from tensorflow.python.framework import errors 29 from tensorflow.python.framework import ops 30 from tensorflow.python.framework import tensor_shape 31 from tensorflow.python.lib.io import python_io 32 from tensorflow.python.ops import array_ops 33 from tensorflow.python.ops import gen_dataset_ops 34 from tensorflow.python.ops import io_ops 35 from tensorflow.python.ops import parsing_ops 36 from tensorflow.python.platform import test 37 from tensorflow.python.util import compat 38 39 40 class TextLineDatasetTest(test.TestCase): 41 42 def _lineText(self, f, l): 43 return compat.as_bytes("%d: %d" % (f, l)) 44 45 def _createFiles(self, 46 num_files, 47 num_lines, 48 crlf=False, 49 compression_type=None): 50 filenames = [] 51 for i in range(num_files): 52 fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i) 53 filenames.append(fn) 54 contents = [] 55 for j in range(num_lines): 56 contents.append(self._lineText(i, j)) 57 # Always include a newline after the record unless it is 58 # at the end of the file, in which case we include it 59 if j + 1 != num_lines or i == 0: 60 contents.append(b"\r\n" if crlf else b"\n") 61 contents = b"".join(contents) 62 63 if not compression_type: 64 with open(fn, "wb") as f: 65 f.write(contents) 66 elif compression_type == "GZIP": 67 with gzip.GzipFile(fn, "wb") as f: 68 f.write(contents) 69 elif compression_type == "ZLIB": 70 contents = zlib.compress(contents) 71 with open(fn, "wb") as f: 72 f.write(contents) 73 else: 74 raise ValueError("Unsupported compression_type", compression_type) 75 76 return filenames 77 78 def _testTextLineDataset(self, compression_type=None): 79 test_filenames = self._createFiles( 80 2, 5, crlf=True, compression_type=compression_type) 81 filenames = array_ops.placeholder(dtypes.string, shape=[None]) 82 num_epochs = array_ops.placeholder(dtypes.int64, shape=[]) 83 batch_size = array_ops.placeholder(dtypes.int64, shape=[]) 84 85 repeat_dataset = readers.TextLineDataset( 86 filenames, compression_type=compression_type).repeat(num_epochs) 87 batch_dataset = repeat_dataset.batch(batch_size) 88 89 iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) 90 init_op = iterator.make_initializer(repeat_dataset) 91 init_batch_op = iterator.make_initializer(batch_dataset) 92 get_next = iterator.get_next() 93 94 with self.test_session() as sess: 95 # Basic test: read from file 0. 96 sess.run( 97 init_op, feed_dict={filenames: [test_filenames[0]], 98 num_epochs: 1}) 99 for i in range(5): 100 self.assertEqual(self._lineText(0, i), sess.run(get_next)) 101 with self.assertRaises(errors.OutOfRangeError): 102 sess.run(get_next) 103 104 # Basic test: read from file 1. 105 sess.run( 106 init_op, feed_dict={filenames: [test_filenames[1]], 107 num_epochs: 1}) 108 for i in range(5): 109 self.assertEqual(self._lineText(1, i), sess.run(get_next)) 110 with self.assertRaises(errors.OutOfRangeError): 111 sess.run(get_next) 112 113 # Basic test: read from both files. 114 sess.run(init_op, feed_dict={filenames: test_filenames, num_epochs: 1}) 115 for j in range(2): 116 for i in range(5): 117 self.assertEqual(self._lineText(j, i), sess.run(get_next)) 118 with self.assertRaises(errors.OutOfRangeError): 119 sess.run(get_next) 120 121 # Test repeated iteration through both files. 122 sess.run(init_op, feed_dict={filenames: test_filenames, num_epochs: 10}) 123 for _ in range(10): 124 for j in range(2): 125 for i in range(5): 126 self.assertEqual(self._lineText(j, i), sess.run(get_next)) 127 with self.assertRaises(errors.OutOfRangeError): 128 sess.run(get_next) 129 130 # Test batched and repeated iteration through both files. 131 sess.run( 132 init_batch_op, 133 feed_dict={filenames: test_filenames, 134 num_epochs: 10, 135 batch_size: 5}) 136 for _ in range(10): 137 self.assertAllEqual([self._lineText(0, i) for i in range(5)], 138 sess.run(get_next)) 139 self.assertAllEqual([self._lineText(1, i) for i in range(5)], 140 sess.run(get_next)) 141 142 def testTextLineDatasetNoCompression(self): 143 self._testTextLineDataset() 144 145 def testTextLineDatasetGzipCompression(self): 146 self._testTextLineDataset(compression_type="GZIP") 147 148 def testTextLineDatasetZlibCompression(self): 149 self._testTextLineDataset(compression_type="ZLIB") 150 151 def testTextLineDatasetBuffering(self): 152 test_filenames = self._createFiles(2, 5, crlf=True) 153 154 repeat_dataset = readers.TextLineDataset(test_filenames, buffer_size=10) 155 iterator = repeat_dataset.make_one_shot_iterator() 156 157 with self.test_session() as sess: 158 for j in range(2): 159 for i in range(5): 160 self.assertEqual(self._lineText(j, i), sess.run(iterator.get_next())) 161 with self.assertRaises(errors.OutOfRangeError): 162 sess.run(iterator.get_next()) 163 164 165 class FixedLengthRecordReaderTest(test.TestCase): 166 167 def setUp(self): 168 super(FixedLengthRecordReaderTest, self).setUp() 169 self._num_files = 2 170 self._num_records = 7 171 self._header_bytes = 5 172 self._record_bytes = 3 173 self._footer_bytes = 2 174 175 def _record(self, f, r): 176 return compat.as_bytes(str(f * 2 + r) * self._record_bytes) 177 178 def _createFiles(self): 179 filenames = [] 180 for i in range(self._num_files): 181 fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i) 182 filenames.append(fn) 183 with open(fn, "wb") as f: 184 f.write(b"H" * self._header_bytes) 185 for j in range(self._num_records): 186 f.write(self._record(i, j)) 187 f.write(b"F" * self._footer_bytes) 188 return filenames 189 190 def testFixedLengthRecordDataset(self): 191 test_filenames = self._createFiles() 192 filenames = array_ops.placeholder(dtypes.string, shape=[None]) 193 num_epochs = array_ops.placeholder(dtypes.int64, shape=[]) 194 batch_size = array_ops.placeholder(dtypes.int64, shape=[]) 195 196 repeat_dataset = (readers.FixedLengthRecordDataset( 197 filenames, self._record_bytes, self._header_bytes, self._footer_bytes) 198 .repeat(num_epochs)) 199 batch_dataset = repeat_dataset.batch(batch_size) 200 201 iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) 202 init_op = iterator.make_initializer(repeat_dataset) 203 init_batch_op = iterator.make_initializer(batch_dataset) 204 get_next = iterator.get_next() 205 206 with self.test_session() as sess: 207 # Basic test: read from file 0. 208 sess.run( 209 init_op, feed_dict={filenames: [test_filenames[0]], 210 num_epochs: 1}) 211 for i in range(self._num_records): 212 self.assertEqual(self._record(0, i), sess.run(get_next)) 213 with self.assertRaises(errors.OutOfRangeError): 214 sess.run(get_next) 215 216 # Basic test: read from file 1. 217 sess.run( 218 init_op, feed_dict={filenames: [test_filenames[1]], 219 num_epochs: 1}) 220 for i in range(self._num_records): 221 self.assertEqual(self._record(1, i), sess.run(get_next)) 222 with self.assertRaises(errors.OutOfRangeError): 223 sess.run(get_next) 224 225 # Basic test: read from both files. 226 sess.run(init_op, feed_dict={filenames: test_filenames, num_epochs: 1}) 227 for j in range(self._num_files): 228 for i in range(self._num_records): 229 self.assertEqual(self._record(j, i), sess.run(get_next)) 230 with self.assertRaises(errors.OutOfRangeError): 231 sess.run(get_next) 232 233 # Test repeated iteration through both files. 234 sess.run(init_op, feed_dict={filenames: test_filenames, num_epochs: 10}) 235 for _ in range(10): 236 for j in range(self._num_files): 237 for i in range(self._num_records): 238 self.assertEqual(self._record(j, i), sess.run(get_next)) 239 with self.assertRaises(errors.OutOfRangeError): 240 sess.run(get_next) 241 242 # Test batched and repeated iteration through both files. 243 sess.run( 244 init_batch_op, 245 feed_dict={ 246 filenames: test_filenames, 247 num_epochs: 10, 248 batch_size: self._num_records 249 }) 250 for _ in range(10): 251 for j in range(self._num_files): 252 self.assertAllEqual( 253 [self._record(j, i) for i in range(self._num_records)], 254 sess.run(get_next)) 255 with self.assertRaises(errors.OutOfRangeError): 256 sess.run(get_next) 257 258 def testFixedLengthRecordDatasetBuffering(self): 259 test_filenames = self._createFiles() 260 dataset = readers.FixedLengthRecordDataset( 261 test_filenames, 262 self._record_bytes, 263 self._header_bytes, 264 self._footer_bytes, 265 buffer_size=10) 266 iterator = dataset.make_one_shot_iterator() 267 268 with self.test_session() as sess: 269 for j in range(self._num_files): 270 for i in range(self._num_records): 271 self.assertEqual(self._record(j, i), sess.run(iterator.get_next())) 272 with self.assertRaises(errors.OutOfRangeError): 273 sess.run(iterator.get_next()) 274 275 def testFixedLengthRecordDatasetWrongSize(self): 276 test_filenames = self._createFiles() 277 dataset = readers.FixedLengthRecordDataset( 278 test_filenames, 279 self._record_bytes + 1, # Incorrect record length. 280 self._header_bytes, 281 self._footer_bytes, 282 buffer_size=10) 283 iterator = dataset.make_one_shot_iterator() 284 285 with self.test_session() as sess: 286 with self.assertRaisesRegexp( 287 errors.InvalidArgumentError, 288 r"Excluding the header \(5 bytes\) and footer \(2 bytes\), input " 289 r"file \".*fixed_length_record.0.txt\" has body length 21 bytes, " 290 r"which is not an exact multiple of the record length \(4 bytes\)."): 291 sess.run(iterator.get_next()) 292 293 def _iterator_checkpoint_path(self): 294 return os.path.join(self.get_temp_dir(), "iterator") 295 296 def _save_op(self, iterator_resource): 297 iterator_state_variant = gen_dataset_ops.serialize_iterator( 298 iterator_resource) 299 save_op = io_ops.write_file( 300 self._iterator_checkpoint_path(), 301 parsing_ops.serialize_tensor(iterator_state_variant)) 302 return save_op 303 304 def _restore_op(self, iterator_resource): 305 iterator_state_variant = parsing_ops.parse_tensor( 306 io_ops.read_file(self._iterator_checkpoint_path()), dtypes.variant) 307 restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, 308 iterator_state_variant) 309 return restore_op 310 311 def _build_iterator_graph(self, num_epochs): 312 filenames = self._createFiles() 313 dataset = (readers.FixedLengthRecordDataset( 314 filenames, self._record_bytes, self._header_bytes, self._footer_bytes) 315 .repeat(num_epochs)) 316 iterator = dataset.make_initializable_iterator() 317 init_op = iterator.initializer 318 get_next_op = iterator.get_next() 319 save_op = self._save_op(iterator._iterator_resource) 320 restore_op = self._restore_op(iterator._iterator_resource) 321 return init_op, get_next_op, save_op, restore_op 322 323 def _restore_iterator(self): 324 output_types = dtypes.string 325 output_shapes = tensor_shape.scalar() 326 iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes) 327 get_next = iterator.get_next() 328 restore_op = self._restore_op(iterator._iterator_resource) 329 return restore_op, get_next 330 331 def testSaveRestore(self): 332 num_epochs = 10 333 epoch_break = 5 334 file_break = self._num_files // 2 335 record_break = self._num_records // 2 336 337 with ops.Graph().as_default() as g: 338 init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( 339 num_epochs=num_epochs) 340 with self.test_session(graph=g) as sess: 341 sess.run(init_op) 342 # Note: There is no checkpoint saved currently so a NotFoundError is 343 # raised. 344 with self.assertRaises(errors.NotFoundError): 345 sess.run(restore_op) 346 for epoch in range(num_epochs): 347 for f in range(self._num_files): 348 for r in range(self._num_records): 349 if (epoch == epoch_break and f == file_break and 350 r == record_break): 351 sess.run(save_op) 352 break 353 self.assertEqual(self._record(f, r), sess.run(get_next_op)) 354 else: 355 continue 356 break 357 else: 358 continue 359 break 360 else: 361 with self.assertRaises(errors.OutOfRangeError): 362 sess.run(get_next_op) 363 364 with ops.Graph().as_default() as g: 365 init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( 366 num_epochs=num_epochs) 367 with self.test_session(graph=g) as sess: 368 sess.run(restore_op) 369 for epoch in range(num_epochs): 370 for f in range(self._num_files): 371 for r in range(self._num_records): 372 if (epoch < epoch_break or 373 (epoch == epoch_break and f < file_break) or 374 (epoch == epoch_break and f == file_break and 375 r < record_break)): 376 continue 377 self.assertEqual(self._record(f, r), sess.run(get_next_op)) 378 with self.assertRaises(errors.OutOfRangeError): 379 sess.run(get_next_op) 380 381 def testInitThenRestore(self): 382 # Note: Calling init_op before restore_op is redundant. This test just makes 383 # sure we do not fail if restore is called on an already initialized 384 # iterator resource. 385 num_epochs = 10 386 epoch_break = 5 387 file_break = self._num_files // 2 388 record_break = self._num_records // 2 389 390 with ops.Graph().as_default() as g: 391 init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( 392 num_epochs=num_epochs) 393 with self.test_session(graph=g) as sess: 394 sess.run(init_op) 395 # Note: There is no checkpoint saved currently so a NotFoundError is 396 # raised. 397 with self.assertRaises(errors.NotFoundError): 398 sess.run(restore_op) 399 for epoch in range(num_epochs): 400 for f in range(self._num_files): 401 for r in range(self._num_records): 402 if (epoch == epoch_break and f == file_break and 403 r == record_break): 404 sess.run(save_op) 405 break 406 self.assertEqual(self._record(f, r), sess.run(get_next_op)) 407 else: 408 continue 409 break 410 else: 411 continue 412 break 413 else: 414 with self.assertRaises(errors.OutOfRangeError): 415 sess.run(get_next_op) 416 417 with ops.Graph().as_default() as g: 418 init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( 419 num_epochs=num_epochs) 420 with self.test_session(graph=g) as sess: 421 sess.run(init_op) 422 sess.run(restore_op) 423 for epoch in range(num_epochs): 424 for f in range(self._num_files): 425 for r in range(self._num_records): 426 if (epoch < epoch_break or 427 (epoch == epoch_break and f < file_break) or 428 (epoch == epoch_break and f == file_break and 429 r < record_break)): 430 continue 431 self.assertEqual(self._record(f, r), sess.run(get_next_op)) 432 with self.assertRaises(errors.OutOfRangeError): 433 sess.run(get_next_op) 434 435 def testRestoreInModifiedGraph(self): 436 num_epochs = 10 437 num_epochs_1 = 20 438 epoch_break = 5 439 file_break = self._num_files // 2 440 record_break = self._num_records // 2 441 442 with ops.Graph().as_default() as g: 443 init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( 444 num_epochs=num_epochs) 445 with self.test_session(graph=g) as sess: 446 sess.run(init_op) 447 # Note: There is no checkpoint saved currently so a NotFoundError is 448 # raised. 449 with self.assertRaises(errors.NotFoundError): 450 sess.run(restore_op) 451 for epoch in range(num_epochs): 452 for f in range(self._num_files): 453 for r in range(self._num_records): 454 if (epoch == epoch_break and f == file_break and 455 r == record_break): 456 sess.run(save_op) 457 break 458 self.assertEqual(self._record(f, r), sess.run(get_next_op)) 459 else: 460 continue 461 break 462 else: 463 continue 464 break 465 else: 466 with self.assertRaises(errors.OutOfRangeError): 467 sess.run(get_next_op) 468 469 with ops.Graph().as_default() as g: 470 init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( 471 num_epochs=num_epochs_1) 472 with self.test_session(graph=g) as sess: 473 sess.run(restore_op) 474 for epoch in range(num_epochs): 475 for f in range(self._num_files): 476 for r in range(self._num_records): 477 if (epoch < epoch_break or 478 (epoch == epoch_break and f < file_break) or 479 (epoch == epoch_break and f == file_break and 480 r < record_break)): 481 continue 482 self.assertEqual(self._record(f, r), sess.run(get_next_op)) 483 with self.assertRaises(errors.OutOfRangeError): 484 sess.run(get_next_op) 485 486 def testRestoreWithoutBuildingDatasetGraph(self): 487 num_epochs = 10 488 epoch_break = 5 489 file_break = self._num_files // 2 490 record_break = self._num_records // 2 491 492 with ops.Graph().as_default() as g: 493 init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( 494 num_epochs=num_epochs) 495 with self.test_session(graph=g) as sess: 496 sess.run(init_op) 497 # Note: There is no checkpoint saved currently so a NotFoundError is 498 # raised. 499 with self.assertRaises(errors.NotFoundError): 500 sess.run(restore_op) 501 for epoch in range(num_epochs): 502 for f in range(self._num_files): 503 for r in range(self._num_records): 504 if (epoch == epoch_break and f == file_break and 505 r == record_break): 506 sess.run(save_op) 507 break 508 self.assertEqual(self._record(f, r), sess.run(get_next_op)) 509 else: 510 continue 511 break 512 else: 513 continue 514 break 515 else: 516 with self.assertRaises(errors.OutOfRangeError): 517 sess.run(get_next_op) 518 519 with ops.Graph().as_default() as g: 520 restore_op, get_next_op = self._restore_iterator() 521 with self.test_session(graph=g) as sess: 522 sess.run(restore_op) 523 for epoch in range(num_epochs): 524 for f in range(self._num_files): 525 for r in range(self._num_records): 526 if (epoch < epoch_break or 527 (epoch == epoch_break and f < file_break) or 528 (epoch == epoch_break and f == file_break and 529 r < record_break)): 530 continue 531 self.assertEqual(self._record(f, r), sess.run(get_next_op)) 532 with self.assertRaises(errors.OutOfRangeError): 533 sess.run(get_next_op) 534 535 def testRestoreUnusedIterator(self): 536 num_epochs = 10 537 with ops.Graph().as_default() as g: 538 init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( 539 num_epochs=num_epochs) 540 with self.test_session(graph=g) as sess: 541 sess.run(init_op) 542 # Note: There is no checkpoint saved currently so a NotFoundError is 543 # raised. 544 with self.assertRaises(errors.NotFoundError): 545 sess.run(restore_op) 546 # Save unused iterator. 547 sess.run(save_op) 548 with ops.Graph().as_default() as g: 549 init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( 550 num_epochs=num_epochs) 551 with self.test_session(graph=g) as sess: 552 sess.run(restore_op) 553 for _ in range(num_epochs * self._num_files * self._num_records): 554 sess.run(get_next_op) 555 with self.assertRaises(errors.OutOfRangeError): 556 sess.run(get_next_op) 557 558 def testRestoreExhaustedIterator(self): 559 num_epochs = 10 560 561 with ops.Graph().as_default() as g: 562 init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( 563 num_epochs=num_epochs) 564 with self.test_session(graph=g) as sess: 565 sess.run(init_op) 566 # Note: There is no checkpoint saved currently so a NotFoundError is 567 # raised. 568 with self.assertRaises(errors.NotFoundError): 569 sess.run(restore_op) 570 for _ in range(num_epochs): 571 for f in range(self._num_files): 572 for r in range(self._num_records): 573 self.assertEqual(self._record(f, r), sess.run(get_next_op)) 574 with self.assertRaises(errors.OutOfRangeError): 575 sess.run(get_next_op) 576 sess.run(save_op) 577 578 with ops.Graph().as_default() as g: 579 init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( 580 num_epochs=num_epochs) 581 with self.test_session(graph=g) as sess: 582 sess.run(restore_op) 583 with self.assertRaises(errors.OutOfRangeError): 584 sess.run(get_next_op) 585 586 587 class TFRecordDatasetTest(test.TestCase): 588 589 def setUp(self): 590 super(TFRecordDatasetTest, self).setUp() 591 self._num_files = 2 592 self._num_records = 7 593 594 self.test_filenames = self._createFiles() 595 596 self.filenames = array_ops.placeholder(dtypes.string, shape=[None]) 597 self.num_epochs = array_ops.placeholder_with_default( 598 constant_op.constant(1, dtypes.int64), shape=[]) 599 self.compression_type = array_ops.placeholder_with_default("", shape=[]) 600 self.batch_size = array_ops.placeholder(dtypes.int64, shape=[]) 601 602 repeat_dataset = readers.TFRecordDataset(self.filenames, 603 self.compression_type).repeat( 604 self.num_epochs) 605 batch_dataset = repeat_dataset.batch(self.batch_size) 606 607 iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) 608 self.init_op = iterator.make_initializer(repeat_dataset) 609 self.init_batch_op = iterator.make_initializer(batch_dataset) 610 self.get_next = iterator.get_next() 611 612 def _record(self, f, r): 613 return compat.as_bytes("Record %d of file %d" % (r, f)) 614 615 def _createFiles(self): 616 filenames = [] 617 for i in range(self._num_files): 618 fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) 619 filenames.append(fn) 620 writer = python_io.TFRecordWriter(fn) 621 for j in range(self._num_records): 622 writer.write(self._record(i, j)) 623 writer.close() 624 return filenames 625 626 def testReadOneEpoch(self): 627 with self.test_session() as sess: 628 # Basic test: read from file 0. 629 sess.run( 630 self.init_op, 631 feed_dict={ 632 self.filenames: [self.test_filenames[0]], 633 self.num_epochs: 1 634 }) 635 for i in range(self._num_records): 636 self.assertAllEqual(self._record(0, i), sess.run(self.get_next)) 637 with self.assertRaises(errors.OutOfRangeError): 638 sess.run(self.get_next) 639 640 # Basic test: read from file 1. 641 sess.run( 642 self.init_op, 643 feed_dict={ 644 self.filenames: [self.test_filenames[1]], 645 self.num_epochs: 1 646 }) 647 for i in range(self._num_records): 648 self.assertAllEqual(self._record(1, i), sess.run(self.get_next)) 649 with self.assertRaises(errors.OutOfRangeError): 650 sess.run(self.get_next) 651 652 # Basic test: read from both files. 653 sess.run( 654 self.init_op, 655 feed_dict={self.filenames: self.test_filenames, 656 self.num_epochs: 1}) 657 for j in range(self._num_files): 658 for i in range(self._num_records): 659 self.assertAllEqual(self._record(j, i), sess.run(self.get_next)) 660 with self.assertRaises(errors.OutOfRangeError): 661 sess.run(self.get_next) 662 663 def testReadTenEpochs(self): 664 with self.test_session() as sess: 665 sess.run( 666 self.init_op, 667 feed_dict={self.filenames: self.test_filenames, 668 self.num_epochs: 10}) 669 for _ in range(10): 670 for j in range(self._num_files): 671 for i in range(self._num_records): 672 self.assertAllEqual(self._record(j, i), sess.run(self.get_next)) 673 with self.assertRaises(errors.OutOfRangeError): 674 sess.run(self.get_next) 675 676 def testReadTenEpochsOfBatches(self): 677 with self.test_session() as sess: 678 sess.run( 679 self.init_batch_op, 680 feed_dict={ 681 self.filenames: self.test_filenames, 682 self.num_epochs: 10, 683 self.batch_size: self._num_records 684 }) 685 for _ in range(10): 686 for j in range(self._num_files): 687 values = sess.run(self.get_next) 688 self.assertAllEqual( 689 [self._record(j, i) for i in range(self._num_records)], values) 690 with self.assertRaises(errors.OutOfRangeError): 691 sess.run(self.get_next) 692 693 def testReadZlibFiles(self): 694 zlib_files = [] 695 for i, fn in enumerate(self.test_filenames): 696 with open(fn, "rb") as f: 697 cdata = zlib.compress(f.read()) 698 699 zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i) 700 with open(zfn, "wb") as f: 701 f.write(cdata) 702 zlib_files.append(zfn) 703 704 with self.test_session() as sess: 705 sess.run( 706 self.init_op, 707 feed_dict={self.filenames: zlib_files, 708 self.compression_type: "ZLIB"}) 709 for j in range(self._num_files): 710 for i in range(self._num_records): 711 self.assertAllEqual(self._record(j, i), sess.run(self.get_next)) 712 with self.assertRaises(errors.OutOfRangeError): 713 sess.run(self.get_next) 714 715 def testReadGzipFiles(self): 716 gzip_files = [] 717 for i, fn in enumerate(self.test_filenames): 718 with open(fn, "rb") as f: 719 gzfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i) 720 with gzip.GzipFile(gzfn, "wb") as gzf: 721 gzf.write(f.read()) 722 gzip_files.append(gzfn) 723 724 with self.test_session() as sess: 725 sess.run( 726 self.init_op, 727 feed_dict={self.filenames: gzip_files, 728 self.compression_type: "GZIP"}) 729 for j in range(self._num_files): 730 for i in range(self._num_records): 731 self.assertAllEqual(self._record(j, i), sess.run(self.get_next)) 732 with self.assertRaises(errors.OutOfRangeError): 733 sess.run(self.get_next) 734 735 def testReadWithBuffer(self): 736 one_mebibyte = 2**20 737 d = readers.TFRecordDataset(self.test_filenames, buffer_size=one_mebibyte) 738 iterator = d.make_one_shot_iterator() 739 with self.test_session() as sess: 740 for j in range(self._num_files): 741 for i in range(self._num_records): 742 self.assertAllEqual(self._record(j, i), sess.run(iterator.get_next())) 743 with self.assertRaises(errors.OutOfRangeError): 744 sess.run(iterator.get_next()) 745 746 747 if __name__ == "__main__": 748 test.main() 749