1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 19 #include "tensorflow/core/framework/function.h" 20 #include "tensorflow/core/framework/op_kernel.h" 21 #include "tensorflow/core/framework/tensor.h" 22 #include "tensorflow/core/framework/tensor_shape.h" 23 #include "tensorflow/core/lib/core/threadpool.h" 24 #include "tensorflow/core/platform/mutex.h" 25 26 namespace tensorflow { 27 28 typedef Eigen::GpuDevice GPUDevice; 29 typedef Eigen::ThreadPoolDevice CPUDevice; 30 typedef FunctionLibraryRuntime::Handle FHandle; 31 typedef std::vector<Tensor> TensorVec; 32 33 namespace { 34 35 // Helper to instantiate function "func" in the library "lib". 36 Status Instantiate(FunctionLibraryRuntime* lib, const NameAttrList& func, 37 FunctionLibraryRuntime::Handle* handle) { 38 return lib->Instantiate(func.name(), AttrSlice(&func.attr()), handle); 39 } 40 41 // If "t" is a scalar of a supported type, returns t != 0 in "*v". 42 Status ToBool(gtl::ArraySlice<Tensor> t, bool* v) { 43 if (t.size() != 1) { 44 return errors::InvalidArgument( 45 "Expected a single scalar which can be converted to a boolean, got ", 46 t.size(), " tensors."); 47 } 48 if (TensorShapeUtils::IsScalar(t[0].shape())) { 49 switch (t[0].dtype()) { 50 #define CASE(T) \ 51 case DataTypeToEnum<T>::value: \ 52 *v = t[0].scalar<T>()() != 0; \ 53 break; 54 55 CASE(float); 56 CASE(double); 57 CASE(int32); 58 CASE(uint8); 59 CASE(int16); 60 CASE(int8); 61 CASE(int64); 62 #undef CASE 63 case DT_BOOL: 64 *v = t[0].scalar<bool>()(); 65 break; 66 case DT_STRING: 67 *v = !t[0].scalar<string>()().empty(); 68 break; 69 default: 70 return errors::InvalidArgument(DataTypeString(t[0].dtype()), 71 " cannot be converted to a boolean"); 72 } 73 } else { 74 *v = t[0].NumElements() > 0; 75 } 76 return Status::OK(); 77 } 78 79 // Sets "rets" to be the output of "ctx". Validates rets' types based 80 // on "kernel". 81 Status SetOutputs(const OpKernel* kernel, OpKernelContext* ctx, 82 gtl::ArraySlice<Tensor> rets) { 83 if (rets.size() != ctx->num_outputs()) { 84 return errors::Internal("Expect to produce ", ctx->num_outputs(), 85 " tensors, but only get ", rets.size()); 86 } 87 for (int i = 0; i < rets.size(); ++i) { 88 if (rets[i].dtype() != kernel->output_type(i)) { 89 return errors::Internal("Expect ", i, "-th output is of type ", 90 DataTypeString(kernel->output_type(i)), 91 " but get ", DataTypeString(rets[i].dtype())); 92 } 93 ctx->set_output(i, rets[i]); 94 } 95 return Status::OK(); 96 } 97 98 void SetRunOptions(OpKernelContext* ctx, FunctionLibraryRuntime::Options* opts, 99 bool always_collect_stats) { 100 opts->step_id = ctx->step_id(); 101 opts->rendezvous = ctx->rendezvous(); 102 opts->cancellation_manager = ctx->cancellation_manager(); 103 if (always_collect_stats) { 104 opts->stats_collector = ctx->stats_collector(); 105 } 106 opts->runner = ctx->runner(); 107 } 108 109 } // end namespace 110 111 class FunctionalIf : public AsyncOpKernel { 112 public: 113 explicit FunctionalIf(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { 114 auto lib = ctx->function_library(); 115 OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library")); 116 const NameAttrList* func; 117 OP_REQUIRES_OK(ctx, ctx->GetAttr("then_branch", &func)); 118 OP_REQUIRES_OK(ctx, Instantiate(lib, *func, &then_handle_)); 119 OP_REQUIRES_OK(ctx, ctx->GetAttr("else_branch", &func)); 120 OP_REQUIRES_OK(ctx, Instantiate(lib, *func, &else_handle_)); 121 } 122 123 ~FunctionalIf() override {} 124 125 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { 126 bool cond; 127 OP_REQUIRES_OK(ctx, ToBool({ctx->input(0)}, &cond)); 128 (new State(this, ctx, cond, done))->Start(); 129 } 130 131 private: 132 FHandle then_handle_; 133 FHandle else_handle_; 134 135 class State { 136 public: 137 State(FunctionalIf* kernel, OpKernelContext* ctx, bool cond, 138 DoneCallback done) 139 : kernel_(kernel), 140 ctx_(ctx), 141 cond_(cond), 142 done_(done), 143 lib_(CHECK_NOTNULL(ctx_->function_library())) { 144 SetRunOptions(ctx_, &opts_, true /* always_collect_stats */); 145 for (int i = 1; i < ctx_->num_inputs(); ++i) { 146 args_.push_back(ctx_->input(i)); 147 } 148 } 149 150 ~State() {} 151 152 void Start() { 153 FHandle handle = cond_ ? kernel_->then_handle_ : kernel_->else_handle_; 154 rets_.clear(); 155 lib_->Run( 156 // Evaluate one of the branch. 157 opts_, handle, args_, &rets_, 158 // Done callback 159 [this](Status s) { 160 if (s.ok()) { 161 s = SetOutputs(kernel_, ctx_, rets_); 162 } 163 ctx_->SetStatus(s); 164 auto done = done_; 165 delete this; 166 done(); 167 }); 168 } 169 170 private: 171 FunctionalIf* const kernel_; 172 OpKernelContext* const ctx_; 173 const bool cond_; 174 const DoneCallback done_; 175 FunctionLibraryRuntime* const lib_; 176 FunctionLibraryRuntime::Options opts_; 177 TensorVec args_; 178 TensorVec rets_; 179 }; 180 }; 181 182 REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_CPU), FunctionalIf); 183 REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_GPU).HostMemory("cond"), 184 FunctionalIf); 185 186 class FunctionalWhile : public AsyncOpKernel { 187 public: 188 explicit FunctionalWhile(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { 189 OP_REQUIRES_OK(ctx, ctx->GetAttr("cond", &cond_func_)); 190 OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &body_func_)); 191 } 192 193 ~FunctionalWhile() override {} 194 195 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { 196 auto lib = ctx->function_library(); 197 OP_REQUIRES_ASYNC(ctx, lib != nullptr, 198 errors::Internal("No function library"), done); 199 200 // TODO(b/37549631): Because this op has `SetIsStateful()` in its 201 // op registration, this kernel may be shared by multiple 202 // subgraphs, which have different associated 203 // `FunctionLibraryRuntime` objects and hence different `FHandle` 204 // namespaces. We currently work around this by caching the map 205 // from `FunctionLibraryRuntime*` to `FHandle` pairs for the two 206 // functions this op uses. 207 FHandle cond_handle; 208 FHandle body_handle; 209 { 210 mutex_lock l(mu_); 211 const auto iter = handles_.find(lib); 212 if (iter == handles_.end()) { 213 OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, cond_func_, &cond_handle), 214 done); 215 OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, body_func_, &body_handle), 216 done); 217 handles_[lib] = {cond_handle, body_handle}; 218 } else { 219 cond_handle = iter->second.first; 220 body_handle = iter->second.second; 221 } 222 } 223 224 (new State(this, ctx, cond_handle, body_handle, done))->Start(); 225 } 226 227 private: 228 NameAttrList cond_func_; 229 NameAttrList body_func_; 230 231 mutex mu_; 232 std::unordered_map<FunctionLibraryRuntime*, std::pair<FHandle, FHandle>> 233 handles_ GUARDED_BY(mu_); 234 235 class State { 236 public: 237 State(FunctionalWhile* kernel, OpKernelContext* ctx, FHandle cond_handle, 238 FHandle body_handle, DoneCallback done) 239 : kernel_(kernel), 240 ctx_(ctx), 241 cond_handle_(cond_handle), 242 body_handle_(body_handle), 243 done_(done), 244 lib_(CHECK_NOTNULL(ctx_->function_library())) { 245 SetRunOptions(ctx_, &opts_, false /* always_collect_stats */); 246 for (int i = 0; i < ctx_->num_inputs(); ++i) { 247 args_.push_back(ctx_->input(i)); 248 } 249 } 250 251 ~State() {} 252 253 void Start() { EvalCond(); } 254 255 private: 256 FunctionalWhile* const kernel_; 257 OpKernelContext* const ctx_; 258 const FHandle cond_handle_; 259 const FHandle body_handle_; 260 const DoneCallback done_; 261 FunctionLibraryRuntime* const lib_; 262 FunctionLibraryRuntime::Options opts_; 263 TensorVec args_; 264 TensorVec rets_; 265 266 void EvalCond() { 267 lib_->Run( 268 // Evaluate the condition. 269 opts_, cond_handle_, args_, &rets_, 270 // Done cb. 271 [this](const Status& s) { 272 if (!s.ok()) { 273 return Finish(s); 274 } 275 StartBody(); 276 }); 277 } 278 279 void StartBody() { 280 bool cond; 281 Status s = ToBool(rets_, &cond); 282 if (!s.ok()) { 283 return Finish(s); 284 } 285 if (!cond) { 286 return Finish(Status::OK()); 287 } 288 rets_.clear(); 289 lib_->Run( 290 // Evaluate the body. 291 opts_, body_handle_, args_, &rets_, 292 // Done callback 293 [this](const Status& s) { 294 if (!s.ok()) { 295 return Finish(s); 296 } 297 if (args_.size() != rets_.size()) { 298 return Finish(errors::InvalidArgument( 299 "While loop body returned ", rets_.size(), 300 " arguments. Expected: ", args_.size())); 301 } 302 args_.clear(); 303 using std::swap; 304 swap(args_, rets_); 305 EvalCond(); 306 }); 307 } 308 309 void Finish(Status s) { 310 if (s.ok()) { 311 s = SetOutputs(kernel_, ctx_, args_); 312 } 313 ctx_->SetStatus(s); 314 done_(); 315 delete this; 316 } 317 }; 318 }; 319 REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_CPU), FunctionalWhile); 320 REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_GPU), FunctionalWhile); 321 322 } // namespace tensorflow 323