1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #include <deque> 19 #include <utility> 20 21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 22 #include "tensorflow/core/framework/resource_mgr.h" 23 #include "tensorflow/core/kernels/captured_function.h" 24 #include "tensorflow/core/lib/core/errors.h" 25 #include "tensorflow/core/platform/macros.h" 26 #include "tensorflow/core/platform/mutex.h" 27 #include "tensorflow/core/platform/types.h" 28 29 namespace tensorflow { 30 31 class CriticalSection : public ResourceBase { 32 public: 33 explicit CriticalSection() : is_locked_(false) {} 34 ~CriticalSection() override { 35 // Wait for all closures to finish running. 36 mutex_lock lock(mu_); 37 while (!closures_.empty()) { 38 queue_empty_cv_.wait(lock); 39 } 40 } 41 42 private: 43 friend class ExecuteInCriticalSectionOp; 44 45 void Acquire(std::function<void()> closure) { 46 std::function<void()> next; 47 { 48 mutex_lock ml(mu_); 49 if (is_locked_) { 50 closures_.push_back(std::move(closure)); 51 } else { 52 // This branch is the common case. Avoid the queue. 53 is_locked_ = true; 54 next = std::move(closure); 55 } 56 } 57 if (next) { 58 next(); 59 } 60 } 61 62 void Release() { 63 std::function<void()> next; 64 { 65 mutex_lock ml(mu_); 66 CHECK(is_locked_); 67 if (!closures_.empty()) { 68 // if queue is not empty, start the next entry off the queue. 69 std::swap(next, closures_.front()); 70 closures_.pop_front(); 71 } else { 72 is_locked_ = false; 73 queue_empty_cv_.notify_all(); 74 } 75 } 76 if (next) { 77 next(); 78 } 79 } 80 81 string DebugString() override { 82 tf_shared_lock ml(mu_); 83 return strings::StrCat("CriticalSection(locked: ", is_locked_, 84 " queue_size: ", closures_.size(), ")"); 85 } 86 87 private: 88 mutex mu_; 89 std::deque<std::function<void()>> closures_ GUARDED_BY(mu_); 90 bool is_locked_ GUARDED_BY(mu_); 91 condition_variable queue_empty_cv_ GUARDED_BY(mu_); 92 }; 93 94 class ExecuteInCriticalSectionOp : public AsyncOpKernel { 95 public: 96 explicit ExecuteInCriticalSectionOp(OpKernelConstruction* c) 97 : AsyncOpKernel(c) { 98 OP_REQUIRES_OK(c, c->GetAttr("f", &func_)); 99 } 100 101 public: 102 void ComputeAsync(OpKernelContext* c, DoneCallback done) override { 103 CriticalSection* critical_section = nullptr; 104 OP_REQUIRES_OK_ASYNC(c, 105 LookupOrCreateResource<CriticalSection>( 106 c, HandleFromInput(c, 0), &critical_section, 107 [this, c](CriticalSection** ptr) { 108 *ptr = new CriticalSection; 109 return Status::OK(); 110 }), 111 done); 112 // No need to Unref critical_section; the Closure below will take 113 // care of the Unref associated with this execution. 114 115 auto* execution = new Closure{std::move(done), c, critical_section, &func_}; 116 execution->Start(); 117 } 118 119 private: 120 class Closure { 121 public: 122 AsyncOpKernel::DoneCallback done_; 123 OpKernelContext* ctx_; 124 CriticalSection* cs_; 125 FunctionLibraryRuntime::Handle handle_; 126 FunctionLibraryRuntime::Options opts_; 127 std::vector<Tensor> arguments_t_; 128 std::vector<Tensor> output_t_; 129 NameAttrList* func_; 130 131 explicit Closure(AsyncOpKernel::DoneCallback done, OpKernelContext* ctx, 132 CriticalSection* critical_section, NameAttrList* func) 133 : done_(std::move(done)), 134 ctx_(ctx), 135 cs_(critical_section), 136 handle_(-1), 137 func_(func) {} 138 139 ~Closure(); 140 141 void Start() { 142 // Perform ExecuteFunction isnide a separate thread to avoid 143 // having lightweight Functions be inlined in this thread. 144 // That inlining would in turn inline DoneAndDelete inside the 145 // same thread. Since DoneAndDelete can call the next 146 // ExecuteFunction in the CriticalSection, this can cause a 147 // stack overflow. 148 cs_->Acquire( 149 [this]() { (*ctx_->runner())([this]() { ExecuteFunction(); }); }); 150 } 151 152 private: 153 void ExecuteFunction(); 154 void DoneAndDelete(const Status& status); 155 }; 156 157 NameAttrList func_; 158 }; 159 160 void ExecuteInCriticalSectionOp::Closure::ExecuteFunction() { 161 // Arguments to a Function are in the order: 162 // concat(<formal arguments>, <captured arguments>) 163 OpInputList arguments; 164 Status s = ctx_->input_list("arguments", &arguments); 165 if (!s.ok()) { 166 DoneAndDelete(s); 167 return; 168 } 169 170 arguments_t_.reserve(arguments.size()); 171 for (const Tensor& t : arguments) { 172 arguments_t_.push_back(t); 173 } 174 175 auto* function_library = ctx_->function_library(); 176 s = function_library->Instantiate(func_->name(), AttrSlice(&func_->attr()), 177 &handle_); 178 if (!s.ok()) { 179 DoneAndDelete(s); 180 return; 181 } 182 183 opts_.step_id = CapturedFunction::generate_step_id(); 184 auto* step_container = 185 new ScopedStepContainer(opts_.step_id, [this](const string& name) { 186 ctx_->resource_manager()->Cleanup(name).IgnoreError(); 187 }); 188 opts_.cancellation_manager = ctx_->cancellation_manager(); 189 opts_.step_container = step_container; 190 opts_.runner = ctx_->runner(); 191 192 function_library->Run(opts_, handle_, arguments_t_, &output_t_, 193 [this](const Status& s) { DoneAndDelete(s); }); 194 } 195 196 void ExecuteInCriticalSectionOp::Closure::DoneAndDelete(const Status& status) { 197 cs_->Release(); 198 199 if (!status.ok()) { 200 ctx_->SetStatus(status); 201 } else { 202 OpOutputList output; 203 const Status s = ctx_->output_list("outputs", &output); 204 if (!s.ok()) { 205 ctx_->SetStatus(s); 206 } else if (output_t_.size() != output.size()) { 207 ctx_->SetStatus(errors::Internal( 208 "Could not set all outputs. Expected output size is ", output.size(), 209 " but function set ", output_t_.size(), " output values.")); 210 } else { 211 for (int i = 0; i < output_t_.size(); ++i) { 212 output.set(i, output_t_[i]); 213 } 214 } 215 } 216 217 delete opts_.step_container; 218 opts_.step_container = nullptr; 219 done_(); 220 cs_->Unref(); 221 delete this; 222 } 223 224 ExecuteInCriticalSectionOp::Closure::~Closure() { 225 CHECK(!opts_.step_container) 226 << "Initialized closure destroyed without calling Done"; 227 } 228 229 REGISTER_KERNEL_BUILDER(Name("ExecuteInCriticalSection").Device(DEVICE_CPU), 230 ExecuteInCriticalSectionOp); 231 232 REGISTER_KERNEL_BUILDER(Name("CriticalSectionOp").Device(DEVICE_CPU), 233 ResourceHandleOp<CriticalSection>); 234 235 // TODO(ebrevdo): Re-enable once the cross-device function execution works. 236 #if GOOGLE_CUDA 237 REGISTER_KERNEL_BUILDER(Name("ExecuteInCriticalSection") 238 .Device(DEVICE_GPU) 239 .HostMemory("critical_section"), 240 ExecuteInCriticalSectionOp); 241 REGISTER_KERNEL_BUILDER( 242 Name("CriticalSectionOp").Device(DEVICE_GPU).HostMemory("resource"), 243 ResourceHandleOp<CriticalSection>); 244 #endif // GOOGLE_CUDA 245 246 } // namespace tensorflow 247