1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/io_ops.cc. 17 18 #include <memory> 19 #include "tensorflow/core/framework/reader_base.h" 20 #include "tensorflow/core/framework/reader_op_kernel.h" 21 #include "tensorflow/core/lib/core/errors.h" 22 #include "tensorflow/core/lib/io/buffered_inputstream.h" 23 #include "tensorflow/core/lib/io/random_inputstream.h" 24 #include "tensorflow/core/lib/io/zlib_compression_options.h" 25 #include "tensorflow/core/lib/io/zlib_inputstream.h" 26 #include "tensorflow/core/lib/strings/strcat.h" 27 #include "tensorflow/core/platform/env.h" 28 29 namespace tensorflow { 30 31 // In the constructor hop_bytes_ is set to record_bytes_ if it was 0, 32 // so that we will always "hop" after each read (except first). 33 class FixedLengthRecordReader : public ReaderBase { 34 public: 35 FixedLengthRecordReader(const string& node_name, int64 header_bytes, 36 int64 record_bytes, int64 footer_bytes, 37 int64 hop_bytes, const string& encoding, Env* env) 38 : ReaderBase( 39 strings::StrCat("FixedLengthRecordReader '", node_name, "'")), 40 header_bytes_(header_bytes), 41 record_bytes_(record_bytes), 42 footer_bytes_(footer_bytes), 43 hop_bytes_(hop_bytes == 0 ? record_bytes : hop_bytes), 44 env_(env), 45 record_number_(0), 46 encoding_(encoding) {} 47 48 // On success: 49 // * buffered_inputstream_ != nullptr, 50 // * buffered_inputstream_->Tell() == header_bytes_ 51 Status OnWorkStartedLocked() override { 52 record_number_ = 0; 53 54 lookahead_cache_.clear(); 55 56 TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(current_work(), &file_)); 57 if (encoding_ == "ZLIB" || encoding_ == "GZIP") { 58 const io::ZlibCompressionOptions zlib_options = 59 encoding_ == "ZLIB" ? io::ZlibCompressionOptions::DEFAULT() 60 : io::ZlibCompressionOptions::GZIP(); 61 file_stream_.reset(new io::RandomAccessInputStream(file_.get())); 62 buffered_inputstream_.reset(new io::ZlibInputStream( 63 file_stream_.get(), static_cast<size_t>(kBufferSize), 64 static_cast<size_t>(kBufferSize), zlib_options)); 65 } else { 66 buffered_inputstream_.reset( 67 new io::BufferedInputStream(file_.get(), kBufferSize)); 68 } 69 // header_bytes_ is always skipped. 70 TF_RETURN_IF_ERROR(buffered_inputstream_->SkipNBytes(header_bytes_)); 71 72 return Status::OK(); 73 } 74 75 Status OnWorkFinishedLocked() override { 76 buffered_inputstream_.reset(nullptr); 77 return Status::OK(); 78 } 79 80 Status ReadLocked(string* key, string* value, bool* produced, 81 bool* at_end) override { 82 // We will always "hop" the hop_bytes_ except the first record 83 // where record_number_ == 0 84 if (record_number_ != 0) { 85 if (hop_bytes_ <= lookahead_cache_.size()) { 86 // If hop_bytes_ is smaller than the cached data we skip the 87 // hop_bytes_ from the cache. 88 lookahead_cache_ = lookahead_cache_.substr(hop_bytes_); 89 } else { 90 // If hop_bytes_ is larger than the cached data, we clean up 91 // the cache, then skip hop_bytes_ - cache_size from the file 92 // as the cache_size has been skipped through cache. 93 int64 cache_size = lookahead_cache_.size(); 94 lookahead_cache_.clear(); 95 Status s = buffered_inputstream_->SkipNBytes(hop_bytes_ - cache_size); 96 if (!s.ok()) { 97 if (!errors::IsOutOfRange(s)) { 98 return s; 99 } 100 *at_end = true; 101 return Status::OK(); 102 } 103 } 104 } 105 106 // Fill up lookahead_cache_ to record_bytes_ + footer_bytes_ 107 int bytes_to_read = record_bytes_ + footer_bytes_ - lookahead_cache_.size(); 108 Status s = buffered_inputstream_->ReadNBytes(bytes_to_read, value); 109 if (!s.ok()) { 110 value->clear(); 111 if (!errors::IsOutOfRange(s)) { 112 return s; 113 } 114 *at_end = true; 115 return Status::OK(); 116 } 117 lookahead_cache_.append(*value, 0, bytes_to_read); 118 value->clear(); 119 120 // Copy first record_bytes_ from cache to value 121 *value = lookahead_cache_.substr(0, record_bytes_); 122 123 *key = strings::StrCat(current_work(), ":", record_number_); 124 *produced = true; 125 ++record_number_; 126 127 return Status::OK(); 128 } 129 130 Status ResetLocked() override { 131 record_number_ = 0; 132 buffered_inputstream_.reset(nullptr); 133 lookahead_cache_.clear(); 134 return ReaderBase::ResetLocked(); 135 } 136 137 // TODO(josh11b): Implement serializing and restoring the state. 138 139 private: 140 enum { kBufferSize = 256 << 10 /* 256 kB */ }; 141 const int64 header_bytes_; 142 const int64 record_bytes_; 143 const int64 footer_bytes_; 144 const int64 hop_bytes_; 145 // The purpose of lookahead_cache_ is to allows "one-pass" processing 146 // without revisit previous processed data of the stream. This is needed 147 // because certain compression like zlib does not allow random access 148 // or even obtain the uncompressed stream size before hand. 149 // The max size of the lookahead_cache_ could be 150 // record_bytes_ + footer_bytes_ 151 string lookahead_cache_; 152 Env* const env_; 153 int64 record_number_; 154 string encoding_; 155 // must outlive buffered_inputstream_ 156 std::unique_ptr<RandomAccessFile> file_; 157 // must outlive buffered_inputstream_ 158 std::unique_ptr<io::RandomAccessInputStream> file_stream_; 159 std::unique_ptr<io::InputStreamInterface> buffered_inputstream_; 160 }; 161 162 class FixedLengthRecordReaderOp : public ReaderOpKernel { 163 public: 164 explicit FixedLengthRecordReaderOp(OpKernelConstruction* context) 165 : ReaderOpKernel(context) { 166 int64 header_bytes = -1, record_bytes = -1, footer_bytes = -1, 167 hop_bytes = -1; 168 OP_REQUIRES_OK(context, context->GetAttr("header_bytes", &header_bytes)); 169 OP_REQUIRES_OK(context, context->GetAttr("record_bytes", &record_bytes)); 170 OP_REQUIRES_OK(context, context->GetAttr("footer_bytes", &footer_bytes)); 171 OP_REQUIRES_OK(context, context->GetAttr("hop_bytes", &hop_bytes)); 172 OP_REQUIRES(context, header_bytes >= 0, 173 errors::InvalidArgument("header_bytes must be >= 0 not ", 174 header_bytes)); 175 OP_REQUIRES(context, record_bytes >= 0, 176 errors::InvalidArgument("record_bytes must be >= 0 not ", 177 record_bytes)); 178 OP_REQUIRES(context, footer_bytes >= 0, 179 errors::InvalidArgument("footer_bytes must be >= 0 not ", 180 footer_bytes)); 181 OP_REQUIRES( 182 context, hop_bytes >= 0, 183 errors::InvalidArgument("hop_bytes must be >= 0 not ", hop_bytes)); 184 Env* env = context->env(); 185 string encoding; 186 OP_REQUIRES_OK(context, context->GetAttr("encoding", &encoding)); 187 SetReaderFactory([this, header_bytes, record_bytes, footer_bytes, hop_bytes, 188 encoding, env]() { 189 return new FixedLengthRecordReader(name(), header_bytes, record_bytes, 190 footer_bytes, hop_bytes, encoding, 191 env); 192 }); 193 } 194 }; 195 196 REGISTER_KERNEL_BUILDER(Name("FixedLengthRecordReader").Device(DEVICE_CPU), 197 FixedLengthRecordReaderOp); 198 REGISTER_KERNEL_BUILDER(Name("FixedLengthRecordReaderV2").Device(DEVICE_CPU), 199 FixedLengthRecordReaderOp); 200 201 } // namespace tensorflow 202