Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <cstddef>
     17 #include <deque>
     18 #include <mutex>
     19 #include <numeric>
     20 #include <vector>
     21 
     22 #include "tensorflow/core/framework/op_kernel.h"
     23 #include "tensorflow/core/framework/resource_mgr.h"
     24 #include "tensorflow/core/framework/tensor.h"
     25 #include "tensorflow/core/framework/tensor_shape.h"
     26 #include "tensorflow/core/lib/strings/strcat.h"
     27 #include "tensorflow/core/platform/env.h"
     28 #include "tensorflow/core/platform/mutex.h"
     29 
     30 namespace tensorflow {
     31 namespace {
     32 
     33 class Buffer : public ResourceBase {
     34  public:
     35   // public types
     36   using Tuple = std::vector<Tensor>;
     37 
     38  private:
     39   // private variables
     40   std::size_t capacity_;
     41   std::size_t memory_limit_;
     42   std::size_t current_bytes_;
     43   std::mutex mu_;
     44   std::condition_variable non_empty_cond_var_;
     45   std::condition_variable full_cond_var_;
     46   std::deque<Tuple> buf_;
     47 
     48  private:
     49   // private methods
     50 
     51   // If the buffer is configured for bounded capacity, notify
     52   // waiting inserters that space is now available
     53   void notify_inserters_if_bounded(std::unique_lock<std::mutex>* lock) {
     54     if (IsBounded()) {
     55       lock->unlock();
     56       // Notify all inserters. The removal of an element
     57       // may make memory available for many inserters
     58       // to insert new elements
     59       full_cond_var_.notify_all();
     60     }
     61   }
     62 
     63   // Are there a limit number of elements or a memory limit
     64   // configued on this buffer?
     65   bool IsBounded() const { return capacity_ > 0 || memory_limit_ > 0; }
     66 
     67   bool IsCapacityFull() const { return buf_.size() >= capacity_; }
     68 
     69   bool WouldExceedMemoryLimit(std::size_t bytes) const {
     70     return bytes + current_bytes_ > memory_limit_;
     71   }
     72 
     73   std::size_t GetTupleBytes(const Tuple& tuple) {
     74     return std::accumulate(tuple.begin(), tuple.end(), 0,
     75                            [](const std::size_t& lhs, const Tensor& rhs) {
     76                              return lhs + rhs.TotalBytes();
     77                            });
     78   }
     79 
     80  public:
     81   // public methods
     82   explicit Buffer(std::size_t capacity, std::size_t memory_limit)
     83       : capacity_(capacity), memory_limit_(memory_limit), current_bytes_(0) {}
     84 
     85   // the Buffer takes ownership of the Tuple
     86   Status Put(Tuple* tuple) {
     87     std::unique_lock<std::mutex> lock(mu_);
     88 
     89     std::size_t tuple_bytes = GetTupleBytes(*tuple);
     90 
     91     // Sanity check so that we don't block for ever below
     92     if (memory_limit_ > 0 && tuple_bytes > memory_limit_) {
     93       return Status(
     94           errors::ResourceExhausted("Attempted to insert "
     95                                     "tensors with combined size of '",
     96                                     tuple_bytes,
     97                                     "' bytes into "
     98                                     "Staging Area with a memory limit of '",
     99                                     memory_limit_, "'."));
    100     }
    101 
    102     // If buffer capacity is bounded wait until elements have been removed
    103     if (IsBounded()) {
    104       full_cond_var_.wait(lock, [tuple_bytes, this]() {
    105         // If there's a memory limit, check if there's space for insertion
    106         bool memory_limit_valid =
    107             memory_limit_ > 0 ? !WouldExceedMemoryLimit(tuple_bytes) : true;
    108         // If we're configured for capacity check if there's space for insertion
    109         bool capacity_valid = capacity_ > 0 ? !IsCapacityFull() : true;
    110 
    111         // Stop waiting upon success for both conditions
    112         return capacity_valid && memory_limit_valid;
    113       });
    114     }
    115 
    116     // Update bytes in the Staging Area
    117     current_bytes_ += tuple_bytes;
    118 
    119     // Store tuple
    120     buf_.push_back(std::move(*tuple));
    121 
    122     lock.unlock();
    123     // Notify all removers. Removers
    124     // may be peeking at a specific element or waiting
    125     // for the element at the front of the deque.
    126     // As we don't know the appropriate one to wake up
    127     // we should wake them all.
    128     non_empty_cond_var_.notify_all();
    129 
    130     return Status::OK();
    131   }
    132 
    133   // Get tuple at front of the buffer
    134   void Get(Tuple* tuple) {  // TODO(zhifengc): Support cancellation.
    135     std::unique_lock<std::mutex> lock(mu_);
    136 
    137     // Wait for data if the buffer is empty
    138     non_empty_cond_var_.wait(lock, [this]() { return !buf_.empty(); });
    139 
    140     // Move data into the output tuple
    141     *tuple = std::move(buf_.front());
    142     buf_.pop_front();
    143 
    144     // Update bytes in the Staging Area
    145     current_bytes_ -= GetTupleBytes(*tuple);
    146 
    147     notify_inserters_if_bounded(&lock);
    148   }
    149 
    150   // Return tuple at index
    151   Status Peek(std::size_t index, Tuple* tuple) {
    152     std::unique_lock<std::mutex> lock(mu_);
    153 
    154     // Wait if the requested index is not available
    155     non_empty_cond_var_.wait(
    156         lock, [index, this]() { return index < this->buf_.size(); });
    157 
    158     // Place tensors in the output tuple
    159     for (const auto& tensor : buf_[index]) {
    160       tuple->push_back(tensor);
    161     }
    162 
    163     return Status::OK();
    164   }
    165 
    166   // Buffer size
    167   size_t Size() {
    168     std::unique_lock<std::mutex> lock(mu_);
    169     return buf_.size();
    170   }
    171 
    172   void Clear() {
    173     std::unique_lock<std::mutex> lock(mu_);
    174     buf_.clear();
    175     current_bytes_ = 0;
    176 
    177     notify_inserters_if_bounded(&lock);
    178   }
    179 
    180   string DebugString() override {
    181     std::unique_lock<std::mutex> lock(mu_);
    182     return strings::StrCat("Staging size: ", buf_.size());
    183   }
    184 };
    185 
    186 Status GetBuffer(OpKernelContext* ctx, const NodeDef& ndef, Buffer** buf) {
    187   auto rm = ctx->resource_manager();
    188   ContainerInfo cinfo;
    189 
    190   // Lambda for creating the Staging Area
    191   auto create_fn = [&ndef](Buffer** ret) -> Status {
    192     int64 capacity;
    193     int64 memory_limit;
    194     TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "capacity", &capacity));
    195     TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "memory_limit", &memory_limit));
    196     *ret = new Buffer(capacity, memory_limit);
    197     return Status::OK();
    198   };
    199 
    200   TF_RETURN_IF_ERROR(cinfo.Init(rm, ndef, true /* use name() */));
    201   TF_RETURN_IF_ERROR(rm->LookupOrCreate<Buffer>(cinfo.container(), cinfo.name(),
    202                                                 buf, create_fn));
    203   return Status::OK();
    204 }
    205 
    206 }  // namespace
    207 
    208 class StageOp : public OpKernel {
    209  public:
    210   explicit StageOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    211 
    212   void Compute(OpKernelContext* ctx) override {
    213     Buffer* buf = nullptr;
    214     OP_REQUIRES_OK(ctx, GetBuffer(ctx, def(), &buf));
    215     core::ScopedUnref scope(buf);
    216     Buffer::Tuple tuple;
    217     tuple.reserve(ctx->num_inputs());
    218     for (int i = 0; i < ctx->num_inputs(); ++i) {
    219       tuple.push_back(ctx->input(i));
    220     }
    221     OP_REQUIRES_OK(ctx, buf->Put(&tuple));
    222   }
    223 };
    224 
    225 REGISTER_KERNEL_BUILDER(Name("Stage").Device(DEVICE_CPU), StageOp);
    226 #if GOOGLE_CUDA
    227 REGISTER_KERNEL_BUILDER(Name("Stage").Device(DEVICE_GPU), StageOp);
    228 #endif
    229 #ifdef TENSORFLOW_USE_SYCL
    230 REGISTER_KERNEL_BUILDER(Name("Stage").Device(DEVICE_SYCL), StageOp);
    231 #endif  // TENSORFLOW_USE_SYCL
    232 
    233 class UnstageOp : public OpKernel {
    234  public:
    235   explicit UnstageOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    236 
    237   // Using this op in such a way that it blocks forever
    238   // is an error.  As such cancellation is not handled.
    239   void Compute(OpKernelContext* ctx) override {
    240     Buffer* buf = nullptr;
    241     OP_REQUIRES_OK(ctx, GetBuffer(ctx, def(), &buf));
    242     core::ScopedUnref scope(buf);
    243     Buffer::Tuple tuple;
    244 
    245     buf->Get(&tuple);
    246 
    247     OP_REQUIRES(
    248         ctx, tuple.size() == (size_t)ctx->num_outputs(),
    249         errors::InvalidArgument("Mismatch stage/unstage: ", tuple.size(),
    250                                 " vs. ", ctx->num_outputs()));
    251 
    252     for (size_t i = 0; i < tuple.size(); ++i) {
    253       ctx->set_output(i, tuple[i]);
    254     }
    255   }
    256 };
    257 
    258 REGISTER_KERNEL_BUILDER(Name("Unstage").Device(DEVICE_CPU), UnstageOp);
    259 #if GOOGLE_CUDA
    260 REGISTER_KERNEL_BUILDER(Name("Unstage").Device(DEVICE_GPU), UnstageOp);
    261 #endif
    262 #ifdef TENSORFLOW_USE_SYCL
    263 REGISTER_KERNEL_BUILDER(Name("Unstage").Device(DEVICE_SYCL), UnstageOp);
    264 #endif  // TENSORFLOW_USE_SYCL
    265 
    266 class StagePeekOp : public OpKernel {
    267  public:
    268   explicit StagePeekOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    269 
    270   // Using this op in such a way that it blocks forever
    271   // is an error.  As such cancellation is not handled.
    272   void Compute(OpKernelContext* ctx) override {
    273     Buffer* buf = nullptr;
    274     OP_REQUIRES_OK(ctx, GetBuffer(ctx, def(), &buf));
    275     core::ScopedUnref scope(buf);
    276     Buffer::Tuple tuple;
    277 
    278     std::size_t index = ctx->input(0).scalar<int>()();
    279 
    280     OP_REQUIRES_OK(ctx, buf->Peek(index, &tuple));
    281 
    282     OP_REQUIRES(
    283         ctx, tuple.size() == (size_t)ctx->num_outputs(),
    284         errors::InvalidArgument("Mismatch stage/unstage: ", tuple.size(),
    285                                 " vs. ", ctx->num_outputs()));
    286 
    287     for (size_t i = 0; i < tuple.size(); ++i) {
    288       ctx->set_output(i, tuple[i]);
    289     }
    290   }
    291 };
    292 
    293 REGISTER_KERNEL_BUILDER(Name("StagePeek").Device(DEVICE_CPU), StagePeekOp);
    294 #if GOOGLE_CUDA
    295 REGISTER_KERNEL_BUILDER(
    296     Name("StagePeek").HostMemory("index").Device(DEVICE_GPU), StagePeekOp);
    297 #endif
    298 #ifdef TENSORFLOW_USE_SYCL
    299 REGISTER_KERNEL_BUILDER(
    300     Name("StagePeek").HostMemory("index").Device(DEVICE_SYCL), StagePeekOp);
    301 #endif  // TENSORFLOW_USE_SYCL
    302 
    303 class StageSizeOp : public OpKernel {
    304  public:
    305   explicit StageSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    306 
    307   // Using this op in such a way that it blocks forever
    308   // is an error.  As such cancellation is not handled.
    309   void Compute(OpKernelContext* ctx) override {
    310     Buffer* buf = nullptr;
    311     OP_REQUIRES_OK(ctx, GetBuffer(ctx, def(), &buf));
    312     core::ScopedUnref scope(buf);
    313 
    314     // Allocate size output tensor
    315     Tensor* size = nullptr;
    316     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &size));
    317 
    318     // Set it to the actual size
    319     size->scalar<int32>().setConstant(buf->Size());
    320   }
    321 };
    322 
    323 REGISTER_KERNEL_BUILDER(Name("StageSize").Device(DEVICE_CPU), StageSizeOp);
    324 #if GOOGLE_CUDA
    325 REGISTER_KERNEL_BUILDER(Name("StageSize").HostMemory("size").Device(DEVICE_GPU),
    326                         StageSizeOp);
    327 #endif
    328 #ifdef TENSORFLOW_USE_SYCL
    329 REGISTER_KERNEL_BUILDER(
    330     Name("StageSize").HostMemory("size").Device(DEVICE_SYCL), StageSizeOp);
    331 #endif  // TENSORFLOW_USE_SYCL
    332 
    333 class StageClearOp : public OpKernel {
    334  public:
    335   explicit StageClearOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    336 
    337   // Using this op in such a way that it blocks forever
    338   // is an error.  As such cancellation is not handled.
    339   void Compute(OpKernelContext* ctx) override {
    340     Buffer* buf = nullptr;
    341     OP_REQUIRES_OK(ctx, GetBuffer(ctx, def(), &buf));
    342     core::ScopedUnref scope(buf);
    343 
    344     buf->Clear();
    345   }
    346 };
    347 
    348 REGISTER_KERNEL_BUILDER(Name("StageClear").Device(DEVICE_CPU), StageClearOp);
    349 #if GOOGLE_CUDA
    350 REGISTER_KERNEL_BUILDER(Name("StageClear").Device(DEVICE_GPU), StageClearOp);
    351 #endif
    352 #ifdef TENSORFLOW_USE_SYCL
    353 REGISTER_KERNEL_BUILDER(Name("StageClear").Device(DEVICE_SYCL), StageClearOp);
    354 #endif  // TENSORFLOW_USE_SYCL
    355 
    356 }  // namespace tensorflow
    357