Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/io_ops.cc.
     17 
     18 #include <memory>
     19 #include "tensorflow/core/framework/op_kernel.h"
     20 #include "tensorflow/core/framework/reader_base.h"
     21 #include "tensorflow/core/framework/reader_base.pb.h"
     22 #include "tensorflow/core/framework/reader_op_kernel.h"
     23 #include "tensorflow/core/framework/tensor_shape.h"
     24 #include "tensorflow/core/lib/core/errors.h"
     25 #include "tensorflow/core/lib/io/buffered_inputstream.h"
     26 #include "tensorflow/core/lib/io/path.h"
     27 #include "tensorflow/core/lib/io/random_inputstream.h"
     28 #include "tensorflow/core/lib/strings/str_util.h"
     29 #include "tensorflow/core/lib/strings/strcat.h"
     30 #include "tensorflow/core/platform/env.h"
     31 #include "tensorflow/core/platform/protobuf.h"
     32 
     33 namespace tensorflow {
     34 
     35 static Status ReadEntireFile(Env* env, const string& filename,
     36                              string* contents) {
     37   std::unique_ptr<RandomAccessFile> file;
     38   TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file));
     39   io::RandomAccessInputStream input_stream(file.get());
     40   io::BufferedInputStream in(&input_stream, 1 << 20);
     41   TF_RETURN_IF_ERROR(in.ReadAll(contents));
     42   return Status::OK();
     43 }
     44 
     45 class WholeFileReader : public ReaderBase {
     46  public:
     47   WholeFileReader(Env* env, const string& node_name)
     48       : ReaderBase(strings::StrCat("WholeFileReader '", node_name, "'")),
     49         env_(env) {}
     50 
     51   Status ReadLocked(string* key, string* value, bool* produced,
     52                     bool* at_end) override {
     53     *key = current_work();
     54     TF_RETURN_IF_ERROR(ReadEntireFile(env_, *key, value));
     55     *produced = true;
     56     *at_end = true;
     57     return Status::OK();
     58   }
     59 
     60   // Stores state in a ReaderBaseState proto, since WholeFileReader has
     61   // no additional state beyond ReaderBase.
     62   Status SerializeStateLocked(string* state) override {
     63     ReaderBaseState base_state;
     64     SaveBaseState(&base_state);
     65     base_state.SerializeToString(state);
     66     return Status::OK();
     67   }
     68 
     69   Status RestoreStateLocked(const string& state) override {
     70     ReaderBaseState base_state;
     71     if (!ParseProtoUnlimited(&base_state, state)) {
     72       return errors::InvalidArgument("Could not parse state for ", name(), ": ",
     73                                      str_util::CEscape(state));
     74     }
     75     TF_RETURN_IF_ERROR(RestoreBaseState(base_state));
     76     return Status::OK();
     77   }
     78 
     79  private:
     80   Env* env_;
     81 };
     82 
     83 class WholeFileReaderOp : public ReaderOpKernel {
     84  public:
     85   explicit WholeFileReaderOp(OpKernelConstruction* context)
     86       : ReaderOpKernel(context) {
     87     Env* env = context->env();
     88     SetReaderFactory(
     89         [this, env]() { return new WholeFileReader(env, name()); });
     90   }
     91 };
     92 
     93 REGISTER_KERNEL_BUILDER(Name("WholeFileReader").Device(DEVICE_CPU),
     94                         WholeFileReaderOp);
     95 REGISTER_KERNEL_BUILDER(Name("WholeFileReaderV2").Device(DEVICE_CPU),
     96                         WholeFileReaderOp);
     97 
     98 class ReadFileOp : public OpKernel {
     99  public:
    100   using OpKernel::OpKernel;
    101   void Compute(OpKernelContext* context) override {
    102     const Tensor* input;
    103     OP_REQUIRES_OK(context, context->input("filename", &input));
    104     OP_REQUIRES(context, TensorShapeUtils::IsScalar(input->shape()),
    105                 errors::InvalidArgument(
    106                     "Input filename tensor must be scalar, but had shape: ",
    107                     input->shape().DebugString()));
    108 
    109     Tensor* output = nullptr;
    110     OP_REQUIRES_OK(context, context->allocate_output("contents",
    111                                                      TensorShape({}), &output));
    112     OP_REQUIRES_OK(context,
    113                    ReadEntireFile(context->env(), input->scalar<string>()(),
    114                                   &output->scalar<string>()()));
    115   }
    116 };
    117 
    118 REGISTER_KERNEL_BUILDER(Name("ReadFile").Device(DEVICE_CPU), ReadFileOp);
    119 
    120 class WriteFileOp : public OpKernel {
    121  public:
    122   using OpKernel::OpKernel;
    123   void Compute(OpKernelContext* context) override {
    124     const Tensor* filename_input;
    125     const Tensor* contents_input;
    126     OP_REQUIRES_OK(context, context->input("filename", &filename_input));
    127     OP_REQUIRES_OK(context, context->input("contents", &contents_input));
    128     OP_REQUIRES(context, TensorShapeUtils::IsScalar(filename_input->shape()),
    129                 errors::InvalidArgument(
    130                     "Input filename tensor must be scalar, but had shape: ",
    131                     filename_input->shape().DebugString()));
    132     OP_REQUIRES(context, TensorShapeUtils::IsScalar(contents_input->shape()),
    133                 errors::InvalidArgument(
    134                     "Contents tensor must be scalar, but had shape: ",
    135                     contents_input->shape().DebugString()));
    136     const string& filename = filename_input->scalar<string>()();
    137     const string dir = io::Dirname(filename).ToString();
    138     if (!context->env()->FileExists(dir).ok()) {
    139       OP_REQUIRES_OK(context, context->env()->RecursivelyCreateDir(dir));
    140     }
    141     OP_REQUIRES_OK(context,
    142                    WriteStringToFile(context->env(), filename,
    143                                      contents_input->scalar<string>()()));
    144   }
    145 };
    146 
    147 REGISTER_KERNEL_BUILDER(Name("WriteFile").Device(DEVICE_CPU), WriteFileOp);
    148 }  // namespace tensorflow
    149