1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/core/framework/partial_tensor_shape.h" 16 #include "tensorflow/core/framework/tensor.h" 17 #include "tensorflow/core/kernels/data/dataset.h" 18 #include "tensorflow/core/lib/io/buffered_inputstream.h" 19 #include "tensorflow/core/lib/io/inputbuffer.h" 20 #include "tensorflow/core/lib/io/random_inputstream.h" 21 #include "tensorflow/core/lib/io/record_reader.h" 22 #include "tensorflow/core/lib/io/zlib_compression_options.h" 23 #include "tensorflow/core/lib/io/zlib_inputstream.h" 24 25 namespace tensorflow { 26 27 namespace { 28 29 // See documentation in ../ops/dataset_ops.cc for a high-level 30 // description of the following ops. 31 32 class TextLineDatasetOp : public DatasetOpKernel { 33 public: 34 using DatasetOpKernel::DatasetOpKernel; 35 36 void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { 37 const Tensor* filenames_tensor; 38 OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor)); 39 OP_REQUIRES( 40 ctx, filenames_tensor->dims() <= 1, 41 errors::InvalidArgument("`filenames` must be a scalar or a vector.")); 42 43 string compression_type; 44 OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "compression_type", 45 &compression_type)); 46 47 int64 buffer_size = -1; 48 OP_REQUIRES_OK( 49 ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size)); 50 OP_REQUIRES( 51 ctx, buffer_size >= 0, 52 errors::InvalidArgument("`buffer_size` must be >= 0 (0 == default)")); 53 54 io::ZlibCompressionOptions zlib_compression_options = 55 io::ZlibCompressionOptions::DEFAULT(); 56 if (compression_type == "ZLIB") { 57 zlib_compression_options = io::ZlibCompressionOptions::DEFAULT(); 58 } else if (compression_type == "GZIP") { 59 zlib_compression_options = io::ZlibCompressionOptions::GZIP(); 60 } else { 61 OP_REQUIRES(ctx, compression_type.empty(), 62 errors::InvalidArgument("Unsupported compression_type.")); 63 } 64 65 if (buffer_size != 0) { 66 // Set the override size. 67 zlib_compression_options.input_buffer_size = buffer_size; 68 } 69 70 std::vector<string> filenames; 71 filenames.reserve(filenames_tensor->NumElements()); 72 for (int i = 0; i < filenames_tensor->NumElements(); ++i) { 73 filenames.push_back(filenames_tensor->flat<string>()(i)); 74 } 75 76 *output = new Dataset(ctx, std::move(filenames), compression_type, 77 zlib_compression_options); 78 } 79 80 private: 81 class Dataset : public GraphDatasetBase { 82 public: 83 Dataset(OpKernelContext* ctx, std::vector<string> filenames, 84 const string& compression_type, 85 const io::ZlibCompressionOptions& options) 86 : GraphDatasetBase(ctx), 87 filenames_(std::move(filenames)), 88 compression_type_(compression_type), 89 use_compression_(!compression_type.empty()), 90 options_(options) {} 91 92 std::unique_ptr<IteratorBase> MakeIterator( 93 const string& prefix) const override { 94 return std::unique_ptr<IteratorBase>( 95 new Iterator({this, strings::StrCat(prefix, "::TextLine")})); 96 } 97 98 const DataTypeVector& output_dtypes() const override { 99 static DataTypeVector* dtypes = new DataTypeVector({DT_STRING}); 100 return *dtypes; 101 } 102 103 const std::vector<PartialTensorShape>& output_shapes() const override { 104 static std::vector<PartialTensorShape>* shapes = 105 new std::vector<PartialTensorShape>({{}}); 106 return *shapes; 107 } 108 109 string DebugString() override { return "TextLineDatasetOp::Dataset"; } 110 111 protected: 112 Status AsGraphDefInternal(DatasetGraphDefBuilder* b, 113 Node** output) const override { 114 Node* filenames = nullptr; 115 Node* compression_type = nullptr; 116 Node* buffer_size = nullptr; 117 TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames)); 118 TF_RETURN_IF_ERROR(b->AddScalar(compression_type_, &compression_type)); 119 TF_RETURN_IF_ERROR( 120 b->AddScalar(options_.input_buffer_size, &buffer_size)); 121 TF_RETURN_IF_ERROR(b->AddDataset( 122 this, {filenames, compression_type, buffer_size}, output)); 123 return Status::OK(); 124 } 125 126 private: 127 class Iterator : public DatasetIterator<Dataset> { 128 public: 129 explicit Iterator(const Params& params) 130 : DatasetIterator<Dataset>(params) {} 131 132 Status GetNextInternal(IteratorContext* ctx, 133 std::vector<Tensor>* out_tensors, 134 bool* end_of_sequence) override { 135 mutex_lock l(mu_); 136 do { 137 // We are currently processing a file, so try to read the next line. 138 if (buffered_input_stream_) { 139 string line_contents; 140 Status s = buffered_input_stream_->ReadLine(&line_contents); 141 142 if (s.ok()) { 143 // Produce the line as output. 144 Tensor line_tensor(ctx->allocator({}), DT_STRING, {}); 145 line_tensor.scalar<string>()() = line_contents; 146 out_tensors->emplace_back(std::move(line_tensor)); 147 *end_of_sequence = false; 148 return Status::OK(); 149 } else if (!errors::IsOutOfRange(s)) { 150 // Report non-EOF errors to the caller. 151 return s; 152 } 153 // We have reached the end of the current file, so maybe 154 // move on to next file. 155 ResetStreamsLocked(); 156 ++current_file_index_; 157 } 158 159 // Iteration ends when there are no more files to process. 160 if (current_file_index_ == dataset()->filenames_.size()) { 161 *end_of_sequence = true; 162 return Status::OK(); 163 } 164 165 TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); 166 } while (true); 167 } 168 169 protected: 170 Status SaveInternal(IteratorStateWriter* writer) override { 171 mutex_lock l(mu_); 172 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"), 173 current_file_index_)); 174 175 // `buffered_input_stream_` is empty if 176 // 1. GetNext has not been called even once. 177 // 2. All files have been read and iterator has been exhausted. 178 if (buffered_input_stream_) { 179 TF_RETURN_IF_ERROR(writer->WriteScalar( 180 full_name("current_pos"), buffered_input_stream_->Tell())); 181 } 182 return Status::OK(); 183 } 184 185 Status RestoreInternal(IteratorContext* ctx, 186 IteratorStateReader* reader) override { 187 mutex_lock l(mu_); 188 ResetStreamsLocked(); 189 int64 current_file_index; 190 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_file_index"), 191 ¤t_file_index)); 192 current_file_index_ = size_t(current_file_index); 193 // The key "current_pos" is written only if the iterator was saved 194 // with an open file. 195 if (reader->Contains(full_name("current_pos"))) { 196 int64 current_pos; 197 TF_RETURN_IF_ERROR( 198 reader->ReadScalar(full_name("current_pos"), ¤t_pos)); 199 200 TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); 201 TF_RETURN_IF_ERROR(buffered_input_stream_->Seek(current_pos)); 202 } 203 return Status::OK(); 204 } 205 206 private: 207 // Sets up reader streams to read from the file at `current_file_index_`. 208 Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) { 209 if (current_file_index_ >= dataset()->filenames_.size()) { 210 return errors::InvalidArgument( 211 "current_file_index_:", current_file_index_, 212 " >= filenames_.size():", dataset()->filenames_.size()); 213 } 214 215 // Actually move on to next file. 216 TF_RETURN_IF_ERROR(env->NewRandomAccessFile( 217 dataset()->filenames_[current_file_index_], &file_)); 218 input_stream_.reset( 219 new io::RandomAccessInputStream(file_.get(), false)); 220 221 if (dataset()->use_compression_) { 222 zlib_input_stream_.reset(new io::ZlibInputStream( 223 input_stream_.get(), dataset()->options_.input_buffer_size, 224 dataset()->options_.input_buffer_size, dataset()->options_)); 225 buffered_input_stream_.reset(new io::BufferedInputStream( 226 zlib_input_stream_.get(), dataset()->options_.input_buffer_size, 227 false)); 228 } else { 229 buffered_input_stream_.reset(new io::BufferedInputStream( 230 input_stream_.get(), dataset()->options_.input_buffer_size, 231 false)); 232 } 233 return Status::OK(); 234 } 235 236 // Resets all reader streams. 237 void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { 238 input_stream_.reset(); 239 zlib_input_stream_.reset(); 240 buffered_input_stream_.reset(); 241 file_.reset(); 242 } 243 244 mutex mu_; 245 std::unique_ptr<io::RandomAccessInputStream> input_stream_ 246 GUARDED_BY(mu_); 247 std::unique_ptr<io::ZlibInputStream> zlib_input_stream_ GUARDED_BY(mu_); 248 std::unique_ptr<io::BufferedInputStream> buffered_input_stream_ 249 GUARDED_BY(mu_); 250 size_t current_file_index_ GUARDED_BY(mu_) = 0; 251 std::unique_ptr<RandomAccessFile> file_ 252 GUARDED_BY(mu_); // must outlive input_stream_ 253 }; 254 255 const std::vector<string> filenames_; 256 const string compression_type_; 257 const bool use_compression_; 258 const io::ZlibCompressionOptions options_; 259 }; 260 }; 261 262 REGISTER_KERNEL_BUILDER(Name("TextLineDataset").Device(DEVICE_CPU), 263 TextLineDatasetOp); 264 265 class FixedLengthRecordDatasetOp : public DatasetOpKernel { 266 public: 267 using DatasetOpKernel::DatasetOpKernel; 268 269 void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { 270 const Tensor* filenames_tensor; 271 OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor)); 272 OP_REQUIRES( 273 ctx, filenames_tensor->dims() <= 1, 274 errors::InvalidArgument("`filenames` must be a scalar or a vector.")); 275 276 std::vector<string> filenames; 277 filenames.reserve(filenames_tensor->NumElements()); 278 for (int i = 0; i < filenames_tensor->NumElements(); ++i) { 279 filenames.push_back(filenames_tensor->flat<string>()(i)); 280 } 281 282 int64 header_bytes = -1; 283 OP_REQUIRES_OK( 284 ctx, ParseScalarArgument<int64>(ctx, "header_bytes", &header_bytes)); 285 OP_REQUIRES(ctx, header_bytes >= 0, 286 errors::InvalidArgument("`header_bytes` must be >= 0")); 287 288 int64 record_bytes = -1; 289 OP_REQUIRES_OK( 290 ctx, ParseScalarArgument<int64>(ctx, "record_bytes", &record_bytes)); 291 OP_REQUIRES(ctx, record_bytes > 0, 292 errors::InvalidArgument("`record_bytes` must be > 0")); 293 294 int64 footer_bytes = -1; 295 OP_REQUIRES_OK( 296 ctx, ParseScalarArgument<int64>(ctx, "footer_bytes", &footer_bytes)); 297 OP_REQUIRES(ctx, footer_bytes >= 0, 298 errors::InvalidArgument("`footer_bytes` must be >= 0")); 299 300 int64 buffer_size = -1; 301 OP_REQUIRES_OK( 302 ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size)); 303 OP_REQUIRES(ctx, buffer_size >= 0, 304 errors::InvalidArgument("`buffer_size` must be >= 0")); 305 if (buffer_size == 0) { 306 buffer_size = 256 << 10; // 256 kB as default. 307 } 308 309 *output = new Dataset(ctx, std::move(filenames), header_bytes, record_bytes, 310 footer_bytes, buffer_size); 311 } 312 313 private: 314 class Dataset : public GraphDatasetBase { 315 public: 316 explicit Dataset(OpKernelContext* ctx, std::vector<string> filenames, 317 int64 header_bytes, int64 record_bytes, int64 footer_bytes, 318 int64 buffer_size) 319 : GraphDatasetBase(ctx), 320 filenames_(std::move(filenames)), 321 header_bytes_(header_bytes), 322 record_bytes_(record_bytes), 323 footer_bytes_(footer_bytes), 324 buffer_size_(buffer_size) {} 325 326 std::unique_ptr<IteratorBase> MakeIterator( 327 const string& prefix) const override { 328 return std::unique_ptr<IteratorBase>( 329 new Iterator({this, strings::StrCat(prefix, "::FixedLengthRecord")})); 330 } 331 332 const DataTypeVector& output_dtypes() const override { 333 static DataTypeVector* dtypes = new DataTypeVector({DT_STRING}); 334 return *dtypes; 335 } 336 337 const std::vector<PartialTensorShape>& output_shapes() const override { 338 static std::vector<PartialTensorShape>* shapes = 339 new std::vector<PartialTensorShape>({{}}); 340 return *shapes; 341 } 342 343 string DebugString() override { 344 return "FixedLengthRecordDatasetOp::Dataset"; 345 } 346 347 protected: 348 Status AsGraphDefInternal(DatasetGraphDefBuilder* b, 349 Node** output) const override { 350 Node* filenames = nullptr; 351 Node* header_bytes = nullptr; 352 Node* record_bytes = nullptr; 353 Node* footer_bytes = nullptr; 354 Node* buffer_size = nullptr; 355 TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames)); 356 TF_RETURN_IF_ERROR(b->AddScalar(header_bytes_, &header_bytes)); 357 TF_RETURN_IF_ERROR(b->AddScalar(record_bytes_, &record_bytes)); 358 TF_RETURN_IF_ERROR(b->AddScalar(footer_bytes_, &footer_bytes)); 359 TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size)); 360 TF_RETURN_IF_ERROR(b->AddDataset( 361 this, 362 {filenames, header_bytes, record_bytes, footer_bytes, buffer_size}, 363 output)); 364 return Status::OK(); 365 } 366 367 private: 368 class Iterator : public DatasetIterator<Dataset> { 369 public: 370 explicit Iterator(const Params& params) 371 : DatasetIterator<Dataset>(params) {} 372 373 Status GetNextInternal(IteratorContext* ctx, 374 std::vector<Tensor>* out_tensors, 375 bool* end_of_sequence) override { 376 mutex_lock l(mu_); 377 do { 378 // We are currently processing a file, so try to read the next record. 379 if (input_buffer_) { 380 const int64 current_pos = input_buffer_->Tell(); 381 DCHECK_GE(file_pos_limit_, 0); 382 if (current_pos < file_pos_limit_) { 383 string record; 384 TF_RETURN_IF_ERROR( 385 input_buffer_->ReadNBytes(dataset()->record_bytes_, &record)); 386 // Produce the record as output. 387 Tensor record_tensor(ctx->allocator({}), DT_STRING, {}); 388 record_tensor.scalar<string>()() = record; 389 out_tensors->emplace_back(std::move(record_tensor)); 390 *end_of_sequence = false; 391 return Status::OK(); 392 } 393 394 // We have reached the end of the current file, so maybe 395 // move on to next file. 396 input_buffer_.reset(); 397 file_.reset(); 398 ++current_file_index_; 399 } 400 401 // Iteration ends when there are no more files to process. 402 if (current_file_index_ == dataset()->filenames_.size()) { 403 *end_of_sequence = true; 404 return Status::OK(); 405 } 406 407 // Actually move on to next file. 408 uint64 file_size; 409 TF_RETURN_IF_ERROR(ctx->env()->GetFileSize( 410 dataset()->filenames_[current_file_index_], &file_size)); 411 file_pos_limit_ = file_size - dataset()->footer_bytes_; 412 413 uint64 body_size = 414 file_size - (dataset()->header_bytes_ + dataset()->footer_bytes_); 415 416 if (body_size % dataset()->record_bytes_ != 0) { 417 return errors::InvalidArgument( 418 "Excluding the header (", dataset()->header_bytes_, 419 " bytes) and footer (", dataset()->footer_bytes_, 420 " bytes), input file \"", 421 dataset()->filenames_[current_file_index_], 422 "\" has body length ", body_size, 423 " bytes, which is not an exact multiple of the record length (", 424 dataset()->record_bytes_, " bytes)."); 425 } 426 TF_RETURN_IF_ERROR(ctx->env()->NewRandomAccessFile( 427 dataset()->filenames_[current_file_index_], &file_)); 428 input_buffer_.reset( 429 new io::InputBuffer(file_.get(), dataset()->buffer_size_)); 430 TF_RETURN_IF_ERROR( 431 input_buffer_->SkipNBytes(dataset()->header_bytes_)); 432 } while (true); 433 } 434 435 protected: 436 Status SaveInternal(IteratorStateWriter* writer) override { 437 mutex_lock l(mu_); 438 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"), 439 current_file_index_)); 440 441 // `input_buffer_` is empty if 442 // 1. GetNext has not been called even once. 443 // 2. All files have been read and iterator has been exhausted. 444 int64 current_pos = input_buffer_ ? input_buffer_->Tell() : -1; 445 TF_RETURN_IF_ERROR( 446 writer->WriteScalar(full_name("current_pos"), current_pos)); 447 return Status::OK(); 448 } 449 450 Status RestoreInternal(IteratorContext* ctx, 451 IteratorStateReader* reader) override { 452 mutex_lock l(mu_); 453 int64 current_file_index; 454 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_file_index"), 455 ¤t_file_index)); 456 current_file_index_ = size_t(current_file_index); 457 int64 current_pos; 458 TF_RETURN_IF_ERROR( 459 reader->ReadScalar(full_name("current_pos"), ¤t_pos)); 460 461 // Seek to current_pos. 462 input_buffer_.reset(); 463 file_.reset(); 464 if (current_pos >= 0) { // There was an active input_buffer_. 465 uint64 file_size; 466 TF_RETURN_IF_ERROR(ctx->env()->GetFileSize( 467 dataset()->filenames_[current_file_index_], &file_size)); 468 file_pos_limit_ = file_size - dataset()->footer_bytes_; 469 TF_RETURN_IF_ERROR(ctx->env()->NewRandomAccessFile( 470 dataset()->filenames_[current_file_index_], &file_)); 471 input_buffer_.reset( 472 new io::InputBuffer(file_.get(), dataset()->buffer_size_)); 473 TF_RETURN_IF_ERROR(input_buffer_->Seek(current_pos)); 474 } 475 476 return Status::OK(); 477 } 478 479 private: 480 mutex mu_; 481 size_t current_file_index_ GUARDED_BY(mu_) = 0; 482 std::unique_ptr<RandomAccessFile> file_ 483 GUARDED_BY(mu_); // must outlive input_buffer_ 484 std::unique_ptr<io::InputBuffer> input_buffer_ GUARDED_BY(mu_); 485 int64 file_pos_limit_ GUARDED_BY(mu_) = -1; 486 }; 487 488 const std::vector<string> filenames_; 489 const int64 header_bytes_; 490 const int64 record_bytes_; 491 const int64 footer_bytes_; 492 const int64 buffer_size_; 493 }; 494 }; 495 496 REGISTER_KERNEL_BUILDER(Name("FixedLengthRecordDataset").Device(DEVICE_CPU), 497 FixedLengthRecordDatasetOp); 498 499 class TFRecordDatasetOp : public DatasetOpKernel { 500 public: 501 using DatasetOpKernel::DatasetOpKernel; 502 503 void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { 504 const Tensor* filenames_tensor; 505 OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor)); 506 OP_REQUIRES( 507 ctx, filenames_tensor->dims() <= 1, 508 errors::InvalidArgument("`filenames` must be a scalar or a vector.")); 509 510 std::vector<string> filenames; 511 filenames.reserve(filenames_tensor->NumElements()); 512 for (int i = 0; i < filenames_tensor->NumElements(); ++i) { 513 filenames.push_back(filenames_tensor->flat<string>()(i)); 514 } 515 516 string compression_type; 517 OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "compression_type", 518 &compression_type)); 519 520 int64 buffer_size = -1; 521 OP_REQUIRES_OK( 522 ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size)); 523 OP_REQUIRES(ctx, buffer_size >= 0, 524 errors::InvalidArgument( 525 "`buffer_size` must be >= 0 (0 == no buffering)")); 526 527 *output = 528 new Dataset(ctx, std::move(filenames), compression_type, buffer_size); 529 } 530 531 private: 532 class Dataset : public GraphDatasetBase { 533 public: 534 explicit Dataset(OpKernelContext* ctx, std::vector<string> filenames, 535 const string& compression_type, int64 buffer_size) 536 : GraphDatasetBase(ctx), 537 filenames_(std::move(filenames)), 538 compression_type_(compression_type), 539 options_(io::RecordReaderOptions::CreateRecordReaderOptions( 540 compression_type)) { 541 if (buffer_size > 0) { 542 options_.buffer_size = buffer_size; 543 } 544 } 545 546 std::unique_ptr<IteratorBase> MakeIterator( 547 const string& prefix) const override { 548 return std::unique_ptr<IteratorBase>( 549 new Iterator({this, strings::StrCat(prefix, "::TFRecord")})); 550 } 551 552 const DataTypeVector& output_dtypes() const override { 553 static DataTypeVector* dtypes = new DataTypeVector({DT_STRING}); 554 return *dtypes; 555 } 556 557 const std::vector<PartialTensorShape>& output_shapes() const override { 558 static std::vector<PartialTensorShape>* shapes = 559 new std::vector<PartialTensorShape>({{}}); 560 return *shapes; 561 } 562 563 string DebugString() override { return "TFRecordDatasetOp::Dataset"; } 564 565 protected: 566 Status AsGraphDefInternal(DatasetGraphDefBuilder* b, 567 Node** output) const override { 568 Node* filenames = nullptr; 569 TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames)); 570 Node* compression_type = nullptr; 571 TF_RETURN_IF_ERROR(b->AddScalar(compression_type_, &compression_type)); 572 Node* buffer_size = nullptr; 573 TF_RETURN_IF_ERROR(b->AddScalar(options_.buffer_size, &buffer_size)); 574 TF_RETURN_IF_ERROR(b->AddDataset( 575 this, {filenames, compression_type, buffer_size}, output)); 576 return Status::OK(); 577 } 578 579 private: 580 class Iterator : public DatasetIterator<Dataset> { 581 public: 582 explicit Iterator(const Params& params) 583 : DatasetIterator<Dataset>(params) {} 584 585 Status GetNextInternal(IteratorContext* ctx, 586 std::vector<Tensor>* out_tensors, 587 bool* end_of_sequence) override { 588 mutex_lock l(mu_); 589 do { 590 // We are currently processing a file, so try to read the next record. 591 if (reader_) { 592 Tensor result_tensor(ctx->allocator({}), DT_STRING, {}); 593 Status s = reader_->ReadRecord(&result_tensor.scalar<string>()()); 594 if (s.ok()) { 595 out_tensors->emplace_back(std::move(result_tensor)); 596 *end_of_sequence = false; 597 return Status::OK(); 598 } else if (!errors::IsOutOfRange(s)) { 599 return s; 600 } 601 602 // We have reached the end of the current file, so maybe 603 // move on to next file. 604 ResetStreamsLocked(); 605 ++current_file_index_; 606 } 607 608 // Iteration ends when there are no more files to process. 609 if (current_file_index_ == dataset()->filenames_.size()) { 610 *end_of_sequence = true; 611 return Status::OK(); 612 } 613 614 TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); 615 } while (true); 616 } 617 618 protected: 619 Status SaveInternal(IteratorStateWriter* writer) override { 620 mutex_lock l(mu_); 621 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"), 622 current_file_index_)); 623 624 if (reader_) { 625 TF_RETURN_IF_ERROR( 626 writer->WriteScalar(full_name("offset"), reader_->TellOffset())); 627 } 628 return Status::OK(); 629 } 630 631 Status RestoreInternal(IteratorContext* ctx, 632 IteratorStateReader* reader) override { 633 mutex_lock l(mu_); 634 ResetStreamsLocked(); 635 int64 current_file_index; 636 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_file_index"), 637 ¤t_file_index)); 638 current_file_index_ = size_t(current_file_index); 639 if (reader->Contains(full_name("offset"))) { 640 int64 offset; 641 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("offset"), &offset)); 642 TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); 643 TF_RETURN_IF_ERROR(reader_->SeekOffset(offset)); 644 } 645 return Status::OK(); 646 } 647 648 private: 649 // Sets up reader streams to read from the file at `current_file_index_`. 650 Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) { 651 if (current_file_index_ >= dataset()->filenames_.size()) { 652 return errors::InvalidArgument( 653 "current_file_index_:", current_file_index_, 654 " >= filenames_.size():", dataset()->filenames_.size()); 655 } 656 657 // Actually move on to next file. 658 const string& next_filename = 659 dataset()->filenames_[current_file_index_]; 660 TF_RETURN_IF_ERROR(env->NewRandomAccessFile(next_filename, &file_)); 661 reader_.reset( 662 new io::SequentialRecordReader(file_.get(), dataset()->options_)); 663 return Status::OK(); 664 } 665 666 // Resets all reader streams. 667 void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { 668 reader_.reset(); 669 file_.reset(); 670 } 671 672 mutex mu_; 673 size_t current_file_index_ GUARDED_BY(mu_) = 0; 674 675 // `reader_` will borrow the object that `file_` points to, so 676 // we must destroy `reader_` before `file_`. 677 std::unique_ptr<RandomAccessFile> file_ GUARDED_BY(mu_); 678 std::unique_ptr<io::SequentialRecordReader> reader_ GUARDED_BY(mu_); 679 }; 680 681 const std::vector<string> filenames_; 682 const string compression_type_; 683 io::RecordReaderOptions options_; 684 }; 685 }; 686 687 REGISTER_KERNEL_BUILDER(Name("TFRecordDataset").Device(DEVICE_CPU), 688 TFRecordDatasetOp); 689 690 } // namespace 691 692 } // namespace tensorflow 693