1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/io_ops.cc. 17 18 #include <memory> 19 #include "tensorflow/core/framework/reader_base.h" 20 #include "tensorflow/core/framework/reader_op_kernel.h" 21 #include "tensorflow/core/lib/core/errors.h" 22 #include "tensorflow/core/lib/io/inputbuffer.h" 23 #include "tensorflow/core/lib/strings/strcat.h" 24 #include "tensorflow/core/platform/env.h" 25 26 namespace tensorflow { 27 28 class TextLineReader : public ReaderBase { 29 public: 30 TextLineReader(const string& node_name, int skip_header_lines, Env* env) 31 : ReaderBase(strings::StrCat("TextLineReader '", node_name, "'")), 32 skip_header_lines_(skip_header_lines), 33 env_(env), 34 line_number_(0) {} 35 36 Status OnWorkStartedLocked() override { 37 line_number_ = 0; 38 TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(current_work(), &file_)); 39 40 input_buffer_.reset(new io::InputBuffer(file_.get(), kBufferSize)); 41 for (; line_number_ < skip_header_lines_; ++line_number_) { 42 string line_contents; 43 Status status = input_buffer_->ReadLine(&line_contents); 44 if (errors::IsOutOfRange(status)) { 45 // We ignore an end of file error when skipping header lines. 46 // We will end up skipping this file. 47 return Status::OK(); 48 } 49 TF_RETURN_IF_ERROR(status); 50 } 51 return Status::OK(); 52 } 53 54 Status OnWorkFinishedLocked() override { 55 input_buffer_.reset(nullptr); 56 return Status::OK(); 57 } 58 59 Status ReadLocked(string* key, string* value, bool* produced, 60 bool* at_end) override { 61 Status status = input_buffer_->ReadLine(value); 62 ++line_number_; 63 if (status.ok()) { 64 *key = strings::StrCat(current_work(), ":", line_number_); 65 *produced = true; 66 return status; 67 } 68 if (errors::IsOutOfRange(status)) { // End of file, advance to the next. 69 *at_end = true; 70 return Status::OK(); 71 } else { // Some other reading error 72 return status; 73 } 74 } 75 76 Status ResetLocked() override { 77 line_number_ = 0; 78 input_buffer_.reset(nullptr); 79 return ReaderBase::ResetLocked(); 80 } 81 82 // TODO(josh11b): Implement serializing and restoring the state. Need 83 // to create TextLineReaderState proto to store ReaderBaseState, 84 // line_number_, and input_buffer_->Tell(). 85 86 private: 87 enum { kBufferSize = 256 << 10 /* 256 kB */ }; 88 const int skip_header_lines_; 89 Env* const env_; 90 int64 line_number_; 91 std::unique_ptr<RandomAccessFile> file_; // must outlive input_buffer_ 92 std::unique_ptr<io::InputBuffer> input_buffer_; 93 }; 94 95 class TextLineReaderOp : public ReaderOpKernel { 96 public: 97 explicit TextLineReaderOp(OpKernelConstruction* context) 98 : ReaderOpKernel(context) { 99 int skip_header_lines = -1; 100 OP_REQUIRES_OK(context, 101 context->GetAttr("skip_header_lines", &skip_header_lines)); 102 OP_REQUIRES(context, skip_header_lines >= 0, 103 errors::InvalidArgument("skip_header_lines must be >= 0 not ", 104 skip_header_lines)); 105 Env* env = context->env(); 106 SetReaderFactory([this, skip_header_lines, env]() { 107 return new TextLineReader(name(), skip_header_lines, env); 108 }); 109 } 110 }; 111 112 REGISTER_KERNEL_BUILDER(Name("TextLineReader").Device(DEVICE_CPU), 113 TextLineReaderOp); 114 REGISTER_KERNEL_BUILDER(Name("TextLineReaderV2").Device(DEVICE_CPU), 115 TextLineReaderOp); 116 117 } // namespace tensorflow 118