Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/io_ops.cc.
     17 
     18 #include "tensorflow/core/framework/op_kernel.h"
     19 #include "tensorflow/core/framework/queue_interface.h"
     20 #include "tensorflow/core/framework/reader_interface.h"
     21 #include "tensorflow/core/framework/tensor_shape.h"
     22 #include "tensorflow/core/kernels/ops_util.h"
     23 #include "tensorflow/core/lib/core/threadpool.h"
     24 #include "tensorflow/core/lib/strings/strcat.h"
     25 
     26 namespace tensorflow {
     27 
     28 class ReaderVerbSyncOpKernel : public OpKernel {
     29  public:
     30   using OpKernel::OpKernel;
     31 
     32   void Compute(OpKernelContext* context) override {
     33     ReaderInterface* reader;
     34     OP_REQUIRES_OK(context,
     35                    GetResourceFromContext(context, "reader_handle", &reader));
     36     ComputeWithReader(context, reader);
     37     reader->Unref();
     38   }
     39 
     40  protected:
     41   virtual void ComputeWithReader(OpKernelContext* context,
     42                                  ReaderInterface* reader) = 0;
     43 };
     44 
     45 class ReaderVerbAsyncOpKernel : public AsyncOpKernel {
     46  public:
     47   using AsyncOpKernel::AsyncOpKernel;
     48 
     49   explicit ReaderVerbAsyncOpKernel(OpKernelConstruction* context)
     50       : AsyncOpKernel(context),
     51         thread_pool_(new thread::ThreadPool(
     52             context->env(), ThreadOptions(),
     53             strings::StrCat("reader_thread_", SanitizeThreadSuffix(name())),
     54             1 /* num_threads */, false /* low_latency_hint */)) {}
     55 
     56   void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
     57     ReaderInterface* reader;
     58     OP_REQUIRES_OK_ASYNC(
     59         context, GetResourceFromContext(context, "reader_handle", &reader),
     60         done);
     61     thread_pool_->Schedule([this, context, reader, done]() {
     62       ComputeWithReader(context, reader);
     63       reader->Unref();
     64       done();
     65     });
     66   }
     67 
     68  protected:
     69   virtual void ComputeWithReader(OpKernelContext* context,
     70                                  ReaderInterface* reader) = 0;
     71 
     72  private:
     73   std::unique_ptr<thread::ThreadPool> thread_pool_;
     74 };
     75 
     76 class ReaderReadOp : public ReaderVerbAsyncOpKernel {
     77  public:
     78   using ReaderVerbAsyncOpKernel::ReaderVerbAsyncOpKernel;
     79 
     80   void ComputeWithReader(OpKernelContext* context,
     81                          ReaderInterface* reader) override {
     82     QueueInterface* queue;
     83     OP_REQUIRES_OK(context,
     84                    GetResourceFromContext(context, "queue_handle", &queue));
     85     core::ScopedUnref unref_me(queue);
     86     Tensor* key = nullptr;
     87     OP_REQUIRES_OK(context,
     88                    context->allocate_output("key", TensorShape({}), &key));
     89     Tensor* value = nullptr;
     90     OP_REQUIRES_OK(context,
     91                    context->allocate_output("value", TensorShape({}), &value));
     92 
     93     auto key_scalar = key->scalar<string>();
     94     auto value_scalar = value->scalar<string>();
     95     reader->Read(queue, &key_scalar(), &value_scalar(), context);
     96   }
     97 };
     98 
     99 REGISTER_KERNEL_BUILDER(Name("ReaderRead").Device(DEVICE_CPU), ReaderReadOp);
    100 REGISTER_KERNEL_BUILDER(Name("ReaderReadV2").Device(DEVICE_CPU), ReaderReadOp);
    101 
    102 class ReaderReadUpToOp : public ReaderVerbAsyncOpKernel {
    103  public:
    104   using ReaderVerbAsyncOpKernel::ReaderVerbAsyncOpKernel;
    105 
    106   void ComputeWithReader(OpKernelContext* context,
    107                          ReaderInterface* reader) override {
    108     QueueInterface* queue;
    109 
    110     const Tensor* num_records_tensor;
    111     OP_REQUIRES_OK(context, context->input("num_records", &num_records_tensor));
    112     int64 num_records = num_records_tensor->scalar<int64>()();
    113 
    114     OP_REQUIRES_OK(context,
    115                    GetResourceFromContext(context, "queue_handle", &queue));
    116     core::ScopedUnref unref_me(queue);
    117 
    118     std::vector<string> keys_vec;
    119     keys_vec.reserve(num_records);
    120     std::vector<string> values_vec;
    121     values_vec.reserve(num_records);
    122 
    123     int64 num_actually_read =
    124         reader->ReadUpTo(num_records, queue, &keys_vec, &values_vec, context);
    125 
    126     OP_REQUIRES(context, num_actually_read == keys_vec.size(),
    127                 errors::InvalidArgument("num_actually_read != len(keys_vec"));
    128 
    129     OP_REQUIRES(context, num_actually_read == values_vec.size(),
    130                 errors::InvalidArgument("num_actually_read != len(values_vec"));
    131 
    132     Tensor* keys = nullptr;
    133     OP_REQUIRES_OK(context,
    134                    context->allocate_output(
    135                        "keys", TensorShape({num_actually_read}), &keys));
    136 
    137     Tensor* values = nullptr;
    138     OP_REQUIRES_OK(context,
    139                    context->allocate_output(
    140                        "values", TensorShape({num_actually_read}), &values));
    141 
    142     auto keys_t = keys->vec<string>();
    143     auto values_t = values->vec<string>();
    144     for (int i = 0; i < num_actually_read; ++i) {
    145       keys_t(i) = std::move(keys_vec[i]);
    146       values_t(i) = std::move(values_vec[i]);
    147     }
    148   }
    149 };
    150 
    151 REGISTER_KERNEL_BUILDER(Name("ReaderReadUpTo").Device(DEVICE_CPU),
    152                         ReaderReadUpToOp);
    153 REGISTER_KERNEL_BUILDER(Name("ReaderReadUpToV2").Device(DEVICE_CPU),
    154                         ReaderReadUpToOp);
    155 
    156 class ReaderNumRecordsProducedOp : public ReaderVerbSyncOpKernel {
    157  public:
    158   using ReaderVerbSyncOpKernel::ReaderVerbSyncOpKernel;
    159 
    160   void ComputeWithReader(OpKernelContext* context,
    161                          ReaderInterface* reader) override {
    162     Tensor* output = nullptr;
    163     OP_REQUIRES_OK(context, context->allocate_output("records_produced",
    164                                                      TensorShape({}), &output));
    165     output->scalar<int64>()() = reader->NumRecordsProduced();
    166   }
    167 };
    168 
    169 REGISTER_KERNEL_BUILDER(Name("ReaderNumRecordsProduced").Device(DEVICE_CPU),
    170                         ReaderNumRecordsProducedOp);
    171 REGISTER_KERNEL_BUILDER(Name("ReaderNumRecordsProducedV2").Device(DEVICE_CPU),
    172                         ReaderNumRecordsProducedOp);
    173 
    174 class ReaderNumWorkUnitsCompletedOp : public ReaderVerbSyncOpKernel {
    175  public:
    176   using ReaderVerbSyncOpKernel::ReaderVerbSyncOpKernel;
    177 
    178   void ComputeWithReader(OpKernelContext* context,
    179                          ReaderInterface* reader) override {
    180     Tensor* output = nullptr;
    181     OP_REQUIRES_OK(context, context->allocate_output("units_completed",
    182                                                      TensorShape({}), &output));
    183     output->scalar<int64>()() = reader->NumWorkUnitsCompleted();
    184   }
    185 };
    186 
    187 REGISTER_KERNEL_BUILDER(Name("ReaderNumWorkUnitsCompleted").Device(DEVICE_CPU),
    188                         ReaderNumWorkUnitsCompletedOp);
    189 REGISTER_KERNEL_BUILDER(
    190     Name("ReaderNumWorkUnitsCompletedV2").Device(DEVICE_CPU),
    191     ReaderNumWorkUnitsCompletedOp);
    192 
    193 class ReaderSerializeStateOp : public ReaderVerbSyncOpKernel {
    194  public:
    195   using ReaderVerbSyncOpKernel::ReaderVerbSyncOpKernel;
    196 
    197   void ComputeWithReader(OpKernelContext* context,
    198                          ReaderInterface* reader) override {
    199     Tensor* output = nullptr;
    200     OP_REQUIRES_OK(context,
    201                    context->allocate_output("state", TensorShape({}), &output));
    202     OP_REQUIRES_OK(context,
    203                    reader->SerializeState(&output->scalar<string>()()));
    204   }
    205 };
    206 
    207 REGISTER_KERNEL_BUILDER(Name("ReaderSerializeState").Device(DEVICE_CPU),
    208                         ReaderSerializeStateOp);
    209 REGISTER_KERNEL_BUILDER(Name("ReaderSerializeStateV2").Device(DEVICE_CPU),
    210                         ReaderSerializeStateOp);
    211 
    212 class ReaderRestoreStateOp : public ReaderVerbSyncOpKernel {
    213  public:
    214   using ReaderVerbSyncOpKernel::ReaderVerbSyncOpKernel;
    215 
    216   void ComputeWithReader(OpKernelContext* context,
    217                          ReaderInterface* reader) override {
    218     const Tensor* tensor;
    219     OP_REQUIRES_OK(context, context->input("state", &tensor));
    220     OP_REQUIRES(
    221         context, TensorShapeUtils::IsScalar(tensor->shape()),
    222         errors::InvalidArgument("Reader state must be scalar, but had shape: ",
    223                                 tensor->shape().DebugString()));
    224     OP_REQUIRES_OK(context, reader->RestoreState(tensor->scalar<string>()()));
    225   }
    226 };
    227 
    228 REGISTER_KERNEL_BUILDER(Name("ReaderRestoreState").Device(DEVICE_CPU),
    229                         ReaderRestoreStateOp);
    230 REGISTER_KERNEL_BUILDER(Name("ReaderRestoreStateV2").Device(DEVICE_CPU),
    231                         ReaderRestoreStateOp);
    232 
    233 class ReaderResetOp : public ReaderVerbSyncOpKernel {
    234  public:
    235   using ReaderVerbSyncOpKernel::ReaderVerbSyncOpKernel;
    236 
    237   void ComputeWithReader(OpKernelContext* context,
    238                          ReaderInterface* reader) override {
    239     OP_REQUIRES_OK(context, reader->Reset());
    240   }
    241 };
    242 
    243 REGISTER_KERNEL_BUILDER(Name("ReaderReset").Device(DEVICE_CPU), ReaderResetOp);
    244 REGISTER_KERNEL_BUILDER(Name("ReaderResetV2").Device(DEVICE_CPU),
    245                         ReaderResetOp);
    246 
    247 }  // namespace tensorflow
    248