Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define EIGEN_USE_THREADS
     17 
     18 #include <deque>
     19 #include <utility>
     20 
     21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     22 #include "tensorflow/core/framework/resource_mgr.h"
     23 #include "tensorflow/core/kernels/captured_function.h"
     24 #include "tensorflow/core/lib/core/errors.h"
     25 #include "tensorflow/core/platform/macros.h"
     26 #include "tensorflow/core/platform/mutex.h"
     27 #include "tensorflow/core/platform/types.h"
     28 
     29 namespace tensorflow {
     30 
     31 class CriticalSection : public ResourceBase {
     32  public:
     33   explicit CriticalSection() : is_locked_(false) {}
     34   ~CriticalSection() override {
     35     // Wait for all closures to finish running.
     36     mutex_lock lock(mu_);
     37     while (!closures_.empty()) {
     38       queue_empty_cv_.wait(lock);
     39     }
     40   }
     41 
     42  private:
     43   friend class ExecuteInCriticalSectionOp;
     44 
     45   void Acquire(std::function<void()> closure) {
     46     std::function<void()> next;
     47     {
     48       mutex_lock ml(mu_);
     49       if (is_locked_) {
     50         closures_.push_back(std::move(closure));
     51       } else {
     52         // This branch is the common case.  Avoid the queue.
     53         is_locked_ = true;
     54         next = std::move(closure);
     55       }
     56     }
     57     if (next) {
     58       next();
     59     }
     60   }
     61 
     62   void Release() {
     63     std::function<void()> next;
     64     {
     65       mutex_lock ml(mu_);
     66       CHECK(is_locked_);
     67       if (!closures_.empty()) {
     68         // if queue is not empty, start the next entry off the queue.
     69         std::swap(next, closures_.front());
     70         closures_.pop_front();
     71       } else {
     72         is_locked_ = false;
     73         queue_empty_cv_.notify_all();
     74       }
     75     }
     76     if (next) {
     77       next();
     78     }
     79   }
     80 
     81   string DebugString() override {
     82     tf_shared_lock ml(mu_);
     83     return strings::StrCat("CriticalSection(locked: ", is_locked_,
     84                            " queue_size: ", closures_.size(), ")");
     85   }
     86 
     87  private:
     88   mutex mu_;
     89   std::deque<std::function<void()>> closures_ GUARDED_BY(mu_);
     90   bool is_locked_ GUARDED_BY(mu_);
     91   condition_variable queue_empty_cv_ GUARDED_BY(mu_);
     92 };
     93 
     94 class ExecuteInCriticalSectionOp : public AsyncOpKernel {
     95  public:
     96   explicit ExecuteInCriticalSectionOp(OpKernelConstruction* c)
     97       : AsyncOpKernel(c) {
     98     OP_REQUIRES_OK(c, c->GetAttr("f", &func_));
     99   }
    100 
    101  public:
    102   void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
    103     CriticalSection* critical_section = nullptr;
    104     OP_REQUIRES_OK_ASYNC(c,
    105                          LookupOrCreateResource<CriticalSection>(
    106                              c, HandleFromInput(c, 0), &critical_section,
    107                              [this, c](CriticalSection** ptr) {
    108                                *ptr = new CriticalSection;
    109                                return Status::OK();
    110                              }),
    111                          done);
    112     // No need to Unref critical_section; the Closure below will take
    113     // care of the Unref associated with this execution.
    114 
    115     auto* execution = new Closure{std::move(done), c, critical_section, &func_};
    116     execution->Start();
    117   }
    118 
    119  private:
    120   class Closure {
    121    public:
    122     AsyncOpKernel::DoneCallback done_;
    123     OpKernelContext* ctx_;
    124     CriticalSection* cs_;
    125     FunctionLibraryRuntime::Handle handle_;
    126     FunctionLibraryRuntime::Options opts_;
    127     std::vector<Tensor> arguments_t_;
    128     std::vector<Tensor> output_t_;
    129     NameAttrList* func_;
    130 
    131     explicit Closure(AsyncOpKernel::DoneCallback done, OpKernelContext* ctx,
    132                      CriticalSection* critical_section, NameAttrList* func)
    133         : done_(std::move(done)),
    134           ctx_(ctx),
    135           cs_(critical_section),
    136           handle_(-1),
    137           func_(func) {}
    138 
    139     ~Closure();
    140 
    141     void Start() {
    142       // Perform ExecuteFunction isnide a separate thread to avoid
    143       // having lightweight Functions be inlined in this thread.
    144       // That inlining would in turn inline DoneAndDelete inside the
    145       // same thread.  Since DoneAndDelete can call the next
    146       // ExecuteFunction in the CriticalSection, this can cause a
    147       // stack overflow.
    148       cs_->Acquire(
    149           [this]() { (*ctx_->runner())([this]() { ExecuteFunction(); }); });
    150     }
    151 
    152    private:
    153     void ExecuteFunction();
    154     void DoneAndDelete(const Status& status);
    155   };
    156 
    157   NameAttrList func_;
    158 };
    159 
    160 void ExecuteInCriticalSectionOp::Closure::ExecuteFunction() {
    161   // Arguments to a Function are in the order:
    162   //   concat(<formal arguments>, <captured arguments>)
    163   OpInputList arguments;
    164   Status s = ctx_->input_list("arguments", &arguments);
    165   if (!s.ok()) {
    166     DoneAndDelete(s);
    167     return;
    168   }
    169 
    170   arguments_t_.reserve(arguments.size());
    171   for (const Tensor& t : arguments) {
    172     arguments_t_.push_back(t);
    173   }
    174 
    175   auto* function_library = ctx_->function_library();
    176   s = function_library->Instantiate(func_->name(), AttrSlice(&func_->attr()),
    177                                     &handle_);
    178   if (!s.ok()) {
    179     DoneAndDelete(s);
    180     return;
    181   }
    182 
    183   opts_.step_id = CapturedFunction::generate_step_id();
    184   auto* step_container =
    185       new ScopedStepContainer(opts_.step_id, [this](const string& name) {
    186         ctx_->resource_manager()->Cleanup(name).IgnoreError();
    187       });
    188   opts_.cancellation_manager = ctx_->cancellation_manager();
    189   opts_.step_container = step_container;
    190   opts_.runner = ctx_->runner();
    191 
    192   function_library->Run(opts_, handle_, arguments_t_, &output_t_,
    193                         [this](const Status& s) { DoneAndDelete(s); });
    194 }
    195 
    196 void ExecuteInCriticalSectionOp::Closure::DoneAndDelete(const Status& status) {
    197   cs_->Release();
    198 
    199   if (!status.ok()) {
    200     ctx_->SetStatus(status);
    201   } else {
    202     OpOutputList output;
    203     const Status s = ctx_->output_list("outputs", &output);
    204     if (!s.ok()) {
    205       ctx_->SetStatus(s);
    206     } else if (output_t_.size() != output.size()) {
    207       ctx_->SetStatus(errors::Internal(
    208           "Could not set all outputs.  Expected output size is ", output.size(),
    209           " but function set ", output_t_.size(), " output values."));
    210     } else {
    211       for (int i = 0; i < output_t_.size(); ++i) {
    212         output.set(i, output_t_[i]);
    213       }
    214     }
    215   }
    216 
    217   delete opts_.step_container;
    218   opts_.step_container = nullptr;
    219   done_();
    220   cs_->Unref();
    221   delete this;
    222 }
    223 
    224 ExecuteInCriticalSectionOp::Closure::~Closure() {
    225   CHECK(!opts_.step_container)
    226       << "Initialized closure destroyed without calling Done";
    227 }
    228 
    229 REGISTER_KERNEL_BUILDER(Name("ExecuteInCriticalSection").Device(DEVICE_CPU),
    230                         ExecuteInCriticalSectionOp);
    231 
    232 REGISTER_KERNEL_BUILDER(Name("CriticalSectionOp").Device(DEVICE_CPU),
    233                         ResourceHandleOp<CriticalSection>);
    234 
    235 // TODO(ebrevdo): Re-enable once the cross-device function execution works.
    236 #if GOOGLE_CUDA
    237 REGISTER_KERNEL_BUILDER(Name("ExecuteInCriticalSection")
    238                             .Device(DEVICE_GPU)
    239                             .HostMemory("critical_section"),
    240                         ExecuteInCriticalSectionOp);
    241 REGISTER_KERNEL_BUILDER(
    242     Name("CriticalSectionOp").Device(DEVICE_GPU).HostMemory("resource"),
    243     ResourceHandleOp<CriticalSection>);
    244 #endif  // GOOGLE_CUDA
    245 
    246 }  // namespace tensorflow
    247