Home | History | Annotate | Download | only in data
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/core/framework/partial_tensor_shape.h"
     16 #include "tensorflow/core/framework/tensor.h"
     17 #include "tensorflow/core/kernels/data/dataset.h"
     18 #include "tensorflow/core/lib/io/buffered_inputstream.h"
     19 #include "tensorflow/core/lib/io/inputbuffer.h"
     20 #include "tensorflow/core/lib/io/random_inputstream.h"
     21 #include "tensorflow/core/lib/io/record_reader.h"
     22 #include "tensorflow/core/lib/io/zlib_compression_options.h"
     23 #include "tensorflow/core/lib/io/zlib_inputstream.h"
     24 
     25 namespace tensorflow {
     26 
     27 namespace {
     28 
     29 // See documentation in ../ops/dataset_ops.cc for a high-level
     30 // description of the following ops.
     31 
     32 class TextLineDatasetOp : public DatasetOpKernel {
     33  public:
     34   using DatasetOpKernel::DatasetOpKernel;
     35 
     36   void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
     37     const Tensor* filenames_tensor;
     38     OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor));
     39     OP_REQUIRES(
     40         ctx, filenames_tensor->dims() <= 1,
     41         errors::InvalidArgument("`filenames` must be a scalar or a vector."));
     42 
     43     string compression_type;
     44     OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "compression_type",
     45                                                     &compression_type));
     46 
     47     int64 buffer_size = -1;
     48     OP_REQUIRES_OK(
     49         ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size));
     50     OP_REQUIRES(
     51         ctx, buffer_size >= 0,
     52         errors::InvalidArgument("`buffer_size` must be >= 0 (0 == default)"));
     53 
     54     io::ZlibCompressionOptions zlib_compression_options =
     55         io::ZlibCompressionOptions::DEFAULT();
     56     if (compression_type == "ZLIB") {
     57       zlib_compression_options = io::ZlibCompressionOptions::DEFAULT();
     58     } else if (compression_type == "GZIP") {
     59       zlib_compression_options = io::ZlibCompressionOptions::GZIP();
     60     } else {
     61       OP_REQUIRES(ctx, compression_type.empty(),
     62                   errors::InvalidArgument("Unsupported compression_type."));
     63     }
     64 
     65     if (buffer_size != 0) {
     66       // Set the override size.
     67       zlib_compression_options.input_buffer_size = buffer_size;
     68     }
     69 
     70     std::vector<string> filenames;
     71     filenames.reserve(filenames_tensor->NumElements());
     72     for (int i = 0; i < filenames_tensor->NumElements(); ++i) {
     73       filenames.push_back(filenames_tensor->flat<string>()(i));
     74     }
     75 
     76     *output = new Dataset(ctx, std::move(filenames), compression_type,
     77                           zlib_compression_options);
     78   }
     79 
     80  private:
     81   class Dataset : public GraphDatasetBase {
     82    public:
     83     Dataset(OpKernelContext* ctx, std::vector<string> filenames,
     84             const string& compression_type,
     85             const io::ZlibCompressionOptions& options)
     86         : GraphDatasetBase(ctx),
     87           filenames_(std::move(filenames)),
     88           compression_type_(compression_type),
     89           use_compression_(!compression_type.empty()),
     90           options_(options) {}
     91 
     92     std::unique_ptr<IteratorBase> MakeIterator(
     93         const string& prefix) const override {
     94       return std::unique_ptr<IteratorBase>(
     95           new Iterator({this, strings::StrCat(prefix, "::TextLine")}));
     96     }
     97 
     98     const DataTypeVector& output_dtypes() const override {
     99       static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
    100       return *dtypes;
    101     }
    102 
    103     const std::vector<PartialTensorShape>& output_shapes() const override {
    104       static std::vector<PartialTensorShape>* shapes =
    105           new std::vector<PartialTensorShape>({{}});
    106       return *shapes;
    107     }
    108 
    109     string DebugString() override { return "TextLineDatasetOp::Dataset"; }
    110 
    111    protected:
    112     Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
    113                               Node** output) const override {
    114       Node* filenames = nullptr;
    115       Node* compression_type = nullptr;
    116       Node* buffer_size = nullptr;
    117       TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames));
    118       TF_RETURN_IF_ERROR(b->AddScalar(compression_type_, &compression_type));
    119       TF_RETURN_IF_ERROR(
    120           b->AddScalar(options_.input_buffer_size, &buffer_size));
    121       TF_RETURN_IF_ERROR(b->AddDataset(
    122           this, {filenames, compression_type, buffer_size}, output));
    123       return Status::OK();
    124     }
    125 
    126    private:
    127     class Iterator : public DatasetIterator<Dataset> {
    128      public:
    129       explicit Iterator(const Params& params)
    130           : DatasetIterator<Dataset>(params) {}
    131 
    132       Status GetNextInternal(IteratorContext* ctx,
    133                              std::vector<Tensor>* out_tensors,
    134                              bool* end_of_sequence) override {
    135         mutex_lock l(mu_);
    136         do {
    137           // We are currently processing a file, so try to read the next line.
    138           if (buffered_input_stream_) {
    139             string line_contents;
    140             Status s = buffered_input_stream_->ReadLine(&line_contents);
    141 
    142             if (s.ok()) {
    143               // Produce the line as output.
    144               Tensor line_tensor(ctx->allocator({}), DT_STRING, {});
    145               line_tensor.scalar<string>()() = line_contents;
    146               out_tensors->emplace_back(std::move(line_tensor));
    147               *end_of_sequence = false;
    148               return Status::OK();
    149             } else if (!errors::IsOutOfRange(s)) {
    150               // Report non-EOF errors to the caller.
    151               return s;
    152             }
    153             // We have reached the end of the current file, so maybe
    154             // move on to next file.
    155             ResetStreamsLocked();
    156             ++current_file_index_;
    157           }
    158 
    159           // Iteration ends when there are no more files to process.
    160           if (current_file_index_ == dataset()->filenames_.size()) {
    161             *end_of_sequence = true;
    162             return Status::OK();
    163           }
    164 
    165           TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
    166         } while (true);
    167       }
    168 
    169      protected:
    170       Status SaveInternal(IteratorStateWriter* writer) override {
    171         mutex_lock l(mu_);
    172         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"),
    173                                                current_file_index_));
    174 
    175         // `buffered_input_stream_` is empty if
    176         // 1. GetNext has not been called even once.
    177         // 2. All files have been read and iterator has been exhausted.
    178         if (buffered_input_stream_) {
    179           TF_RETURN_IF_ERROR(writer->WriteScalar(
    180               full_name("current_pos"), buffered_input_stream_->Tell()));
    181         }
    182         return Status::OK();
    183       }
    184 
    185       Status RestoreInternal(IteratorContext* ctx,
    186                              IteratorStateReader* reader) override {
    187         mutex_lock l(mu_);
    188         ResetStreamsLocked();
    189         int64 current_file_index;
    190         TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_file_index"),
    191                                               &current_file_index));
    192         current_file_index_ = size_t(current_file_index);
    193         // The key "current_pos" is written only if the iterator was saved
    194         // with an open file.
    195         if (reader->Contains(full_name("current_pos"))) {
    196           int64 current_pos;
    197           TF_RETURN_IF_ERROR(
    198               reader->ReadScalar(full_name("current_pos"), &current_pos));
    199 
    200           TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
    201           TF_RETURN_IF_ERROR(buffered_input_stream_->Seek(current_pos));
    202         }
    203         return Status::OK();
    204       }
    205 
    206      private:
    207       // Sets up reader streams to read from the file at `current_file_index_`.
    208       Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    209         if (current_file_index_ >= dataset()->filenames_.size()) {
    210           return errors::InvalidArgument(
    211               "current_file_index_:", current_file_index_,
    212               " >= filenames_.size():", dataset()->filenames_.size());
    213         }
    214 
    215         // Actually move on to next file.
    216         TF_RETURN_IF_ERROR(env->NewRandomAccessFile(
    217             dataset()->filenames_[current_file_index_], &file_));
    218         input_stream_.reset(
    219             new io::RandomAccessInputStream(file_.get(), false));
    220 
    221         if (dataset()->use_compression_) {
    222           zlib_input_stream_.reset(new io::ZlibInputStream(
    223               input_stream_.get(), dataset()->options_.input_buffer_size,
    224               dataset()->options_.input_buffer_size, dataset()->options_));
    225           buffered_input_stream_.reset(new io::BufferedInputStream(
    226               zlib_input_stream_.get(), dataset()->options_.input_buffer_size,
    227               false));
    228         } else {
    229           buffered_input_stream_.reset(new io::BufferedInputStream(
    230               input_stream_.get(), dataset()->options_.input_buffer_size,
    231               false));
    232         }
    233         return Status::OK();
    234       }
    235 
    236       // Resets all reader streams.
    237       void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    238         input_stream_.reset();
    239         zlib_input_stream_.reset();
    240         buffered_input_stream_.reset();
    241         file_.reset();
    242       }
    243 
    244       mutex mu_;
    245       std::unique_ptr<io::RandomAccessInputStream> input_stream_
    246           GUARDED_BY(mu_);
    247       std::unique_ptr<io::ZlibInputStream> zlib_input_stream_ GUARDED_BY(mu_);
    248       std::unique_ptr<io::BufferedInputStream> buffered_input_stream_
    249           GUARDED_BY(mu_);
    250       size_t current_file_index_ GUARDED_BY(mu_) = 0;
    251       std::unique_ptr<RandomAccessFile> file_
    252           GUARDED_BY(mu_);  // must outlive input_stream_
    253     };
    254 
    255     const std::vector<string> filenames_;
    256     const string compression_type_;
    257     const bool use_compression_;
    258     const io::ZlibCompressionOptions options_;
    259   };
    260 };
    261 
    262 REGISTER_KERNEL_BUILDER(Name("TextLineDataset").Device(DEVICE_CPU),
    263                         TextLineDatasetOp);
    264 
    265 class FixedLengthRecordDatasetOp : public DatasetOpKernel {
    266  public:
    267   using DatasetOpKernel::DatasetOpKernel;
    268 
    269   void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
    270     const Tensor* filenames_tensor;
    271     OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor));
    272     OP_REQUIRES(
    273         ctx, filenames_tensor->dims() <= 1,
    274         errors::InvalidArgument("`filenames` must be a scalar or a vector."));
    275 
    276     std::vector<string> filenames;
    277     filenames.reserve(filenames_tensor->NumElements());
    278     for (int i = 0; i < filenames_tensor->NumElements(); ++i) {
    279       filenames.push_back(filenames_tensor->flat<string>()(i));
    280     }
    281 
    282     int64 header_bytes = -1;
    283     OP_REQUIRES_OK(
    284         ctx, ParseScalarArgument<int64>(ctx, "header_bytes", &header_bytes));
    285     OP_REQUIRES(ctx, header_bytes >= 0,
    286                 errors::InvalidArgument("`header_bytes` must be >= 0"));
    287 
    288     int64 record_bytes = -1;
    289     OP_REQUIRES_OK(
    290         ctx, ParseScalarArgument<int64>(ctx, "record_bytes", &record_bytes));
    291     OP_REQUIRES(ctx, record_bytes > 0,
    292                 errors::InvalidArgument("`record_bytes` must be > 0"));
    293 
    294     int64 footer_bytes = -1;
    295     OP_REQUIRES_OK(
    296         ctx, ParseScalarArgument<int64>(ctx, "footer_bytes", &footer_bytes));
    297     OP_REQUIRES(ctx, footer_bytes >= 0,
    298                 errors::InvalidArgument("`footer_bytes` must be >= 0"));
    299 
    300     int64 buffer_size = -1;
    301     OP_REQUIRES_OK(
    302         ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size));
    303     OP_REQUIRES(ctx, buffer_size >= 0,
    304                 errors::InvalidArgument("`buffer_size` must be >= 0"));
    305     if (buffer_size == 0) {
    306       buffer_size = 256 << 10;  // 256 kB as default.
    307     }
    308 
    309     *output = new Dataset(ctx, std::move(filenames), header_bytes, record_bytes,
    310                           footer_bytes, buffer_size);
    311   }
    312 
    313  private:
    314   class Dataset : public GraphDatasetBase {
    315    public:
    316     explicit Dataset(OpKernelContext* ctx, std::vector<string> filenames,
    317                      int64 header_bytes, int64 record_bytes, int64 footer_bytes,
    318                      int64 buffer_size)
    319         : GraphDatasetBase(ctx),
    320           filenames_(std::move(filenames)),
    321           header_bytes_(header_bytes),
    322           record_bytes_(record_bytes),
    323           footer_bytes_(footer_bytes),
    324           buffer_size_(buffer_size) {}
    325 
    326     std::unique_ptr<IteratorBase> MakeIterator(
    327         const string& prefix) const override {
    328       return std::unique_ptr<IteratorBase>(
    329           new Iterator({this, strings::StrCat(prefix, "::FixedLengthRecord")}));
    330     }
    331 
    332     const DataTypeVector& output_dtypes() const override {
    333       static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
    334       return *dtypes;
    335     }
    336 
    337     const std::vector<PartialTensorShape>& output_shapes() const override {
    338       static std::vector<PartialTensorShape>* shapes =
    339           new std::vector<PartialTensorShape>({{}});
    340       return *shapes;
    341     }
    342 
    343     string DebugString() override {
    344       return "FixedLengthRecordDatasetOp::Dataset";
    345     }
    346 
    347    protected:
    348     Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
    349                               Node** output) const override {
    350       Node* filenames = nullptr;
    351       Node* header_bytes = nullptr;
    352       Node* record_bytes = nullptr;
    353       Node* footer_bytes = nullptr;
    354       Node* buffer_size = nullptr;
    355       TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames));
    356       TF_RETURN_IF_ERROR(b->AddScalar(header_bytes_, &header_bytes));
    357       TF_RETURN_IF_ERROR(b->AddScalar(record_bytes_, &record_bytes));
    358       TF_RETURN_IF_ERROR(b->AddScalar(footer_bytes_, &footer_bytes));
    359       TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size));
    360       TF_RETURN_IF_ERROR(b->AddDataset(
    361           this,
    362           {filenames, header_bytes, record_bytes, footer_bytes, buffer_size},
    363           output));
    364       return Status::OK();
    365     }
    366 
    367    private:
    368     class Iterator : public DatasetIterator<Dataset> {
    369      public:
    370       explicit Iterator(const Params& params)
    371           : DatasetIterator<Dataset>(params) {}
    372 
    373       Status GetNextInternal(IteratorContext* ctx,
    374                              std::vector<Tensor>* out_tensors,
    375                              bool* end_of_sequence) override {
    376         mutex_lock l(mu_);
    377         do {
    378           // We are currently processing a file, so try to read the next record.
    379           if (input_buffer_) {
    380             const int64 current_pos = input_buffer_->Tell();
    381             DCHECK_GE(file_pos_limit_, 0);
    382             if (current_pos < file_pos_limit_) {
    383               string record;
    384               TF_RETURN_IF_ERROR(
    385                   input_buffer_->ReadNBytes(dataset()->record_bytes_, &record));
    386               // Produce the record as output.
    387               Tensor record_tensor(ctx->allocator({}), DT_STRING, {});
    388               record_tensor.scalar<string>()() = record;
    389               out_tensors->emplace_back(std::move(record_tensor));
    390               *end_of_sequence = false;
    391               return Status::OK();
    392             }
    393 
    394             // We have reached the end of the current file, so maybe
    395             // move on to next file.
    396             input_buffer_.reset();
    397             file_.reset();
    398             ++current_file_index_;
    399           }
    400 
    401           // Iteration ends when there are no more files to process.
    402           if (current_file_index_ == dataset()->filenames_.size()) {
    403             *end_of_sequence = true;
    404             return Status::OK();
    405           }
    406 
    407           // Actually move on to next file.
    408           uint64 file_size;
    409           TF_RETURN_IF_ERROR(ctx->env()->GetFileSize(
    410               dataset()->filenames_[current_file_index_], &file_size));
    411           file_pos_limit_ = file_size - dataset()->footer_bytes_;
    412 
    413           uint64 body_size =
    414               file_size - (dataset()->header_bytes_ + dataset()->footer_bytes_);
    415 
    416           if (body_size % dataset()->record_bytes_ != 0) {
    417             return errors::InvalidArgument(
    418                 "Excluding the header (", dataset()->header_bytes_,
    419                 " bytes) and footer (", dataset()->footer_bytes_,
    420                 " bytes), input file \"",
    421                 dataset()->filenames_[current_file_index_],
    422                 "\" has body length ", body_size,
    423                 " bytes, which is not an exact multiple of the record length (",
    424                 dataset()->record_bytes_, " bytes).");
    425           }
    426           TF_RETURN_IF_ERROR(ctx->env()->NewRandomAccessFile(
    427               dataset()->filenames_[current_file_index_], &file_));
    428           input_buffer_.reset(
    429               new io::InputBuffer(file_.get(), dataset()->buffer_size_));
    430           TF_RETURN_IF_ERROR(
    431               input_buffer_->SkipNBytes(dataset()->header_bytes_));
    432         } while (true);
    433       }
    434 
    435      protected:
    436       Status SaveInternal(IteratorStateWriter* writer) override {
    437         mutex_lock l(mu_);
    438         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"),
    439                                                current_file_index_));
    440 
    441         // `input_buffer_` is empty if
    442         // 1. GetNext has not been called even once.
    443         // 2. All files have been read and iterator has been exhausted.
    444         int64 current_pos = input_buffer_ ? input_buffer_->Tell() : -1;
    445         TF_RETURN_IF_ERROR(
    446             writer->WriteScalar(full_name("current_pos"), current_pos));
    447         return Status::OK();
    448       }
    449 
    450       Status RestoreInternal(IteratorContext* ctx,
    451                              IteratorStateReader* reader) override {
    452         mutex_lock l(mu_);
    453         int64 current_file_index;
    454         TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_file_index"),
    455                                               &current_file_index));
    456         current_file_index_ = size_t(current_file_index);
    457         int64 current_pos;
    458         TF_RETURN_IF_ERROR(
    459             reader->ReadScalar(full_name("current_pos"), &current_pos));
    460 
    461         // Seek to current_pos.
    462         input_buffer_.reset();
    463         file_.reset();
    464         if (current_pos >= 0) {  // There was an active input_buffer_.
    465           uint64 file_size;
    466           TF_RETURN_IF_ERROR(ctx->env()->GetFileSize(
    467               dataset()->filenames_[current_file_index_], &file_size));
    468           file_pos_limit_ = file_size - dataset()->footer_bytes_;
    469           TF_RETURN_IF_ERROR(ctx->env()->NewRandomAccessFile(
    470               dataset()->filenames_[current_file_index_], &file_));
    471           input_buffer_.reset(
    472               new io::InputBuffer(file_.get(), dataset()->buffer_size_));
    473           TF_RETURN_IF_ERROR(input_buffer_->Seek(current_pos));
    474         }
    475 
    476         return Status::OK();
    477       }
    478 
    479      private:
    480       mutex mu_;
    481       size_t current_file_index_ GUARDED_BY(mu_) = 0;
    482       std::unique_ptr<RandomAccessFile> file_
    483           GUARDED_BY(mu_);  // must outlive input_buffer_
    484       std::unique_ptr<io::InputBuffer> input_buffer_ GUARDED_BY(mu_);
    485       int64 file_pos_limit_ GUARDED_BY(mu_) = -1;
    486     };
    487 
    488     const std::vector<string> filenames_;
    489     const int64 header_bytes_;
    490     const int64 record_bytes_;
    491     const int64 footer_bytes_;
    492     const int64 buffer_size_;
    493   };
    494 };
    495 
    496 REGISTER_KERNEL_BUILDER(Name("FixedLengthRecordDataset").Device(DEVICE_CPU),
    497                         FixedLengthRecordDatasetOp);
    498 
    499 class TFRecordDatasetOp : public DatasetOpKernel {
    500  public:
    501   using DatasetOpKernel::DatasetOpKernel;
    502 
    503   void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
    504     const Tensor* filenames_tensor;
    505     OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor));
    506     OP_REQUIRES(
    507         ctx, filenames_tensor->dims() <= 1,
    508         errors::InvalidArgument("`filenames` must be a scalar or a vector."));
    509 
    510     std::vector<string> filenames;
    511     filenames.reserve(filenames_tensor->NumElements());
    512     for (int i = 0; i < filenames_tensor->NumElements(); ++i) {
    513       filenames.push_back(filenames_tensor->flat<string>()(i));
    514     }
    515 
    516     string compression_type;
    517     OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "compression_type",
    518                                                     &compression_type));
    519 
    520     int64 buffer_size = -1;
    521     OP_REQUIRES_OK(
    522         ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size));
    523     OP_REQUIRES(ctx, buffer_size >= 0,
    524                 errors::InvalidArgument(
    525                     "`buffer_size` must be >= 0 (0 == no buffering)"));
    526 
    527     *output =
    528         new Dataset(ctx, std::move(filenames), compression_type, buffer_size);
    529   }
    530 
    531  private:
    532   class Dataset : public GraphDatasetBase {
    533    public:
    534     explicit Dataset(OpKernelContext* ctx, std::vector<string> filenames,
    535                      const string& compression_type, int64 buffer_size)
    536         : GraphDatasetBase(ctx),
    537           filenames_(std::move(filenames)),
    538           compression_type_(compression_type),
    539           options_(io::RecordReaderOptions::CreateRecordReaderOptions(
    540               compression_type)) {
    541       if (buffer_size > 0) {
    542         options_.buffer_size = buffer_size;
    543       }
    544     }
    545 
    546     std::unique_ptr<IteratorBase> MakeIterator(
    547         const string& prefix) const override {
    548       return std::unique_ptr<IteratorBase>(
    549           new Iterator({this, strings::StrCat(prefix, "::TFRecord")}));
    550     }
    551 
    552     const DataTypeVector& output_dtypes() const override {
    553       static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
    554       return *dtypes;
    555     }
    556 
    557     const std::vector<PartialTensorShape>& output_shapes() const override {
    558       static std::vector<PartialTensorShape>* shapes =
    559           new std::vector<PartialTensorShape>({{}});
    560       return *shapes;
    561     }
    562 
    563     string DebugString() override { return "TFRecordDatasetOp::Dataset"; }
    564 
    565    protected:
    566     Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
    567                               Node** output) const override {
    568       Node* filenames = nullptr;
    569       TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames));
    570       Node* compression_type = nullptr;
    571       TF_RETURN_IF_ERROR(b->AddScalar(compression_type_, &compression_type));
    572       Node* buffer_size = nullptr;
    573       TF_RETURN_IF_ERROR(b->AddScalar(options_.buffer_size, &buffer_size));
    574       TF_RETURN_IF_ERROR(b->AddDataset(
    575           this, {filenames, compression_type, buffer_size}, output));
    576       return Status::OK();
    577     }
    578 
    579    private:
    580     class Iterator : public DatasetIterator<Dataset> {
    581      public:
    582       explicit Iterator(const Params& params)
    583           : DatasetIterator<Dataset>(params) {}
    584 
    585       Status GetNextInternal(IteratorContext* ctx,
    586                              std::vector<Tensor>* out_tensors,
    587                              bool* end_of_sequence) override {
    588         mutex_lock l(mu_);
    589         do {
    590           // We are currently processing a file, so try to read the next record.
    591           if (reader_) {
    592             Tensor result_tensor(ctx->allocator({}), DT_STRING, {});
    593             Status s = reader_->ReadRecord(&result_tensor.scalar<string>()());
    594             if (s.ok()) {
    595               out_tensors->emplace_back(std::move(result_tensor));
    596               *end_of_sequence = false;
    597               return Status::OK();
    598             } else if (!errors::IsOutOfRange(s)) {
    599               return s;
    600             }
    601 
    602             // We have reached the end of the current file, so maybe
    603             // move on to next file.
    604             ResetStreamsLocked();
    605             ++current_file_index_;
    606           }
    607 
    608           // Iteration ends when there are no more files to process.
    609           if (current_file_index_ == dataset()->filenames_.size()) {
    610             *end_of_sequence = true;
    611             return Status::OK();
    612           }
    613 
    614           TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
    615         } while (true);
    616       }
    617 
    618      protected:
    619       Status SaveInternal(IteratorStateWriter* writer) override {
    620         mutex_lock l(mu_);
    621         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"),
    622                                                current_file_index_));
    623 
    624         if (reader_) {
    625           TF_RETURN_IF_ERROR(
    626               writer->WriteScalar(full_name("offset"), reader_->TellOffset()));
    627         }
    628         return Status::OK();
    629       }
    630 
    631       Status RestoreInternal(IteratorContext* ctx,
    632                              IteratorStateReader* reader) override {
    633         mutex_lock l(mu_);
    634         ResetStreamsLocked();
    635         int64 current_file_index;
    636         TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_file_index"),
    637                                               &current_file_index));
    638         current_file_index_ = size_t(current_file_index);
    639         if (reader->Contains(full_name("offset"))) {
    640           int64 offset;
    641           TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("offset"), &offset));
    642           TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
    643           TF_RETURN_IF_ERROR(reader_->SeekOffset(offset));
    644         }
    645         return Status::OK();
    646       }
    647 
    648      private:
    649       // Sets up reader streams to read from the file at `current_file_index_`.
    650       Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    651         if (current_file_index_ >= dataset()->filenames_.size()) {
    652           return errors::InvalidArgument(
    653               "current_file_index_:", current_file_index_,
    654               " >= filenames_.size():", dataset()->filenames_.size());
    655         }
    656 
    657         // Actually move on to next file.
    658         const string& next_filename =
    659             dataset()->filenames_[current_file_index_];
    660         TF_RETURN_IF_ERROR(env->NewRandomAccessFile(next_filename, &file_));
    661         reader_.reset(
    662             new io::SequentialRecordReader(file_.get(), dataset()->options_));
    663         return Status::OK();
    664       }
    665 
    666       // Resets all reader streams.
    667       void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    668         reader_.reset();
    669         file_.reset();
    670       }
    671 
    672       mutex mu_;
    673       size_t current_file_index_ GUARDED_BY(mu_) = 0;
    674 
    675       // `reader_` will borrow the object that `file_` points to, so
    676       // we must destroy `reader_` before `file_`.
    677       std::unique_ptr<RandomAccessFile> file_ GUARDED_BY(mu_);
    678       std::unique_ptr<io::SequentialRecordReader> reader_ GUARDED_BY(mu_);
    679     };
    680 
    681     const std::vector<string> filenames_;
    682     const string compression_type_;
    683     io::RecordReaderOptions options_;
    684   };
    685 };
    686 
    687 REGISTER_KERNEL_BUILDER(Name("TFRecordDataset").Device(DEVICE_CPU),
    688                         TFRecordDatasetOp);
    689 
    690 }  // namespace
    691 
    692 }  // namespace tensorflow
    693