Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/io_ops.cc.
     17 
     18 #include <memory>
     19 #include "tensorflow/core/framework/reader_base.h"
     20 #include "tensorflow/core/framework/reader_op_kernel.h"
     21 #include "tensorflow/core/lib/core/errors.h"
     22 #include "tensorflow/core/lib/io/inputbuffer.h"
     23 #include "tensorflow/core/lib/strings/strcat.h"
     24 #include "tensorflow/core/platform/env.h"
     25 
     26 namespace tensorflow {
     27 
     28 class TextLineReader : public ReaderBase {
     29  public:
     30   TextLineReader(const string& node_name, int skip_header_lines, Env* env)
     31       : ReaderBase(strings::StrCat("TextLineReader '", node_name, "'")),
     32         skip_header_lines_(skip_header_lines),
     33         env_(env),
     34         line_number_(0) {}
     35 
     36   Status OnWorkStartedLocked() override {
     37     line_number_ = 0;
     38     TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(current_work(), &file_));
     39 
     40     input_buffer_.reset(new io::InputBuffer(file_.get(), kBufferSize));
     41     for (; line_number_ < skip_header_lines_; ++line_number_) {
     42       string line_contents;
     43       Status status = input_buffer_->ReadLine(&line_contents);
     44       if (errors::IsOutOfRange(status)) {
     45         // We ignore an end of file error when skipping header lines.
     46         // We will end up skipping this file.
     47         return Status::OK();
     48       }
     49       TF_RETURN_IF_ERROR(status);
     50     }
     51     return Status::OK();
     52   }
     53 
     54   Status OnWorkFinishedLocked() override {
     55     input_buffer_.reset(nullptr);
     56     return Status::OK();
     57   }
     58 
     59   Status ReadLocked(string* key, string* value, bool* produced,
     60                     bool* at_end) override {
     61     Status status = input_buffer_->ReadLine(value);
     62     ++line_number_;
     63     if (status.ok()) {
     64       *key = strings::StrCat(current_work(), ":", line_number_);
     65       *produced = true;
     66       return status;
     67     }
     68     if (errors::IsOutOfRange(status)) {  // End of file, advance to the next.
     69       *at_end = true;
     70       return Status::OK();
     71     } else {  // Some other reading error
     72       return status;
     73     }
     74   }
     75 
     76   Status ResetLocked() override {
     77     line_number_ = 0;
     78     input_buffer_.reset(nullptr);
     79     return ReaderBase::ResetLocked();
     80   }
     81 
     82   // TODO(josh11b): Implement serializing and restoring the state.  Need
     83   // to create TextLineReaderState proto to store ReaderBaseState,
     84   // line_number_, and input_buffer_->Tell().
     85 
     86  private:
     87   enum { kBufferSize = 256 << 10 /* 256 kB */ };
     88   const int skip_header_lines_;
     89   Env* const env_;
     90   int64 line_number_;
     91   std::unique_ptr<RandomAccessFile> file_;  // must outlive input_buffer_
     92   std::unique_ptr<io::InputBuffer> input_buffer_;
     93 };
     94 
     95 class TextLineReaderOp : public ReaderOpKernel {
     96  public:
     97   explicit TextLineReaderOp(OpKernelConstruction* context)
     98       : ReaderOpKernel(context) {
     99     int skip_header_lines = -1;
    100     OP_REQUIRES_OK(context,
    101                    context->GetAttr("skip_header_lines", &skip_header_lines));
    102     OP_REQUIRES(context, skip_header_lines >= 0,
    103                 errors::InvalidArgument("skip_header_lines must be >= 0 not ",
    104                                         skip_header_lines));
    105     Env* env = context->env();
    106     SetReaderFactory([this, skip_header_lines, env]() {
    107       return new TextLineReader(name(), skip_header_lines, env);
    108     });
    109   }
    110 };
    111 
    112 REGISTER_KERNEL_BUILDER(Name("TextLineReader").Device(DEVICE_CPU),
    113                         TextLineReaderOp);
    114 REGISTER_KERNEL_BUILDER(Name("TextLineReaderV2").Device(DEVICE_CPU),
    115                         TextLineReaderOp);
    116 
    117 }  // namespace tensorflow
    118