Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/io_ops.cc.
     17 
     18 #include <memory>
     19 #include "tensorflow/core/framework/reader_base.h"
     20 #include "tensorflow/core/framework/reader_op_kernel.h"
     21 #include "tensorflow/core/lib/core/errors.h"
     22 #include "tensorflow/core/lib/io/buffered_inputstream.h"
     23 #include "tensorflow/core/lib/io/random_inputstream.h"
     24 #include "tensorflow/core/lib/io/zlib_compression_options.h"
     25 #include "tensorflow/core/lib/io/zlib_inputstream.h"
     26 #include "tensorflow/core/lib/strings/strcat.h"
     27 #include "tensorflow/core/platform/env.h"
     28 
     29 namespace tensorflow {
     30 
     31 // In the constructor hop_bytes_ is set to record_bytes_ if it was 0,
     32 // so that we will always "hop" after each read (except first).
     33 class FixedLengthRecordReader : public ReaderBase {
     34  public:
     35   FixedLengthRecordReader(const string& node_name, int64 header_bytes,
     36                           int64 record_bytes, int64 footer_bytes,
     37                           int64 hop_bytes, const string& encoding, Env* env)
     38       : ReaderBase(
     39             strings::StrCat("FixedLengthRecordReader '", node_name, "'")),
     40         header_bytes_(header_bytes),
     41         record_bytes_(record_bytes),
     42         footer_bytes_(footer_bytes),
     43         hop_bytes_(hop_bytes == 0 ? record_bytes : hop_bytes),
     44         env_(env),
     45         record_number_(0),
     46         encoding_(encoding) {}
     47 
     48   // On success:
     49   // * buffered_inputstream_ != nullptr,
     50   // * buffered_inputstream_->Tell() == header_bytes_
     51   Status OnWorkStartedLocked() override {
     52     record_number_ = 0;
     53 
     54     lookahead_cache_.clear();
     55 
     56     TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(current_work(), &file_));
     57     if (encoding_ == "ZLIB" || encoding_ == "GZIP") {
     58       const io::ZlibCompressionOptions zlib_options =
     59           encoding_ == "ZLIB" ? io::ZlibCompressionOptions::DEFAULT()
     60                               : io::ZlibCompressionOptions::GZIP();
     61       file_stream_.reset(new io::RandomAccessInputStream(file_.get()));
     62       buffered_inputstream_.reset(new io::ZlibInputStream(
     63           file_stream_.get(), static_cast<size_t>(kBufferSize),
     64           static_cast<size_t>(kBufferSize), zlib_options));
     65     } else {
     66       buffered_inputstream_.reset(
     67           new io::BufferedInputStream(file_.get(), kBufferSize));
     68     }
     69     // header_bytes_ is always skipped.
     70     TF_RETURN_IF_ERROR(buffered_inputstream_->SkipNBytes(header_bytes_));
     71 
     72     return Status::OK();
     73   }
     74 
     75   Status OnWorkFinishedLocked() override {
     76     buffered_inputstream_.reset(nullptr);
     77     return Status::OK();
     78   }
     79 
     80   Status ReadLocked(string* key, string* value, bool* produced,
     81                     bool* at_end) override {
     82     // We will always "hop" the hop_bytes_ except the first record
     83     // where record_number_ == 0
     84     if (record_number_ != 0) {
     85       if (hop_bytes_ <= lookahead_cache_.size()) {
     86         // If hop_bytes_ is smaller than the cached data we skip the
     87         // hop_bytes_ from the cache.
     88         lookahead_cache_ = lookahead_cache_.substr(hop_bytes_);
     89       } else {
     90         // If hop_bytes_ is larger than the cached data, we clean up
     91         // the cache, then skip hop_bytes_ - cache_size from the file
     92         // as the cache_size has been skipped through cache.
     93         int64 cache_size = lookahead_cache_.size();
     94         lookahead_cache_.clear();
     95         Status s = buffered_inputstream_->SkipNBytes(hop_bytes_ - cache_size);
     96         if (!s.ok()) {
     97           if (!errors::IsOutOfRange(s)) {
     98             return s;
     99           }
    100           *at_end = true;
    101           return Status::OK();
    102         }
    103       }
    104     }
    105 
    106     // Fill up lookahead_cache_ to record_bytes_ + footer_bytes_
    107     int bytes_to_read = record_bytes_ + footer_bytes_ - lookahead_cache_.size();
    108     Status s = buffered_inputstream_->ReadNBytes(bytes_to_read, value);
    109     if (!s.ok()) {
    110       value->clear();
    111       if (!errors::IsOutOfRange(s)) {
    112         return s;
    113       }
    114       *at_end = true;
    115       return Status::OK();
    116     }
    117     lookahead_cache_.append(*value, 0, bytes_to_read);
    118     value->clear();
    119 
    120     // Copy first record_bytes_ from cache to value
    121     *value = lookahead_cache_.substr(0, record_bytes_);
    122 
    123     *key = strings::StrCat(current_work(), ":", record_number_);
    124     *produced = true;
    125     ++record_number_;
    126 
    127     return Status::OK();
    128   }
    129 
    130   Status ResetLocked() override {
    131     record_number_ = 0;
    132     buffered_inputstream_.reset(nullptr);
    133     lookahead_cache_.clear();
    134     return ReaderBase::ResetLocked();
    135   }
    136 
    137   // TODO(josh11b): Implement serializing and restoring the state.
    138 
    139  private:
    140   enum { kBufferSize = 256 << 10 /* 256 kB */ };
    141   const int64 header_bytes_;
    142   const int64 record_bytes_;
    143   const int64 footer_bytes_;
    144   const int64 hop_bytes_;
    145   // The purpose of lookahead_cache_ is to allows "one-pass" processing
    146   // without revisit previous processed data of the stream. This is needed
    147   // because certain compression like zlib does not allow random access
    148   // or even obtain the uncompressed stream size before hand.
    149   // The max size of the lookahead_cache_ could be
    150   // record_bytes_ + footer_bytes_
    151   string lookahead_cache_;
    152   Env* const env_;
    153   int64 record_number_;
    154   string encoding_;
    155   // must outlive buffered_inputstream_
    156   std::unique_ptr<RandomAccessFile> file_;
    157   // must outlive buffered_inputstream_
    158   std::unique_ptr<io::RandomAccessInputStream> file_stream_;
    159   std::unique_ptr<io::InputStreamInterface> buffered_inputstream_;
    160 };
    161 
    162 class FixedLengthRecordReaderOp : public ReaderOpKernel {
    163  public:
    164   explicit FixedLengthRecordReaderOp(OpKernelConstruction* context)
    165       : ReaderOpKernel(context) {
    166     int64 header_bytes = -1, record_bytes = -1, footer_bytes = -1,
    167           hop_bytes = -1;
    168     OP_REQUIRES_OK(context, context->GetAttr("header_bytes", &header_bytes));
    169     OP_REQUIRES_OK(context, context->GetAttr("record_bytes", &record_bytes));
    170     OP_REQUIRES_OK(context, context->GetAttr("footer_bytes", &footer_bytes));
    171     OP_REQUIRES_OK(context, context->GetAttr("hop_bytes", &hop_bytes));
    172     OP_REQUIRES(context, header_bytes >= 0,
    173                 errors::InvalidArgument("header_bytes must be >= 0 not ",
    174                                         header_bytes));
    175     OP_REQUIRES(context, record_bytes >= 0,
    176                 errors::InvalidArgument("record_bytes must be >= 0 not ",
    177                                         record_bytes));
    178     OP_REQUIRES(context, footer_bytes >= 0,
    179                 errors::InvalidArgument("footer_bytes must be >= 0 not ",
    180                                         footer_bytes));
    181     OP_REQUIRES(
    182         context, hop_bytes >= 0,
    183         errors::InvalidArgument("hop_bytes must be >= 0 not ", hop_bytes));
    184     Env* env = context->env();
    185     string encoding;
    186     OP_REQUIRES_OK(context, context->GetAttr("encoding", &encoding));
    187     SetReaderFactory([this, header_bytes, record_bytes, footer_bytes, hop_bytes,
    188                       encoding, env]() {
    189       return new FixedLengthRecordReader(name(), header_bytes, record_bytes,
    190                                          footer_bytes, hop_bytes, encoding,
    191                                          env);
    192     });
    193   }
    194 };
    195 
    196 REGISTER_KERNEL_BUILDER(Name("FixedLengthRecordReader").Device(DEVICE_CPU),
    197                         FixedLengthRecordReaderOp);
    198 REGISTER_KERNEL_BUILDER(Name("FixedLengthRecordReaderV2").Device(DEVICE_CPU),
    199                         FixedLengthRecordReaderOp);
    200 
    201 }  // namespace tensorflow
    202