1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/common_runtime/function.h" 17 18 #include <deque> 19 #include <vector> 20 21 #include "absl/algorithm/container.h" 22 #include "absl/strings/str_cat.h" 23 #include "tensorflow/core/common_runtime/device.h" 24 #include "tensorflow/core/common_runtime/executor.h" 25 #include "tensorflow/core/common_runtime/executor_factory.h" 26 #include "tensorflow/core/common_runtime/graph_optimizer.h" 27 #include "tensorflow/core/common_runtime/memory_types.h" 28 #include "tensorflow/core/common_runtime/rendezvous_mgr.h" 29 #include "tensorflow/core/framework/collective.h" 30 #include "tensorflow/core/framework/function.h" 31 #include "tensorflow/core/framework/node_def.pb.h" 32 #include "tensorflow/core/framework/node_def_util.h" 33 #include "tensorflow/core/framework/op.h" 34 #include "tensorflow/core/framework/op_kernel.h" 35 #include "tensorflow/core/framework/versions.pb.h" 36 #include "tensorflow/core/graph/algorithm.h" 37 #include "tensorflow/core/graph/control_flow.h" 38 #include "tensorflow/core/graph/gradients.h" 39 #include "tensorflow/core/graph/graph_constructor.h" 40 #include "tensorflow/core/graph/optimizer_cse.h" 41 #include "tensorflow/core/lib/core/threadpool.h" 42 #include "tensorflow/core/lib/gtl/map_util.h" 43 #include "tensorflow/core/platform/macros.h" 44 45 // See core/kernels/function_ops.cc for related kernels. 46 47 namespace tensorflow { 48 49 // A few string constant used throughout this module. 50 static constexpr const char* const kArgOp = FunctionLibraryDefinition::kArgOp; 51 static constexpr const char* const kDeviceArgOp = 52 FunctionLibraryDefinition::kDeviceArgOp; 53 static constexpr const char* const kRetOp = FunctionLibraryDefinition::kRetOp; 54 static constexpr const char* const kDeviceRetOp = 55 FunctionLibraryDefinition::kDeviceRetOp; 56 static constexpr const char* const kGradientOp = 57 FunctionLibraryDefinition::kGradientOp; 58 static constexpr const char* const kNodeLabel = "Func"; 59 static constexpr const char* const kFuncAttr = 60 FunctionLibraryDefinition::kFuncAttr; 61 62 // Represents the index-th output of a node. 63 struct Endpoint { 64 Node* node; 65 int index; 66 67 // Returns the string name represents this endpoint. 68 string name() const { 69 if (index == 0) { 70 return node->name(); 71 } else { 72 return strings::StrCat(node->name(), ":", index); 73 } 74 } 75 76 DataType dtype() const { return node->output_type(index); } 77 }; 78 79 struct EndpointHash { 80 uint64 operator()(const Endpoint& x) const { 81 return Hash64(reinterpret_cast<const char*>(&x.node), sizeof(Node*), 82 x.index); 83 } 84 }; 85 86 struct EndpointEq { 87 bool operator()(const Endpoint& x, const Endpoint& y) const { 88 return (x.node == y.node) && (x.index == y.index); 89 } 90 }; 91 92 // The following Add* routines are used to add a few graph nodes while 93 // functions are transformed. 94 static Node* AddNoOp(StringPiece name, Graph* g) { 95 NodeDef ndef; 96 ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name))); 97 ndef.set_op("NoOp"); 98 Status s; 99 Node* ret = g->AddNode(ndef, &s); 100 TF_CHECK_OK(s); 101 return ret; 102 } 103 104 static Node* AddIdentity(StringPiece name, Graph* g, Endpoint input) { 105 DCHECK_LT(0, input.dtype()); 106 NodeDef ndef; 107 ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name))); 108 ndef.set_op("Identity"); 109 // NOTE(skyewm): we explicitly set the device here to address a multi-GPU 110 // performance issue where this Identity would be placed alone on a GPU, 111 // causing unnecessary device traffic. See b/122483225 for details. 112 ndef.set_device(input.node->def().device()); 113 ndef.add_input(input.name()); 114 AddNodeAttr("T", BaseType(input.dtype()), &ndef); 115 Status s; 116 Node* ret = g->AddNode(ndef, &s); 117 TF_CHECK_OK(s); 118 g->AddEdge(input.node, input.index, ret, 0); 119 return ret; 120 } 121 122 static Node* AddArg(Graph* g, DataType dtype, int index) { 123 DCHECK_LT(0, dtype); 124 DCHECK_LT(dtype, DT_FLOAT_REF); 125 NodeDef ndef; 126 ndef.set_name(g->NewName(kNodeLabel)); 127 ndef.set_op(kArgOp); 128 AddNodeAttr("T", dtype, &ndef); 129 AddNodeAttr("index", index, &ndef); 130 Status s; 131 Node* ret = g->AddNode(ndef, &s); 132 TF_CHECK_OK(s); 133 return ret; 134 } 135 136 static Node* AddRet(Graph* g, Endpoint input, int index) { 137 DCHECK_LT(0, input.dtype()); 138 DCHECK_LT(input.dtype(), DT_FLOAT_REF); 139 NodeDef ndef; 140 ndef.set_name(g->NewName(kNodeLabel)); 141 ndef.set_op(kRetOp); 142 ndef.add_input(input.name()); 143 AddNodeAttr("T", input.dtype(), &ndef); 144 AddNodeAttr("index", index, &ndef); 145 Status s; 146 Node* ret = g->AddNode(ndef, &s); 147 TF_CHECK_OK(s); 148 g->AddEdge(input.node, input.index, ret, 0); 149 return ret; 150 } 151 152 // FunctionLibraryRuntime implementation that forwards all the function calls to 153 // the base runtime implementation, and only overrides overlay lib in calls to 154 // Instantiate (if caller doesn't provide its own overlay lib). 155 // 156 // When function library runtime (FunctionLibraryRuntimeImpl specifically) 157 // instantiates function into a Graph object, it also creates an Executor for 158 // it. That executor has a pointer to the function library runtime instance, 159 // that is used to instantiate all nested function calls. 160 // 161 // If the original function was instantiated using overlay lib, we must preserve 162 // that overlay lib in the executor's function library runtime. 163 // 164 // IMPORTANT: This runtime is intended for use only in executors created for 165 // functions instantiated into a graph in FunctionLibraryRuntimeImpl. 166 class FunctionLibraryRuntimeOverlay : public FunctionLibraryRuntime { 167 public: 168 FunctionLibraryRuntimeOverlay( 169 FunctionLibraryRuntime* base_flr, 170 const FunctionLibraryDefinition* overlay_lib_def) 171 : base_flr_(base_flr), overlay_lib_def_(overlay_lib_def) {} 172 ~FunctionLibraryRuntimeOverlay() override; 173 174 Status Instantiate(const string& function_name, AttrSlice attrs, 175 const InstantiateOptions& options, 176 Handle* handle) override; 177 178 Status ReleaseHandle(Handle handle) override; 179 180 const FunctionBody* GetFunctionBody(Handle h) override; 181 182 void Run(const Options& opts, Handle handle, gtl::ArraySlice<Tensor> args, 183 std::vector<Tensor>* rets, DoneCallback done) override; 184 185 void Run(const Options& opts, Handle handle, CallFrameInterface* call_frame, 186 DoneCallback done) override; 187 188 Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) override; 189 190 bool IsStateful(const string& function_name) override; 191 192 const FunctionLibraryDefinition* GetFunctionLibraryDefinition() 193 const override; 194 195 Env* env() override; 196 Device* device() override; 197 std::function<void(std::function<void()>)>* runner() override; 198 const DeviceMgr* device_mgr() const override; 199 200 string DebugString(Handle handle) override; 201 int graph_def_version() override; 202 203 Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def, 204 std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr, 205 FunctionLibraryRuntime** out_flr) override; 206 207 private: 208 FunctionLibraryRuntime* base_flr_; // not owned 209 const FunctionLibraryDefinition* overlay_lib_def_; // not owned 210 }; 211 212 FunctionLibraryRuntimeOverlay::~FunctionLibraryRuntimeOverlay() = default; 213 214 Status FunctionLibraryRuntimeOverlay::Instantiate( 215 const string& function_name, AttrSlice attrs, 216 const InstantiateOptions& options, Handle* handle) { 217 // We automatically add overlay lib to all instantiations, if the caller 218 // doesn't provide its own override. 219 if (!options.overlay_lib && overlay_lib_def_) { 220 InstantiateOptions options_copy = options; 221 options_copy.overlay_lib = overlay_lib_def_; 222 return base_flr_->Instantiate(function_name, attrs, options_copy, handle); 223 } else { 224 return base_flr_->Instantiate(function_name, attrs, options, handle); 225 } 226 } 227 228 Status FunctionLibraryRuntimeOverlay::ReleaseHandle(Handle handle) { 229 return base_flr_->ReleaseHandle(handle); 230 } 231 232 const FunctionBody* FunctionLibraryRuntimeOverlay::GetFunctionBody(Handle h) { 233 return base_flr_->GetFunctionBody(h); 234 } 235 236 void FunctionLibraryRuntimeOverlay::Run(const Options& opts, Handle handle, 237 gtl::ArraySlice<Tensor> args, 238 std::vector<Tensor>* rets, 239 DoneCallback done) { 240 base_flr_->Run(opts, handle, args, rets, std::move(done)); 241 } 242 243 void FunctionLibraryRuntimeOverlay::Run(const Options& opts, Handle handle, 244 CallFrameInterface* call_frame, 245 DoneCallback done) { 246 base_flr_->Run(opts, handle, call_frame, std::move(done)); 247 } 248 249 Status FunctionLibraryRuntimeOverlay::CreateKernel(const NodeDef&, OpKernel**) { 250 // We don't have access base_lib_def_ in base function library runtime (aka 251 // FunctionLibraryRuntimeImpl), so to make sure we do not create kernel with 252 // wrong lib_def we just disable creation of new kernels through overlays. 253 // 254 // When we call Instantiate from the base runtime with overlay lib override, 255 // the base runtime implementation is responsible for correctly passing custom 256 // overlay lib to all kernel constructions. 257 return errors::Internal( 258 "Overlay function library runtime doesn't support kernel creation."); 259 } 260 261 bool FunctionLibraryRuntimeOverlay::IsStateful(const string& function_name) { 262 // Important: we do not forward lookup to the base FLR. 263 const OpDef* op_def; 264 const Status s = overlay_lib_def_->LookUpOpDef(function_name, &op_def); 265 return s.ok() && op_def->is_stateful(); 266 } 267 268 Env* FunctionLibraryRuntimeOverlay::env() { return base_flr_->env(); } 269 270 Device* FunctionLibraryRuntimeOverlay::device() { return base_flr_->device(); } 271 272 std::function<void(std::function<void()>)>* 273 FunctionLibraryRuntimeOverlay::runner() { 274 return base_flr_->runner(); 275 } 276 277 const DeviceMgr* FunctionLibraryRuntimeOverlay::device_mgr() const { 278 return base_flr_->device_mgr(); 279 } 280 281 const FunctionLibraryDefinition* 282 FunctionLibraryRuntimeOverlay::GetFunctionLibraryDefinition() const { 283 return overlay_lib_def_ ? overlay_lib_def_ 284 : base_flr_->GetFunctionLibraryDefinition(); 285 } 286 287 string FunctionLibraryRuntimeOverlay::DebugString(Handle handle) { 288 return base_flr_->DebugString(handle); 289 } 290 291 int FunctionLibraryRuntimeOverlay::graph_def_version() { 292 return base_flr_->graph_def_version(); 293 } 294 295 Status FunctionLibraryRuntimeOverlay::Clone( 296 std::unique_ptr<FunctionLibraryDefinition>* out_lib_def, 297 std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr, 298 FunctionLibraryRuntime** out_flr) { 299 // NOTE(ezhulenev): Cloned FunctionLibraryRuntime will be missing overlay lib, 300 // but that's ok because we anyway do not copy/clone instantiated items from 301 // the base FLR. 302 return base_flr_->Clone(out_lib_def, out_pflr, out_flr); 303 } 304 305 class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { 306 public: 307 FunctionLibraryRuntimeImpl(const DeviceMgr* dmgr, Env* env, Device* device, 308 int graph_def_version, 309 const FunctionLibraryDefinition* lib_def, 310 thread::ThreadPool* default_thread_pool, 311 const OptimizerOptions& optimizer_options, 312 CustomKernelCreator custom_kernel_creator, 313 ProcessFunctionLibraryRuntime* parent); 314 315 ~FunctionLibraryRuntimeImpl() override; 316 317 Status Instantiate(const string& function_name, AttrSlice attrs, 318 const InstantiateOptions& options, 319 Handle* handle) override; 320 321 Status ReleaseHandle(Handle handle) override; 322 323 const FunctionBody* GetFunctionBody(Handle handle) override; 324 325 Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) override; 326 327 void Run(const Options& opts, Handle handle, gtl::ArraySlice<Tensor> args, 328 std::vector<Tensor>* rets, DoneCallback done) override; 329 // NOTE(mrry): This overload is currently only implemented for local function 330 // execution. 331 // TODO(b/70346412): Implement support for remote function execution when 332 // passing a call frame. 333 void Run(const Options& opts, Handle handle, CallFrameInterface* frame, 334 DoneCallback done) override; 335 336 bool IsStateful(const string& function) override; 337 338 const FunctionLibraryDefinition* GetFunctionLibraryDefinition() 339 const override { 340 return base_lib_def_; 341 } 342 343 Device* device() override { return device_; } 344 345 std::function<void(std::function<void()>)>* runner() override { 346 return &default_runner_; 347 } 348 349 const DeviceMgr* device_mgr() const override { return device_mgr_; } 350 Env* env() override { return env_; } 351 int graph_def_version() override { return graph_def_version_; } 352 353 string DebugString(Handle h) override; 354 355 Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def, 356 std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr, 357 FunctionLibraryRuntime** out_flr) override; 358 359 private: 360 typedef FunctionLibraryRuntimeImpl ME; 361 362 const DeviceMgr* const device_mgr_; 363 Device* const device_; 364 Env* const env_; 365 const int graph_def_version_; 366 const FunctionLibraryDefinition* const base_lib_def_; 367 GraphOptimizer optimizer_; 368 const CustomKernelCreator custom_kernel_creator_; 369 Executor::Args::Runner default_runner_; 370 const string device_name_; 371 372 std::function<Status(const string&, const OpDef**)> get_func_sig_; 373 std::function<Status(const NodeDef&, OpKernel**)> create_kernel_; 374 375 mutable mutex mu_; 376 377 int next_handle_ GUARDED_BY(mu_); 378 379 // The instantiated and transformed function is encoded as a Graph 380 // object, and an executor is created for the graph. 381 struct Item { 382 uint64 instantiation_counter = 0; 383 const Graph* graph = nullptr; // Owned by exec. 384 const FunctionLibraryDefinition* overlay_lib = nullptr; // Not owned. 385 FunctionBody* func_graph = nullptr; 386 Executor* exec = nullptr; 387 FunctionLibraryRuntimeOverlay* overlay_flr = nullptr; 388 string executor_type; 389 390 ~Item() { 391 delete this->func_graph; 392 delete this->exec; 393 delete this->overlay_flr; 394 } 395 }; 396 std::unordered_map<Handle, std::unique_ptr<Item>> items_ GUARDED_BY(mu_); 397 398 ProcessFunctionLibraryRuntime* parent_ = nullptr; // not owned. 399 400 Status CreateKernel(const NodeDef& ndef, 401 const FunctionLibraryDefinition* lib_def, 402 OpKernel** kernel); 403 Status FunctionDefToBody(const FunctionDef& fdef, AttrSlice attrs, 404 const FunctionLibraryDefinition* lib_def, 405 FunctionBody** fbody); 406 Status CreateItem(Item** item); 407 Status GetOrCreateItem(LocalHandle local_handle, Item** item); 408 Status InstantiateSymbolicGradient(const NameAttrList& func, 409 const FunctionLibraryDefinition* lib_def, 410 FunctionBody** g_body); 411 bool IsLocalTarget(const InstantiateOptions& options); 412 AttrValueMap FixAttrs(const AttrSlice& attrs); 413 void RunRemote(const Options& opts, Handle handle, 414 gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets, 415 Item* item, DoneCallback done); 416 417 void ExecutorArgsFromOptions(const FunctionLibraryRuntime::Options& run_opts, 418 CallFrameInterface* frame, 419 Executor::Args* exec_args); 420 421 TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryRuntimeImpl); 422 }; 423 424 FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl( 425 const DeviceMgr* dmgr, Env* env, Device* device, int graph_def_version, 426 const FunctionLibraryDefinition* lib_def, 427 thread::ThreadPool* default_thread_pool, 428 const OptimizerOptions& optimizer_options, 429 CustomKernelCreator custom_kernel_creator, 430 ProcessFunctionLibraryRuntime* parent) 431 : device_mgr_(dmgr), 432 device_(device), 433 env_(env), 434 graph_def_version_(graph_def_version), 435 base_lib_def_(lib_def), 436 optimizer_(optimizer_options), 437 custom_kernel_creator_(std::move(custom_kernel_creator)), 438 default_runner_(nullptr), 439 device_name_(device_ == nullptr 440 ? ProcessFunctionLibraryRuntime::kDefaultFLRDevice 441 : device_->name()), 442 next_handle_(0), 443 parent_(parent) { 444 get_func_sig_ = [this](const string& op, const OpDef** sig) { 445 return base_lib_def_->LookUpOpDef(op, sig); 446 }; 447 create_kernel_ = [this](const NodeDef& ndef, OpKernel** kernel) { 448 return CreateKernel(ndef, kernel); 449 }; 450 thread::ThreadPool* pool = nullptr; 451 if (device_ != nullptr) { 452 pool = device_->tensorflow_device_thread_pool(); 453 } 454 if (pool == nullptr) { 455 pool = default_thread_pool; 456 } 457 if (pool != nullptr) { 458 default_runner_ = [pool](Executor::Args::Closure c) { 459 pool->Schedule(std::move(c)); 460 }; 461 } 462 } 463 464 FunctionLibraryRuntimeImpl::~FunctionLibraryRuntimeImpl() {} 465 466 // An asynchronous op kernel which executes an instantiated function 467 // defined in a library. 468 class CallOp : public AsyncOpKernel { 469 public: 470 CallOp(FunctionLibraryRuntime::Handle handle, OpKernelConstruction* ctx) 471 : AsyncOpKernel(ctx), handle_(handle) {} 472 473 ~CallOp() override { 474 // TODO(iga): Release the cached handle_ 475 } 476 477 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { 478 FunctionLibraryRuntime* lib = ctx->function_library(); 479 OP_REQUIRES_ASYNC(ctx, lib != nullptr, 480 errors::Internal("No function library is provided."), 481 done); 482 FunctionLibraryRuntime::Options opts; 483 opts.step_id = ctx->step_id(); 484 opts.rendezvous = ctx->rendezvous(); 485 opts.cancellation_manager = ctx->cancellation_manager(); 486 opts.step_container = ctx->step_container(); 487 opts.stats_collector = ctx->stats_collector(); 488 opts.runner = ctx->runner(); 489 opts.collective_executor = ctx->collective_executor(); 490 std::vector<Tensor> args; 491 args.reserve(ctx->num_inputs()); 492 for (int i = 0; i < ctx->num_inputs(); ++i) { 493 args.push_back(ctx->input(i)); 494 } 495 std::vector<Tensor>* rets = new std::vector<Tensor>; 496 lib->Run(opts, handle_, args, rets, 497 [ctx, done, rets](const Status& status) { 498 if (!status.ok()) { 499 ctx->SetStatus(status); 500 } else { 501 const int ret_size = static_cast<int>(rets->size()); 502 CHECK_EQ(ret_size, ctx->num_outputs()); 503 for (int i = 0; i < ret_size; ++i) { 504 ctx->set_output(i, (*rets)[i]); 505 } 506 } 507 delete rets; 508 done(); 509 }); 510 } 511 512 private: 513 FunctionLibraryRuntime::Handle handle_; 514 515 TF_DISALLOW_COPY_AND_ASSIGN(CallOp); 516 }; 517 518 const FunctionBody* FunctionLibraryRuntimeImpl::GetFunctionBody(Handle h) { 519 LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, h); 520 if (local_handle == kInvalidLocalHandle) { 521 LOG(ERROR) << "Could not find Handle: " << h 522 << " on device: " << device_name_; 523 return nullptr; 524 } 525 526 tf_shared_lock l(mu_); 527 auto iter = items_.find(local_handle); 528 CHECK(iter != items_.end()); 529 return iter->second->func_graph; 530 } 531 532 Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef, 533 OpKernel** kernel) { 534 return CreateKernel(ndef, base_lib_def_, kernel); 535 } 536 537 Status FunctionLibraryRuntimeImpl::CreateKernel( 538 const NodeDef& ndef, const FunctionLibraryDefinition* lib_def, 539 OpKernel** kernel) { 540 // If a custom kernel creator is given, try that. 541 Status s; 542 if (custom_kernel_creator_) { 543 std::unique_ptr<OpKernel> ret; 544 s = custom_kernel_creator_(this, ndef, &ret); 545 if (s.ok()) { 546 *kernel = ret.release(); 547 return s; 548 } else { 549 VLOG(2) << "Custom creator error: " << s; 550 // Falls through. 551 s = Status::OK(); 552 } 553 } 554 555 if (lib_def->Find(ndef.op()) == nullptr) { 556 // A primitive operation. Creates the registered kernel. 557 return CreateNonCachedKernel(device_, this, ndef, graph_def_version_, 558 kernel); 559 } 560 561 // Try to instantiate this function for the func/attr. Maybe it's 562 // cached already. 563 InstantiateOptions options; 564 if (lib_def != base_lib_def_) { 565 options.overlay_lib = lib_def; 566 } 567 Handle handle; 568 TF_RETURN_IF_ERROR( 569 Instantiate(ndef.op(), AttrSlice(&ndef.attr()), options, &handle)); 570 571 const FunctionBody* fbody = GetFunctionBody(handle); 572 CHECK_NOTNULL(fbody); 573 574 // TODO(zhifengc): For now, we assume int32 and resources are always on host 575 // memory and other types are always on device memory. We should do type 576 // inference over function body to derive the correct input/output memory 577 // types. 578 MemoryTypeVector input_memory_types; 579 for (const auto& t : fbody->arg_types) { 580 input_memory_types.push_back(MTypeFromDType(t)); 581 } 582 MemoryTypeVector output_memory_types; 583 for (const auto& t : fbody->ret_types) { 584 output_memory_types.push_back(MTypeFromDType(t)); 585 } 586 587 // Constructs a CallOp kernel for running the instantiated function. 588 auto device_type = DeviceType(device_->attributes().device_type()); 589 OpKernelConstruction construction( 590 device_type, device_, device_->GetAllocator(AllocatorAttributes()), &ndef, 591 &fbody->fdef.signature(), this, fbody->arg_types, input_memory_types, 592 fbody->ret_types, output_memory_types, graph_def_version_, &s); 593 if (s.ok()) { 594 *kernel = new CallOp(handle, &construction); 595 } 596 return s; 597 } 598 599 Status FunctionLibraryRuntimeImpl::FunctionDefToBody( 600 const FunctionDef& fdef, AttrSlice attrs, 601 const FunctionLibraryDefinition* lib_def, FunctionBody** fbody) { 602 if (lib_def == base_lib_def_) { 603 return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig_, fbody); 604 } else { 605 auto get_func_sig = [lib_def](const string& op, const OpDef** sig) { 606 return lib_def->LookUpOpDef(op, sig); 607 }; 608 return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig, fbody); 609 } 610 } 611 612 Status FunctionLibraryRuntimeImpl::InstantiateSymbolicGradient( 613 const NameAttrList& func, const FunctionLibraryDefinition* lib_def, 614 FunctionBody** g_body) { 615 const FunctionDef* fdef = lib_def->Find(func.name()); 616 if (fdef == nullptr) { 617 // f is a primitive op. 618 gradient::Creator creator; 619 TF_RETURN_IF_ERROR(gradient::GetOpGradientCreator(func.name(), &creator)); 620 if (creator == nullptr) { 621 return errors::InvalidArgument("No gradient is defined for ", 622 func.name()); 623 } 624 FunctionDef grad_fdef; 625 // TODO(josh11b): Should filter out the attrs from func that aren't used 626 // by the gradient function. 627 TF_RETURN_IF_ERROR(creator(AttrSlice(&func.attr()), &grad_fdef)); 628 TF_RETURN_IF_ERROR( 629 FunctionDefToBody(grad_fdef, AttrSlice(&func.attr()), lib_def, g_body)); 630 } else { 631 // f is a user-defined function. 632 InstantiateOptions options; 633 if (lib_def != base_lib_def_) { 634 options.overlay_lib = lib_def; 635 } 636 Handle f_handle; 637 TF_RETURN_IF_ERROR( 638 Instantiate(func.name(), AttrSlice(&func.attr()), options, &f_handle)); 639 const FunctionBody* f_body = GetFunctionBody(f_handle); 640 CHECK_NOTNULL(f_body); 641 *g_body = SymbolicGradient(*f_body); 642 } 643 return Status::OK(); 644 } 645 646 bool FunctionLibraryRuntimeImpl::IsLocalTarget( 647 const InstantiateOptions& options) { 648 if (device_ == nullptr) return true; 649 if (options.target.empty()) return true; 650 if (options.is_multi_device_function) return false; 651 Device* target_device; 652 if (!device_mgr_->LookupDevice(options.target, &target_device).ok()) { 653 VLOG(1) << "Not instantiating function in FLR because failed to " 654 << "find device " << options.target << " in device manager"; 655 return false; 656 } 657 if (target_device != device_) { 658 VLOG(1) << "Not instantiating function in FLR because target device " 659 << options.target 660 << " is different from FLR's device: " << device_->DebugString(); 661 return false; 662 } 663 return true; 664 } 665 666 Status FunctionLibraryRuntimeImpl::Instantiate( 667 const string& function_name, AttrSlice attrs, 668 const InstantiateOptions& options, Handle* handle) { 669 if (!IsLocalTarget(options)) { 670 return parent_->Instantiate(function_name, attrs, options, handle); 671 } 672 673 // Since this is a local target, ensure that the local `device_name_` appears 674 // in the canonical key. 675 InstantiateOptions options_copy(options); 676 options_copy.target = device_name_; 677 const string key = Canonicalize(function_name, attrs, options_copy); 678 679 { 680 mutex_lock l(mu_); 681 *handle = parent_->GetHandle(key); 682 if (*handle != kInvalidHandle) { 683 FunctionLibraryRuntime::LocalHandle handle_on_device = 684 parent_->GetHandleOnDevice(device_name_, *handle); 685 if (handle_on_device == kInvalidLocalHandle) { 686 return errors::Internal("LocalHandle not found for handle ", *handle, 687 "."); 688 } 689 auto item_handle = items_.find(handle_on_device); 690 if (item_handle == items_.end()) { 691 return errors::Internal("LocalHandle ", handle_on_device, 692 " for handle ", *handle, 693 " not found in items."); 694 } 695 ++item_handle->second->instantiation_counter; 696 return Status::OK(); 697 } 698 } 699 700 const FunctionLibraryDefinition* lib_def = 701 options.overlay_lib ? options.overlay_lib : base_lib_def_; 702 FunctionBody* fbody = nullptr; 703 if (function_name == kGradientOp) { 704 const AttrValue* f = attrs.Find(kFuncAttr); 705 if (f == nullptr) { 706 return errors::InvalidArgument("SymbolicGradient is missing attr: f"); 707 } 708 const auto& func = f->func(); 709 if (func.name() == kGradientOp) { 710 return errors::InvalidArgument("Can't take gradient of SymbolicGradient"); 711 } 712 const string grad = lib_def->FindGradient(func.name()); 713 if (!grad.empty()) { 714 return Instantiate(grad, AttrSlice(&func.attr()), options, handle); 715 } 716 TF_RETURN_IF_ERROR(InstantiateSymbolicGradient(func, lib_def, &fbody)); 717 } else { 718 const FunctionDef* fdef = lib_def->Find(function_name); 719 if (fdef == nullptr) { 720 return errors::NotFound("Function ", function_name, " is not defined."); 721 } 722 TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, attrs, lib_def, &fbody)); 723 } 724 725 LocalHandle local_handle; 726 { 727 mutex_lock l(mu_); 728 *handle = parent_->GetHandle(key); 729 if (*handle != kInvalidHandle) { 730 delete fbody; 731 local_handle = parent_->GetHandleOnDevice(device_name_, *handle); 732 ++items_[local_handle]->instantiation_counter; 733 } else { 734 *handle = parent_->AddHandle(key, device_name_, next_handle_); 735 Item* item = new Item; 736 item->func_graph = fbody; 737 item->overlay_lib = options.overlay_lib; 738 item->instantiation_counter = 1; 739 item->executor_type = ExecutorType(options, attrs); 740 if (options.overlay_lib) { 741 item->overlay_flr = 742 new FunctionLibraryRuntimeOverlay(this, options.overlay_lib); 743 } 744 local_handle = next_handle_++; 745 items_.emplace(local_handle, std::unique_ptr<Item>(item)); 746 } 747 } 748 749 if (options.create_kernels_eagerly) { 750 Item* item; 751 TF_RETURN_IF_ERROR(GetOrCreateItem(local_handle, &item)); 752 } 753 754 return Status::OK(); 755 } 756 757 Status FunctionLibraryRuntimeImpl::ReleaseHandle(Handle handle) { 758 LocalHandle h = parent_->GetHandleOnDevice(device_name_, handle); 759 if (h == kInvalidLocalHandle) { 760 return parent_->ReleaseHandle(handle); 761 } 762 763 std::unique_ptr<Item> item_to_delete; 764 Status parent_status; 765 { 766 mutex_lock l(mu_); 767 auto it = items_.find(h); 768 if (it == items_.end()) { 769 return errors::Internal( 770 "Inconsistent FunctionLibraryRuntime. Expected to find an item for " 771 "handle ", 772 h, " but found none"); 773 } 774 std::unique_ptr<Item>& item = it->second; 775 --item->instantiation_counter; 776 if (item->instantiation_counter == 0) { 777 // We don't simply erase h's item because that would trigger 778 // item destruction while holding mu_. Item destruction can 779 // trigger graph destruction. If the graph contains kernels like 780 // CallOp or PartitionCallOp, their destructors will release cached 781 // function handles, resulting in deadlock here. 782 item_to_delete = std::move(item); 783 items_.erase(h); 784 parent_status = parent_->RemoveHandle(handle); 785 } 786 } 787 return parent_status; 788 } 789 790 void DumpGraph(StringPiece label, const Graph* g) { 791 // TODO(zhifengc): Change Graph to record #nodes. 792 VLOG(1) << "Graph " << label << " #nodes " << g->num_nodes() << " #edges " 793 << g->num_edges(); 794 if (VLOG_IS_ON(2)) { 795 for (const auto& line : str_util::Split(DebugString(g), '\n')) { 796 VLOG(2) << "|| " << line; 797 } 798 } 799 } 800 801 void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g, 802 const GraphOptimizer::Options& graph_optimizer_options) { 803 OptimizerOptions opts; 804 opts.set_do_common_subexpression_elimination(true); 805 opts.set_do_function_inlining(true); 806 opts.set_do_constant_folding(true); 807 GraphOptimizer optimizer(opts); 808 optimizer.Optimize(lib, lib->env(), lib->device(), g, 809 graph_optimizer_options); 810 } 811 812 void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g) { 813 OptimizeGraph(lib, g, GraphOptimizer::Options()); 814 } 815 816 namespace { 817 // Removes all stateless nodes that do not contribute to a return 818 // value from the function body. Unlike `RemoveDeadNodes()`, which is 819 // triggered by `OptimizerOptions.do_function_inlining`, this pass 820 // ignores the SINK node, from which (by definition) all nodes are 821 // reverse reachable, and preserves all nodes that are reachable from 822 // control output nodes. 823 // 824 // TODO(ezhulenev, skyewm): Function body should not have special treatment of 825 // stateful ops, graph should encode nodes that must execute with `control_ret` 826 // and `control_output`. 827 void PruneFunctionBody(const FunctionDef& fdef, Graph* g) { 828 VLOG(2) << "Pruning function body: function_name=" << fdef.signature().name(); 829 830 // `control_ret` nodes must be always executed. 831 std::unordered_set<StringPiece, StringPieceHasher> control_ret_nodes; 832 for (const auto& control_ret : fdef.control_ret()) { 833 control_ret_nodes.insert(control_ret.second); 834 } 835 836 std::unordered_set<const Node*> nodes; 837 for (auto n : g->nodes()) { 838 // NOTE(mrry): "_Retval" nodes are stateful, and so will be added 839 // to the seed set of `nodes`. "_Arg" nodes are also stateful, but we 840 // specifically exclude them as seeds, to avoid unconditionally executing 841 // unused argument nodes (e.g. in a function like `lambda x, y: y`). 842 // TODO(mrry): Investigate whether the `n->IsControlFlow()` test is 843 // still needed. It would be preferable to prune entire loops and/or 844 // conditionals if they are not used in the graph. 845 if (n->IsControlFlow() || 846 (n->op_def().is_stateful() && n->type_string() != kArgOp) || 847 (control_ret_nodes.find(n->name()) != control_ret_nodes.end())) { 848 nodes.insert(n); 849 } 850 } 851 bool changed = PruneForReverseReachability(g, std::move(nodes)); 852 if (changed) { 853 FixupSourceAndSinkEdges(g); 854 } 855 } 856 } // namespace 857 858 Status FunctionLibraryRuntimeImpl::CreateItem(Item** item) { 859 const FunctionBody* fbody; 860 const FunctionLibraryDefinition* lib_def; 861 string executor_type; 862 { 863 tf_shared_lock l(mu_); 864 fbody = (*item)->func_graph; 865 lib_def = (*item)->overlay_lib; 866 executor_type = (*item)->executor_type; 867 } 868 if (!lib_def) { 869 lib_def = base_lib_def_; 870 } 871 std::unique_ptr<Graph> g(new Graph(lib_def)); 872 CopyGraph(*fbody->graph, g.get()); 873 874 PruneFunctionBody(fbody->fdef, g.get()); 875 optimizer_.Optimize(this, env(), device(), &g, /*shape_map=*/nullptr); 876 TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device()->device_type()), 877 device()->name(), g.get())); 878 879 // Creates an executor based on the g. This must be done without 880 // holding mu_ because create_kernel_ calls back into the library. 881 LocalExecutorParams params; 882 params.device = device_; 883 params.function_library = 884 (*item)->overlay_flr 885 ? static_cast<FunctionLibraryRuntime*>((*item)->overlay_flr) 886 : static_cast<FunctionLibraryRuntime*>(this); 887 if (lib_def == base_lib_def_) { 888 params.create_kernel = create_kernel_; 889 } else { 890 params.create_kernel = [this, lib_def](const NodeDef& ndef, 891 OpKernel** kernel) { 892 return CreateKernel(ndef, lib_def, kernel); 893 }; 894 } 895 params.delete_kernel = [](OpKernel* kernel) { 896 DeleteNonCachedKernel(kernel); 897 }; 898 Graph* graph = g.get(); 899 std::unique_ptr<Executor> exec; 900 TF_RETURN_IF_ERROR(NewExecutor(executor_type, params, std::move(g), &exec)); 901 { 902 // Guard item since it is already inserted in items_. 903 mutex_lock l(mu_); 904 if ((*item)->exec == nullptr) { 905 (*item)->graph = graph; 906 (*item)->exec = exec.release(); 907 } 908 } 909 return Status::OK(); 910 } 911 912 Status FunctionLibraryRuntimeImpl::GetOrCreateItem(LocalHandle local_handle, 913 Item** item) { 914 { 915 tf_shared_lock l(mu_); 916 auto iter = items_.find(local_handle); 917 if (iter == items_.end()) { 918 return errors::Internal("Local function handle ", local_handle, 919 " is not valid. Likely an internal error."); 920 } 921 *item = iter->second.get(); 922 if ((*item)->exec != nullptr) { 923 return Status::OK(); 924 } 925 } 926 // NOTE: We need to call CreateItem out of mu_ because creating an 927 // executor needs to call CreateKernel. 928 return CreateItem(item); 929 } 930 931 void FunctionLibraryRuntimeImpl::ExecutorArgsFromOptions( 932 const FunctionLibraryRuntime::Options& run_opts, CallFrameInterface* frame, 933 Executor::Args* exec_args) { 934 // Inherit the step_id from the caller. 935 exec_args->step_id = run_opts.step_id; 936 exec_args->rendezvous = run_opts.rendezvous; 937 exec_args->stats_collector = run_opts.stats_collector; 938 exec_args->cancellation_manager = run_opts.cancellation_manager; 939 exec_args->step_container = run_opts.step_container; 940 if (run_opts.runner) { 941 exec_args->runner = *run_opts.runner; 942 } else { 943 exec_args->runner = default_runner_; 944 } 945 exec_args->collective_executor = run_opts.collective_executor; 946 exec_args->call_frame = frame; 947 } 948 949 void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle, 950 gtl::ArraySlice<Tensor> args, 951 std::vector<Tensor>* rets, 952 Item* item, DoneCallback done) { 953 string target_device = parent_->GetDeviceName(handle); 954 string source_device = opts.source_device; 955 Rendezvous* rendezvous = opts.rendezvous; 956 DeviceContext* device_context; 957 Status s = parent_->GetDeviceContext(target_device, &device_context); 958 if (!s.ok()) { 959 done(s); 960 return; 961 } 962 int64 src_incarnation, target_incarnation; 963 s = parent_->GetDeviceIncarnation(source_device, &src_incarnation); 964 s.Update(parent_->GetDeviceIncarnation(target_device, &target_incarnation)); 965 if (!s.ok()) { 966 done(s); 967 return; 968 } 969 970 const FunctionBody* fbody = GetFunctionBody(handle); 971 FunctionCallFrame* frame = 972 new FunctionCallFrame(fbody->arg_types, fbody->ret_types); 973 Executor::Args* exec_args = new Executor::Args; 974 ExecutorArgsFromOptions(opts, frame, exec_args); 975 976 std::vector<AllocatorAttributes> args_alloc_attrs, rets_alloc_attrs; 977 args_alloc_attrs.reserve(fbody->arg_types.size()); 978 rets_alloc_attrs.reserve(fbody->ret_types.size()); 979 // Note: Functions assume that int32's are always on host memory. 980 for (const auto& arg_type : fbody->arg_types) { 981 AllocatorAttributes arg_alloc_attrs; 982 if (MTypeFromDType(arg_type) == HOST_MEMORY) { 983 arg_alloc_attrs.set_on_host(true); 984 } 985 args_alloc_attrs.push_back(arg_alloc_attrs); 986 } 987 for (const auto& ret_type : fbody->ret_types) { 988 AllocatorAttributes ret_alloc_attrs; 989 if (MTypeFromDType(ret_type) == HOST_MEMORY) { 990 ret_alloc_attrs.set_on_host(true); 991 } 992 rets_alloc_attrs.push_back(ret_alloc_attrs); 993 } 994 995 bool allow_dead_tensors = opts.allow_dead_tensors; 996 997 // The ProcFLR sends the arguments to the function from the source_device to 998 // the target_device. So here we receive those arguments. Similarly, when the 999 // computation is done and stored in *rets, we send the return values back 1000 // to the source_device (caller) so that the ProcFLR can receive them later. 1001 std::vector<Tensor>* remote_args = new std::vector<Tensor>; 1002 ProcessFunctionLibraryRuntime::ReceiveTensorsAsync( 1003 source_device, target_device, "arg_", src_incarnation, args.size(), 1004 device_context, args_alloc_attrs, rendezvous, remote_args, 1005 [frame, remote_args, item, source_device, target_device, 1006 target_incarnation, rendezvous, device_context, rets, done, exec_args, 1007 rets_alloc_attrs, allow_dead_tensors](const Status& status) { 1008 Status s = status; 1009 if (s.ok()) { 1010 s = frame->SetArgs(*remote_args); 1011 } 1012 if (!s.ok()) { 1013 delete frame; 1014 delete remote_args; 1015 delete exec_args; 1016 done(s); 1017 return; 1018 } 1019 item->exec->RunAsync( 1020 *exec_args, 1021 [frame, rets, done, source_device, target_device, 1022 target_incarnation, rendezvous, device_context, remote_args, 1023 rets_alloc_attrs, allow_dead_tensors](const Status& status) { 1024 Status s = status; 1025 if (s.ok()) { 1026 s = frame->ConsumeRetvals(rets, allow_dead_tensors); 1027 } 1028 delete frame; 1029 if (!s.ok()) { 1030 delete remote_args; 1031 done(s); 1032 return; 1033 } 1034 s = ProcessFunctionLibraryRuntime::SendTensors( 1035 target_device, source_device, "ret_", target_incarnation, 1036 *rets, device_context, rets_alloc_attrs, rendezvous); 1037 delete remote_args; 1038 done(s); 1039 }); 1040 delete exec_args; 1041 }); 1042 } 1043 1044 void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, 1045 gtl::ArraySlice<Tensor> args, 1046 std::vector<Tensor>* rets, 1047 DoneCallback done) { 1048 if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) { 1049 done(errors::Cancelled("")); 1050 return; 1051 } 1052 Options run_opts = opts; 1053 if (opts.create_rendezvous) { 1054 Rendezvous* rendezvous = new IntraProcessRendezvous(device_mgr_); 1055 run_opts.rendezvous = rendezvous; 1056 run_opts.create_rendezvous = false; 1057 done = [done, rendezvous](const Status& status) { 1058 rendezvous->Unref(); 1059 done(status); 1060 }; 1061 } 1062 1063 LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle); 1064 if (local_handle == kInvalidLocalHandle) { 1065 parent_->Run(run_opts, handle, args, rets, done); 1066 return; 1067 } 1068 1069 if (run_opts.runner == nullptr) { 1070 run_opts.runner = &default_runner_; 1071 } 1072 DCHECK(run_opts.runner != nullptr); 1073 1074 Item* item = nullptr; 1075 Status s = GetOrCreateItem(local_handle, &item); 1076 if (!s.ok()) { 1077 done(s); 1078 return; 1079 } 1080 1081 if (run_opts.remote_execution) { 1082 // NOTE(mrry): `RunRemote()` will set `exec_args->call_frame` for us. 1083 RunRemote(run_opts, handle, args, rets, item, done); 1084 return; 1085 } 1086 1087 const FunctionBody* fbody = GetFunctionBody(handle); 1088 FunctionCallFrame* frame = 1089 new FunctionCallFrame(fbody->arg_types, fbody->ret_types); 1090 s = frame->SetArgs(args); 1091 if (!s.ok()) { 1092 delete frame; 1093 done(s); 1094 return; 1095 } 1096 1097 Executor::Args exec_args; 1098 ExecutorArgsFromOptions(run_opts, frame, &exec_args); 1099 1100 bool allow_dead_tensors = run_opts.allow_dead_tensors; 1101 item->exec->RunAsync( 1102 // Executor args 1103 exec_args, 1104 // Done callback. 1105 [frame, rets, done, allow_dead_tensors](const Status& status) { 1106 Status s = status; 1107 if (s.ok()) { 1108 s = frame->ConsumeRetvals(rets, allow_dead_tensors); 1109 } 1110 delete frame; 1111 done(s); 1112 }); 1113 } 1114 1115 void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, 1116 CallFrameInterface* frame, 1117 DoneCallback done) { 1118 if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) { 1119 done(errors::Cancelled("")); 1120 return; 1121 } 1122 LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle); 1123 if (local_handle == kInvalidLocalHandle || opts.remote_execution) { 1124 done(errors::Unimplemented("Remote calling with CallFrameInterface")); 1125 return; 1126 } 1127 1128 Options run_opts = opts; 1129 if (opts.create_rendezvous) { 1130 Rendezvous* rendezvous = new IntraProcessRendezvous(device_mgr_); 1131 run_opts.rendezvous = rendezvous; 1132 run_opts.create_rendezvous = false; 1133 done = std::bind( 1134 [rendezvous](DoneCallback done, 1135 // Begin unbound arguments. 1136 const Status& status) { 1137 rendezvous->Unref(); 1138 done(status); 1139 }, 1140 std::move(done), std::placeholders::_1); 1141 } 1142 1143 Item* item = nullptr; 1144 Status s = GetOrCreateItem(local_handle, &item); 1145 if (!s.ok()) { 1146 done(s); 1147 return; 1148 } 1149 if (run_opts.runner == nullptr) { 1150 run_opts.runner = &default_runner_; 1151 } 1152 DCHECK(run_opts.runner != nullptr); 1153 1154 Executor::Args exec_args; 1155 ExecutorArgsFromOptions(run_opts, frame, &exec_args); 1156 item->exec->RunAsync(exec_args, std::move(done)); 1157 } 1158 1159 bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) { 1160 const OpDef* op_def; 1161 const Status s = base_lib_def_->LookUpOpDef(func, &op_def); 1162 return s.ok() && op_def->is_stateful(); 1163 } 1164 1165 string FunctionLibraryRuntimeImpl::DebugString(Handle handle) { 1166 Item* item = nullptr; 1167 LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle); 1168 Status s = GetOrCreateItem(local_handle, &item); 1169 if (s.ok()) { 1170 return tensorflow::DebugString(item->graph); 1171 } else { 1172 return s.ToString(); 1173 } 1174 } 1175 1176 Status FunctionLibraryRuntimeImpl::Clone( 1177 std::unique_ptr<FunctionLibraryDefinition>* out_lib_def, 1178 std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr, 1179 FunctionLibraryRuntime** out_flr) { 1180 TF_RETURN_IF_ERROR( 1181 parent_->Clone(env_, graph_def_version_, optimizer_.options(), 1182 custom_kernel_creator_, out_lib_def, out_pflr)); 1183 *out_flr = (*out_pflr)->GetFLR(device_->name()); 1184 if (out_flr != nullptr) { 1185 return Status::OK(); 1186 } else { 1187 return errors::Internal("Cloning FunctionLibraryRuntime failed."); 1188 } 1189 } 1190 1191 namespace { 1192 1193 struct CustomCreatorSingleton { 1194 mutex mu; 1195 CustomKernelCreator custom_creator = nullptr; 1196 1197 void Set(CustomKernelCreator cb) { 1198 mutex_lock l(mu); 1199 custom_creator = std::move(cb); 1200 } 1201 1202 CustomKernelCreator Get() { 1203 mutex_lock l(mu); 1204 return custom_creator; 1205 } 1206 }; 1207 1208 CustomCreatorSingleton* GetCustomCreatorSingleton() { 1209 static CustomCreatorSingleton* ccs = new CustomCreatorSingleton; 1210 return ccs; 1211 } 1212 1213 } // namespace 1214 1215 void RegisterDefaultCustomKernelCreator(CustomKernelCreator cb) { 1216 GetCustomCreatorSingleton()->Set(std::move(cb)); 1217 } 1218 1219 std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime( 1220 const DeviceMgr* device_mgr, Env* env, Device* device, 1221 int graph_def_version, const FunctionLibraryDefinition* lib_def, 1222 thread::ThreadPool* thread_pool, const OptimizerOptions& optimizer_options, 1223 CustomKernelCreator custom_kernel_creator, 1224 ProcessFunctionLibraryRuntime* parent) { 1225 return std::unique_ptr<FunctionLibraryRuntime>(new FunctionLibraryRuntimeImpl( 1226 device_mgr, env, device, graph_def_version, lib_def, thread_pool, 1227 optimizer_options, std::move(custom_kernel_creator), parent)); 1228 } 1229 1230 std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime( 1231 const DeviceMgr* device_mgr, Env* env, Device* device, 1232 int graph_def_version, const FunctionLibraryDefinition* lib_def, 1233 thread::ThreadPool* thread_pool, const OptimizerOptions& optimizer_options, 1234 ProcessFunctionLibraryRuntime* parent) { 1235 return NewFunctionLibraryRuntime(device_mgr, env, device, graph_def_version, 1236 lib_def, thread_pool, optimizer_options, 1237 GetCustomCreatorSingleton()->Get(), parent); 1238 } 1239 1240 bool RemoveDeadNodes(Graph* g) { 1241 VLOG(2) << "Removing dead nodes"; 1242 std::unordered_set<const Node*> nodes; 1243 for (auto n : g->nodes()) { 1244 if (n->IsSource() || n->IsSink() || n->IsControlFlow() || 1245 n->op_def().is_stateful()) { 1246 nodes.insert(n); 1247 } 1248 } 1249 return PruneForReverseReachability(g, std::move(nodes)); 1250 } 1251 1252 namespace { 1253 // If 'edges' contains only 1 non-control edge, returns it. Otherwise, 1254 // returns a nullptr. 1255 const Edge* GetTheOnlyDataEdge(const EdgeSet& edges) { 1256 const Edge* ret = nullptr; 1257 for (const Edge* e : edges) { 1258 if (e->IsControlEdge() || ret) { 1259 // Don't touch it if there is a control edge. 1260 return nullptr; 1261 } 1262 if (IsRefType(e->src()->output_type(e->src_output()))) { 1263 // Don't touch it if the identity node is effectively de-reffing 1264 // a ref. 1265 return nullptr; 1266 } 1267 if (IsRecv(e->src()) || IsSwitch(e->src())) { 1268 // Don't touch it if the identity is introduced for control flow. 1269 // Recv disables all its successors if it receives a dead signal. 1270 // When Recv has an outgoing control edge, the current executor 1271 // would not disable the destination. The current solution (see 1272 // graph_partition.cc) is to add an identity after Recv and change 1273 // the control edge to be from this identity node. So the identity 1274 // can't be removed. 1275 return nullptr; 1276 } 1277 ret = e; 1278 } 1279 return ret; 1280 } 1281 } // end namespace 1282 1283 bool RemoveIdentityNodes(Graph* g) { 1284 VLOG(2) << "Removing identity nodes"; 1285 bool removed_any = false; 1286 gtl::InlinedVector<Node*, 8> matches; 1287 for (Node* n : g->nodes()) { 1288 if (!n->IsIdentity()) continue; 1289 if (!GetTheOnlyDataEdge(n->in_edges())) continue; 1290 1291 // Some identity nodes are used as sink nodes to give names to output 1292 // tensors. These nodes are not going to be executed unless they are in the 1293 // fetch set. But if they are in the fetch set we don't want to remove them. 1294 if (n->out_edges().empty()) continue; 1295 1296 matches.push_back(n); 1297 } 1298 if (!matches.empty()) { 1299 for (Node* n : matches) { 1300 const Edge* in = GetTheOnlyDataEdge(n->in_edges()); 1301 for (const Edge* out : n->out_edges()) { 1302 if (out->IsControlEdge()) { 1303 g->AddControlEdge(in->src(), out->dst()); 1304 } else { 1305 g->AddEdge(in->src(), in->src_output(), out->dst(), out->dst_input()); 1306 } 1307 } 1308 VLOG(2) << "Remove Identity: " << n->DebugString(); 1309 g->RemoveNode(n); 1310 removed_any = true; 1311 } 1312 } 1313 return removed_any; 1314 } 1315 1316 bool RemoveListArrayConverter(Graph* g) { 1317 VLOG(2) << "Removing list array converter"; 1318 gtl::InlinedVector<Node*, 8> matches; 1319 for (Node* n : g->nodes()) { 1320 if ((n->type_string() == "_ListToArray") || 1321 (n->type_string() == "_ArrayToList")) { 1322 matches.push_back(n); 1323 } 1324 } 1325 bool removed_any = false; 1326 if (!matches.empty()) { 1327 for (Node* n : matches) { 1328 if (n->num_inputs() != n->num_outputs()) { 1329 continue; // Not expected. Skip. 1330 } 1331 gtl::InlinedVector<Node*, 8> identity_nodes(n->num_inputs(), nullptr); 1332 1333 const auto no_op = [&](StringPiece name) { 1334 return AddNoOp(absl::StrCat(n->name(), "/", name), g); 1335 }; 1336 1337 const auto identity = [&](StringPiece name, Endpoint input) { 1338 return AddIdentity(absl::StrCat(n->name(), "/", name), g, input); 1339 }; 1340 1341 // Process input edges first. 1342 Node* input_control_node = nullptr; 1343 for (const Edge* e : n->in_edges()) { 1344 if (e->IsControlEdge()) { 1345 if (input_control_node == nullptr) { 1346 // If node "n" has any control dependencies, adds a no-op 1347 // node (input_control_node) which the additional Identity 1348 // nodes depends on and the input_control_node depends on 1349 // the node "n"s control dependencies. 1350 input_control_node = no_op("input_control_node"); 1351 } 1352 g->AddControlEdge(e->src(), input_control_node); 1353 } else { 1354 const int index = e->dst_input(); 1355 Node** id_node = &identity_nodes[index]; 1356 if (*id_node != nullptr) { 1357 LOG(ERROR) 1358 << "RemoveListArrayConverter unexpected duplicated input: " 1359 << e->dst_input(); 1360 return removed_any; 1361 } 1362 *id_node = identity("input", {e->src(), e->src_output()}); 1363 } 1364 } 1365 1366 // If node "n" has any control dependencies, the added identity 1367 // nodes should have control dependencies on input_control_node. 1368 if (input_control_node != nullptr) { 1369 for (Node* id : identity_nodes) { 1370 g->AddControlEdge(input_control_node, id); 1371 } 1372 } 1373 1374 Node* output_control_node = nullptr; 1375 for (const Edge* e : n->out_edges()) { 1376 if (e->IsControlEdge()) { 1377 if (output_control_node == nullptr) { 1378 // If node "n" is control-depended upon by other nodes, 1379 // adds a no-op node (output_control_node) which those 1380 // nodes will depend on and output_control_node depends on 1381 // all Identity nodes. 1382 output_control_node = no_op("output_control_node"); 1383 } 1384 g->AddControlEdge(output_control_node, e->dst()); 1385 } else { 1386 Node* id_node = identity_nodes[e->src_output()]; 1387 if (id_node == nullptr) { 1388 LOG(ERROR) << "RemoveListArrayConverter unexpected missing input: " 1389 << e->src_output(); 1390 return removed_any; 1391 } 1392 CHECK(id_node); 1393 g->AddEdge(id_node, 0, e->dst(), e->dst_input()); 1394 } 1395 } 1396 1397 // If any nodes have control dependencies on node "n", those 1398 // nodes should have control dependencies on 1399 // output_control_node. 1400 if (output_control_node != nullptr) { 1401 for (Node* id : identity_nodes) { 1402 g->AddControlEdge(id, output_control_node); 1403 } 1404 } 1405 1406 g->RemoveNode(n); 1407 removed_any = true; 1408 } 1409 } 1410 return removed_any; 1411 } 1412 1413 Status InstantiateFunctionCall(const NodeDef& call_def, 1414 FunctionLibraryRuntime& flr, 1415 FunctionLibraryRuntime::Handle* handle) { 1416 const string* func_name; 1417 AttrSlice attrs; 1418 1419 NameAttrList func; 1420 if (call_def.op() == "PartitionedCall" || 1421 call_def.op() == "StatefulPartitionedCall") { 1422 TF_RETURN_IF_ERROR(GetNodeAttr(call_def, "f", &func)); 1423 func_name = &func.name(); 1424 attrs = AttrSlice(&func.attr()); 1425 } else { 1426 func_name = &call_def.op(); 1427 attrs = AttrSlice(call_def); 1428 } 1429 1430 return flr.Instantiate(*func_name, attrs, handle); 1431 } 1432 1433 namespace { 1434 1435 Status ValidateNoInline(const FunctionBody* fbody) { 1436 const auto attr = AttrSlice(&fbody->fdef.attr()); 1437 bool noinline = false; 1438 if (GetNodeAttr(attr, kNoInlineAttr, &noinline).ok() && noinline) { 1439 return errors::InvalidArgument( 1440 "Can't inline function marked with '_noinline'"); 1441 } 1442 return Status::OK(); 1443 } 1444 1445 using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource; 1446 1447 } // namespace 1448 1449 string InlineFunctionBodyOptions::DebugString() const { 1450 return absl::StrCat("ignore_noinline=", ignore_noinline ? "true" : "false", 1451 ", override_device=", override_device ? "true" : "false", 1452 ", output_control_src=", 1453 output_control_src == OutputControlSrc::kDataOutputs 1454 ? "DataOutputs" 1455 : "ControlOutputs"); 1456 } 1457 1458 Status ValidateInlining(const Node* node, const FunctionBody* fbody, 1459 const InlineFunctionBodyOptions& options) { 1460 // TODO(ezhulenev): Currently common_runtime function inlining can't guarantee 1461 // that all side-effectful ops will be executed after inlining. See Grappler 1462 // function_optimizer for details. Unify all function inlining mechanism. 1463 // Do not inline if `!fbody->control_ret_nodes.empty()`. 1464 1465 const auto num_node_inputs = static_cast<size_t>(node->num_inputs()); 1466 const auto num_node_outputs = static_cast<size_t>(node->num_outputs()); 1467 1468 if (num_node_inputs != fbody->arg_types.size() || 1469 num_node_inputs != fbody->arg_nodes.size()) { 1470 return errors::InvalidArgument( 1471 "Node inputs do not match function arguments: inputs=", num_node_inputs, 1472 " arg_types=", fbody->arg_types.size(), 1473 " arg_nodes=", fbody->arg_nodes.size()); 1474 } 1475 1476 if (num_node_outputs != fbody->ret_types.size() || 1477 num_node_outputs != fbody->ret_nodes.size()) { 1478 return errors::InvalidArgument( 1479 "Node outputs do not match function returns: outputs=", 1480 num_node_outputs, " ret_types=", fbody->ret_types.size(), 1481 " ret_nodes=", fbody->ret_nodes.size()); 1482 } 1483 1484 for (int i = 0; i < node->num_inputs(); ++i) { 1485 if (node->input_type(i) != fbody->arg_types[i]) { 1486 return errors::InvalidArgument( 1487 "Node input type doesn't match function argument type: ", 1488 node->input_type(i), " != ", fbody->arg_types[i], " @ index=", i); 1489 } 1490 } 1491 for (int i = 0; i < node->num_outputs(); ++i) { 1492 if (node->output_type(i) != fbody->ret_types[i]) { 1493 return errors::InvalidArgument( 1494 "Node output type doesn't match function return type: ", 1495 node->output_type(i), " != ", fbody->ret_types[i], " @ index=", i); 1496 } 1497 } 1498 1499 if (!options.ignore_noinline) { 1500 TF_RETURN_IF_ERROR(ValidateNoInline(fbody)); 1501 } 1502 1503 return Status::OK(); 1504 } 1505 1506 // Function inlining must preserve function execution semantics with regards to 1507 // side-effects visibility. Tensorflow in Eager mode has an automatic control 1508 // dependencies tracking mechanism, which enforces well-defined execution order 1509 // of all side-effects. Any other frontend (e.g. Swift) must produce graphs 1510 // following the same rules, to ensure that function inlining works correctly. 1511 // 1512 // IMPORTANT: Currently we do not have a true notion of "side-effectful" node, 1513 // we assume that all stateful nodes might have side-effects, though it's not 1514 // true in practice, e.g. `ReadVariableOp` doesn't have an observable 1515 // side-effect. 1516 // 1517 // Automatic control dependency rules in Tensorflow 2.0 (python in eager mode): 1518 // 1519 // 1) When a function has a resource (DT_RESOURCE data type) input argument it 1520 // "captures" the mutable resource. This is implemented by automatically 1521 // adding a incoming control edge from the previous side-effectful op 1522 // touching that resource, and an outgoing control edge to the next 1523 // side-effectful op using the same resource. This serializes the mutations 1524 // of the resource to make graph execution deterministic. 1525 // 1526 // 2) All stateful ops inside a function body are guaranteed to execute in 1527 // program order, this is achieved by adding control edges between stateful 1528 // ops at graph construction time. Stateful ops (or ops that must execute) 1529 // should be in the function control return set. Having a data edge to the 1530 // regular function output might be not enough, because after function 1531 // inlining it might happen that data output is unused. 1532 // 1533 // 3) Furthermore, all ops accepting the same resource as an input are 1534 // guaranteed to run in program order. This is also done by adding control 1535 // edges at graph construction time. The last op touching the resource 1536 // must be in a control return set, which will guarantee that all side 1537 // effects to the resource will happen before function completion. 1538 // 1539 // Function inlining must preserve side-effect visibility: 1540 // 1541 // 1) All side-effects to the captured resources, that happened before function 1542 // call must be visible to the function body nodes using that resources. 1543 // 1544 // 2) All side-effects to the captured resources, that happened inside function 1545 // body, must be visible to every op/function using that resource after the 1546 // function call completed. 1547 // 1548 // To guarantee that these properties are preserved after inlining we: 1549 // 1550 // 1) Create "input_control_node" NoOp. Function call node incoming control 1551 // edges will be forwarded *to* this node. Function inputs (Identity nodes) 1552 // will have a control edge *from* this node. If function body has nodes 1553 // without inputs, they will have a control edge *from* this node. 1554 // 1555 // 2) Create "output_control_node" NoOp. All nodes that have incoming control 1556 // edge *from* the function call node, will be forwarded to this node. 1557 // 1558 // We have two options for choosing which nodes will have a control edge *to* 1559 // the "output control node": 1560 // a) control returns (`control_ret` field in FunctionDef) 1561 // b) data returns (`ret` field in FunctionDef) 1562 // 1563 // We do a) for multi-device function calls in Tensorflow v2 and b) 1564 // for the rest for compatibility with Tensorflow v1. 1565 // 1566 // Following the automatic control dependencies tracking rules, a node that 1567 // has an incoming control edge from the function call node is dependent on 1568 // the side-effects happening inside the function body. The output control 1569 // node will guarantee side-effects execution order. 1570 // 1571 // If function call node doesn't have an outgoing control edge, it means that 1572 // no one is interested in observing side-effects that might have happened. 1573 // 1574 // Function inlining might leave the graph in partially-placed state. Function 1575 // inlining caller must call Placer to guarantee that all nodes are placed. 1576 // 1577 // Function inlining with `options.override_device=true` will leave graph in 1578 // fully placed state, by overriding all inlined nodes devices with the caller 1579 // node device, but it will make functions always single-device. These functions 1580 // after inlining will not be able to handle resources on multiple devices. This 1581 // is currently acceptable for XLA use cases (XLA cluster is always executed on 1582 // a single device). 1583 // 1584 // TODO(ezhulenev): Documentation above is ahead of implementation below. 1585 Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g, 1586 Node* caller, const FunctionBody* fbody, 1587 const InlineFunctionBodyOptions& options) { 1588 VLOG(3) << "Inline function call: " << SummarizeNode(*caller) << " [" 1589 << options.DebugString() << "]"; 1590 VLOG(4) << "Inlined function definition: " << DebugString(fbody->fdef); 1591 1592 Status validation = ValidateInlining(caller, fbody, options); 1593 if (!validation.ok()) { 1594 LOG(WARNING) << "Inlining mismatch: " << SummarizeNode(*caller) << " vs. " 1595 << DebugString(fbody->graph); 1596 return errors::Internal("Inlining mismatch: ", validation.error_message()); 1597 } 1598 1599 // ------------------------------------------------------------------------ // 1600 // Helper functions to create `NoOp` and `Identity` nodes for auxiliary 1601 // control nodes and inlined function inputs and outputs. 1602 1603 // Add a NoOp node for function control inputs/outputs. 1604 const auto no_op = [&](StringPiece name) { 1605 Node* node = AddNoOp(absl::StrCat(caller->name(), "/", name), g); 1606 node->set_requested_device(caller->def().device()); 1607 return node; 1608 }; 1609 1610 // Add an Identity node for function data inputs/outputs. 1611 const auto identity = [&](StringPiece name, Endpoint input) { 1612 return AddIdentity(absl::StrCat(caller->name(), "/", name), g, input); 1613 }; 1614 1615 // ------------------------------------------------------------------------ // 1616 // Input edges. For data edges coming into "caller", we first compute the 1617 // <src>:<src_output> for the i-th input in "inputs". 1618 // If "caller" has any input control dependencies, we add a NoOp 1619 // node "input_control_node", which depends on "caller"'s control inputs. 1620 std::vector<Endpoint> inputs(caller->num_inputs()); 1621 Node* input_control_node = nullptr; 1622 for (const Edge* e : caller->in_edges()) { 1623 if (e->IsControlEdge()) { 1624 if (input_control_node == nullptr) { 1625 input_control_node = no_op("input_control_node"); 1626 } 1627 g->AddControlEdge(e->src(), input_control_node); 1628 } else { 1629 inputs[e->dst_input()] = {e->src(), e->src_output()}; 1630 } 1631 } 1632 1633 // ------------------------------------------------------------------------ // 1634 // Duplicate fbody->graph into 'g'. First, we copy the nodes of 1635 // fbody->graph into 'g' except the source and sink nodes. We copy 1636 // edges among nodes in 'fbody->graph'. 1637 // 1638 // If 'x' is a node in fbody->graph and its copy in 'g' is 'y', we 1639 // remember 'y' in node_map[x->id()]. 1640 std::vector<Node*> node_map(fbody->graph->num_node_ids()); 1641 for (Node* n : fbody->graph->op_nodes()) { 1642 NodeDef ndef = n->def(); 1643 ndef.set_name(strings::StrCat(caller->name(), "/", ndef.name())); 1644 if (options.override_device || ndef.device().empty()) { 1645 ndef.set_device(caller->def().device()); 1646 } 1647 for (auto& attr : *ndef.mutable_attr()) { 1648 if (attr.first == "_class") { 1649 attr.second.set_s( 1650 strings::StrCat(caller->name(), "/", attr.second.s())); 1651 } 1652 } 1653 Status added_node; 1654 Node* clone = g->AddNode(ndef, &added_node); 1655 if (options.override_device && !caller->assigned_device_name().empty()) { 1656 clone->set_assigned_device_name(caller->assigned_device_name()); 1657 } 1658 TF_CHECK_OK(added_node); 1659 node_map[n->id()] = clone; 1660 1661 // If there is an input control node, and one of: 1662 // a) the node has no data or control inputs, or 1663 // b) the node is a function call or SymbolicGradient, 1664 // then add a control edge from the input control node to the clone. 1665 // 1666 // We must not execute any nodes if the original function call would not 1667 // have executed. This is especially critical when the function call is 1668 // inside a control-flow construct like tf.cond(). Case (a) ensures that 1669 // such nodes do not run. 1670 // 1671 // The purpose of case (b) is to ensure that instances of case (a) created 1672 // by further inlining steps also receive the control dependency. 1673 // 1674 // TODO(ezhulenev): If caller has no control inputs, should we add a control 1675 // edge from one of the inputs to ensure that function body node will 1676 // execute in correct frame? 1677 if (input_control_node) { 1678 bool has_inputs = absl::c_any_of( 1679 n->in_edges(), [](const Edge* e) { return !e->src()->IsSource(); }); 1680 if (!has_inputs || flib_def.Find(clone->type_string()) != nullptr || 1681 clone->type_string() == kGradientOp) { 1682 g->AddControlEdge(input_control_node, clone); 1683 } 1684 } 1685 } 1686 for (const Edge* e : fbody->graph->edges()) { 1687 if (e->src()->IsSource() || e->src()->IsSink() || e->dst()->IsSource() || 1688 e->dst()->IsSink()) { 1689 continue; 1690 } 1691 Node* src_copy = node_map[e->src()->id()]; 1692 Node* dst_copy = node_map[e->dst()->id()]; 1693 g->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input()); 1694 } 1695 1696 // ------------------------------------------------------------------------ // 1697 // Connect input edges. 1698 // 1699 // We create one Identity node for each input. Then, we connect inputs[i] to 1700 // the i-th identity node added. The nodes that previously connected 1701 // to the j-th output of i-th arg node are reconnected to the i-th 1702 // identity node. 1703 // 1704 // The added identity nodes depend on "input_control_node". 1705 for (std::size_t i = 0; i < fbody->arg_nodes.size(); ++i) { 1706 Node* arg = node_map[fbody->arg_nodes[i]->id()]; 1707 Node* n = identity("input", inputs[i]); 1708 if (input_control_node) { 1709 g->AddControlEdge(input_control_node, n); 1710 } 1711 for (const Edge* e : arg->out_edges()) { 1712 if (e->IsControlEdge()) { 1713 g->AddControlEdge(n, e->dst()); 1714 } else { 1715 g->AddEdge(n, 0, e->dst(), e->dst_input()); 1716 } 1717 } 1718 node_map[fbody->arg_nodes[i]->id()] = n; 1719 g->RemoveNode(arg); // 'arg' is disconnected. 1720 } 1721 1722 // ------------------------------------------------------------------------ // 1723 // Connect output edges. 1724 // 1725 // For i-th return node in fbody->graph, we add in "g" an identity node 1726 // (outputs[i-th]). We then reconnect every incoming edge into the i-th return 1727 // node to the added identity node. 1728 // 1729 // For every data edge coming out of "callee"s i-th output, we reconnect it to 1730 // the i-th identity added above. 1731 // 1732 // If "callee" is control-depended upon by any other nodes, we add a NoOp node 1733 // "output_control_node". "output_control_node" depends on all identity nodes 1734 // added above or on all control return nodes (controlled by 1735 // `options.output_control_src` value). And nodes previously depend on 1736 // "callee" is changed to depend on "output_control_node". 1737 std::vector<Node*> outputs(caller->num_outputs()); 1738 for (std::size_t i = 0; i < fbody->ret_nodes.size(); ++i) { 1739 Node* ret = node_map[fbody->ret_nodes[i]->id()]; 1740 Endpoint data; // Data input for the ret node. 1741 for (const Edge* e : ret->in_edges()) { 1742 if (!e->IsControlEdge()) { 1743 data = {e->src(), e->src_output()}; 1744 break; 1745 } 1746 } 1747 CHECK(data.node != nullptr); 1748 Node* n = identity("output", data); 1749 outputs[i] = n; 1750 for (const Edge* e : ret->in_edges()) { 1751 if (e->IsControlEdge()) { 1752 g->AddControlEdge(e->src(), n); 1753 } 1754 } 1755 g->RemoveNode(ret); // 'ret' is disconnected. 1756 } 1757 Node* output_control_node = nullptr; 1758 for (const Edge* e : caller->out_edges()) { 1759 if (e->IsControlEdge()) { 1760 if (output_control_node == nullptr) { 1761 output_control_node = no_op("output_control_node"); 1762 if (options.output_control_src == 1763 InlineFunctionBodyOptions::OutputControlSource::kDataOutputs) { 1764 for (Node* n : outputs) { 1765 g->AddControlEdge(n, output_control_node); 1766 } 1767 } else { 1768 for (Node* fbody_node : fbody->control_ret_nodes) { 1769 Node* n = node_map[fbody_node->id()]; 1770 g->AddControlEdge(n, output_control_node); 1771 } 1772 } 1773 } 1774 g->AddControlEdge(output_control_node, e->dst()); 1775 } else { 1776 g->AddEdge(outputs[e->src_output()], 0, e->dst(), e->dst_input()); 1777 } 1778 } 1779 g->RemoveNode(caller); // 'caller' is replaced with inlined nodes. 1780 1781 return Status::OK(); 1782 } 1783 1784 bool IsFunctionCall(const FunctionLibraryDefinition& lib_def, 1785 const Node& node) { 1786 return node.IsPartitionedCall() || 1787 node.type_string() == FunctionLibraryDefinition::kGradientOp || 1788 lib_def.Find(node.def().op()) != nullptr; 1789 } 1790 1791 bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph, 1792 const ExpandInlineFunctionsOptions& options) { 1793 std::vector<std::pair<Node*, const FunctionBody*>> candidates; 1794 1795 const FunctionLibraryDefinition* fld = lib->GetFunctionLibraryDefinition(); 1796 1797 for (Node* node : graph->nodes()) { 1798 // Skip nodes that are not function calls or SymbolicGradient calls. 1799 if (!IsFunctionCall(*lib->GetFunctionLibraryDefinition(), *node)) { 1800 continue; 1801 } 1802 // Skip function calls that marked noinline. 1803 bool noinline; 1804 if (fld->GetAttr(*node, kNoInlineAttr, &noinline).ok() && noinline) { 1805 VLOG(3) << "noinline: " << SummarizeNode(*node); 1806 continue; 1807 } 1808 FunctionLibraryRuntime::Handle handle; 1809 Status s = InstantiateFunctionCall(node->def(), *lib, &handle); 1810 if (!s.ok()) { 1811 LOG(ERROR) << "Failed to instantiate a function: " << s.error_message(); 1812 continue; 1813 } 1814 const FunctionBody* fbody = lib->GetFunctionBody(handle); 1815 CHECK_NOTNULL(fbody); 1816 candidates.emplace_back(node, fbody); 1817 } 1818 1819 bool inlined_any = false; 1820 for (const auto& p : candidates) { 1821 Status inlined = InlineFunctionBody(*fld, graph, p.first, p.second, 1822 p.first->IsPartitionedCall() 1823 ? options.multi_device_options 1824 : options.native_options); 1825 if (inlined.ok()) { 1826 inlined_any = true; 1827 } else { 1828 VLOG(1) << "Failed to inline function call: node=" << p.first->name() 1829 << " error=" << inlined.error_message(); 1830 } 1831 } 1832 1833 // TODO(ezhulenev): Release handles for inlined function calls. 1834 1835 return inlined_any; 1836 } 1837 1838 string NewName(const Node* n, bool pretty) { 1839 if (pretty) { 1840 return strings::StrCat(n->type_string(), n->id()); 1841 } else { 1842 return strings::StrCat("n", n->id()); 1843 } 1844 } 1845 1846 // TODO(zhifengc): Maybe this should be the default Graph::AsGraphDef. 1847 // and stash the original NodeDef name as an attr for documentation 1848 // purpose. 1849 void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty) { 1850 // We visit nodes in forward topological sort order, which is a 1851 // possible execution order of the graph. 1852 gtl::InlinedVector<const Edge*, 4> inputs; 1853 gdef->Clear(); 1854 gdef->mutable_versions()->CopyFrom(g->versions()); 1855 1856 std::vector<Node*> start_nodes; 1857 for (Node* n : g->nodes()) { 1858 if (n->out_edges().empty()) { 1859 start_nodes.push_back(n); 1860 } 1861 } 1862 1863 ReverseDFSFrom(*g, start_nodes, nullptr, [gdef, pretty, &inputs](Node* n) { 1864 if (!n->IsOp()) return; 1865 NodeDef* ndef = gdef->add_node(); 1866 ndef->set_name(NewName(n, pretty)); 1867 ndef->set_op(n->type_string()); 1868 for (const auto& attr : n->attrs()) { 1869 (*ndef->mutable_attr())[attr.first] = attr.second; 1870 } 1871 1872 if (!n->assigned_device_name().empty()) { 1873 ndef->set_device(n->assigned_device_name()); 1874 } else { 1875 ndef->set_device(n->requested_device()); 1876 } 1877 1878 inputs.clear(); 1879 inputs.resize(n->num_inputs()); 1880 for (const Edge* e : n->in_edges()) { 1881 if (e->IsControlEdge()) { 1882 inputs.push_back(e); 1883 } else { 1884 if (inputs[e->dst_input()] == nullptr) { 1885 inputs[e->dst_input()] = e; 1886 } else { 1887 LOG(WARNING) << "Malformed graph node. multiple input edges: " 1888 << n->DebugString(); 1889 } 1890 } 1891 } 1892 // node->name() is merely NodeDef::name, which are not guaranteed 1893 // to be unique and stable after optimization rewrites. Therefore, 1894 // we use "n<node id>" instead. 1895 for (const Edge* e : inputs) { 1896 if (e == nullptr) { 1897 ndef->add_input("unknown"); 1898 continue; 1899 } 1900 const string srcname = NewName(e->src(), pretty); 1901 if (!e->src()->IsOp()) { 1902 } else if (e->IsControlEdge()) { 1903 ndef->add_input(strings::StrCat("^", srcname)); 1904 } else if (e->src_output() == 0) { 1905 ndef->add_input(srcname); 1906 } else { 1907 ndef->add_input(strings::StrCat(srcname, ":", e->src_output())); 1908 } 1909 } 1910 }); 1911 } 1912 1913 string DebugString(const Graph* g) { 1914 GraphDef gdef; 1915 ToGraphDef(g, &gdef); 1916 return DebugString(gdef); 1917 } 1918 1919 FunctionBody::FunctionBody(const FunctionDef& f, DataTypeSlice arg_t, 1920 DataTypeSlice ret_t, Graph* g) 1921 : fdef(f), 1922 graph(g), 1923 arg_types(arg_t.begin(), arg_t.end()), 1924 ret_types(ret_t.begin(), ret_t.end()) { 1925 // 1. Find regular Arg/Ret nodes. 1926 this->arg_nodes.resize(arg_types.size()); 1927 this->ret_nodes.resize(ret_types.size()); 1928 for (Node* n : this->graph->op_nodes()) { 1929 gtl::InlinedVector<Node*, 4>* node_vec; 1930 if (n->type_string() == kRetOp || n->type_string() == kDeviceRetOp) { 1931 node_vec = &this->ret_nodes; 1932 } else if (n->type_string() == kArgOp || n->type_string() == kDeviceArgOp) { 1933 node_vec = &this->arg_nodes; 1934 } else { 1935 continue; 1936 } 1937 int index; 1938 TF_CHECK_OK(GetNodeAttr(n->attrs(), "index", &index)); 1939 CHECK_LE(0, index); 1940 CHECK_LT(index, node_vec->size()); 1941 (*node_vec)[index] = n; 1942 } 1943 // 2. Find ControlRet nodes that must be always executed. 1944 std::unordered_set<StringPiece, StringPieceHasher> control_ret_node_names; 1945 for (const auto& control_ret : fdef.control_ret()) { 1946 control_ret_node_names.insert(control_ret.second); 1947 } 1948 this->control_ret_nodes.reserve(control_ret_node_names.size()); 1949 for (Node* n : this->graph->op_nodes()) { 1950 if (control_ret_node_names.count(n->name()) > 0) { 1951 this->control_ret_nodes.push_back(n); 1952 } 1953 } 1954 } 1955 1956 FunctionBody::~FunctionBody() { delete this->graph; } 1957 1958 class SymbolicGradientHelper { 1959 public: 1960 explicit SymbolicGradientHelper(const FunctionBody& f) : fbody_(&f) {} 1961 1962 ~SymbolicGradientHelper() { delete gbody_; } 1963 1964 FunctionBody* Compute(); 1965 1966 private: 1967 const FunctionBody* fbody_; 1968 FunctionBody* gbody_ = nullptr; 1969 1970 // Makes a copy of fbody_ in gbody_. 1971 void Copy(); 1972 1973 TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientHelper); 1974 }; 1975 1976 void SymbolicGradientHelper::Copy() { 1977 const Graph& src = *(fbody_->graph); 1978 gbody_->graph = new Graph(src.op_registry()); 1979 Graph* dst = gbody_->graph; 1980 1981 std::vector<Node*> node_map(src.num_node_ids()); 1982 1983 // Copy the nodes. 1984 node_map[src.source_node()->id()] = dst->source_node(); 1985 node_map[src.sink_node()->id()] = dst->sink_node(); 1986 for (Node* n : src.op_nodes()) { 1987 node_map[n->id()] = dst->CopyNode(n); 1988 } 1989 1990 // Copy the edges. 1991 for (const Edge* e : src.edges()) { 1992 Node* src_copy = node_map[e->src()->id()]; 1993 Node* dst_copy = node_map[e->dst()->id()]; 1994 dst->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input()); 1995 } 1996 1997 // Save inputs in copied graph. 1998 CHECK_EQ(fbody_->arg_types.size(), fbody_->arg_nodes.size()); 1999 gbody_->arg_types = fbody_->arg_types; 2000 for (std::size_t i = 0; i < fbody_->arg_nodes.size(); ++i) { 2001 gbody_->arg_nodes.push_back(node_map[fbody_->arg_nodes[i]->id()]); 2002 } 2003 2004 // Save outputs in copied graph. 2005 CHECK_EQ(fbody_->ret_types.size(), fbody_->ret_nodes.size()); 2006 gbody_->ret_types = fbody_->ret_types; 2007 for (std::size_t i = 0; i < fbody_->ret_nodes.size(); ++i) { 2008 gbody_->ret_nodes.push_back(node_map[fbody_->ret_nodes[i]->id()]); 2009 } 2010 } 2011 2012 FunctionBody* SymbolicGradientHelper::Compute() { 2013 CHECK(gbody_ == nullptr); 2014 gbody_ = new FunctionBody; 2015 2016 // Copy fbody_ into gbody_. 2017 Copy(); 2018 2019 Graph* g = gbody_->graph; 2020 2021 const int num_y = static_cast<int>(gbody_->ret_nodes.size()); 2022 2023 // Populate 'y_node_outputs_' with node function body outputs. 2024 // Populate 'y_grad_nodes' with initial gradient nodes for each return node 2025 // of the original function body (these will be 'arg' nodes in the function 2026 // gradient body). 2027 std::vector<NodeOut> y_node_outputs; 2028 y_node_outputs.reserve(num_y); 2029 std::vector<NodeOut> y_grad_node_outputs; 2030 y_grad_node_outputs.reserve(num_y); 2031 for (int i = 0; i < num_y; ++i) { 2032 Node* y = gbody_->ret_nodes[i]; 2033 y_node_outputs.push_back({y, 0}); 2034 DCHECK_EQ(y->type_string(), kRetOp); 2035 const DataType dtype = y->input_type(0); 2036 const int index = static_cast<int>(gbody_->arg_nodes.size()); 2037 Node* dy = AddArg(g, dtype, index); 2038 gbody_->arg_types.push_back(dtype); 2039 gbody_->arg_nodes.push_back(dy); 2040 y_grad_node_outputs.push_back({dy, 0}); 2041 } 2042 2043 // Populate 'x_nodes' with function args (excluding 'y_grad_node_outputs'). 2044 const size_t num_x = fbody_->arg_nodes.size(); 2045 std::vector<NodeOut> x_node_outputs; 2046 x_node_outputs.reserve(num_x); 2047 for (size_t i = 0; i < fbody_->arg_nodes.size(); ++i) { 2048 x_node_outputs.push_back({gbody_->arg_nodes[i], 0}); 2049 } 2050 2051 // Call AddSymbolicGradients which will add nodes to graph 'g' that 2052 // compute the function gradient (adding an entry in 'x_grad_node_outputs' 2053 // for each node in 'x_node_outputs'). 2054 std::vector<NodeOut> x_grad_node_outputs; 2055 TF_CHECK_OK(AddSymbolicGradients(y_node_outputs, x_node_outputs, 2056 y_grad_node_outputs, &x_grad_node_outputs, 2057 g)); 2058 2059 // Remove the old return nodes from the function body. 2060 for (Node* n : gbody_->ret_nodes) { 2061 g->RemoveNode(n); 2062 } 2063 gbody_->ret_types = fbody_->arg_types; 2064 // TODO(apassos): use the right dtype for gradients of resource variables 2065 for (int i = 0; i < gbody_->ret_types.size(); ++i) { 2066 if (gbody_->ret_types[i] == DT_RESOURCE) { 2067 gbody_->ret_types[i] = DT_FLOAT; 2068 } 2069 } 2070 gbody_->ret_nodes.clear(); 2071 // Add new return nodes to the function gradient body for each node 2072 // in 'x_grad_nodes'. 2073 const int arg_types_size = static_cast<int>(fbody_->arg_types.size()); 2074 for (int i = 0; i < arg_types_size; ++i) { 2075 Endpoint grad = {x_grad_node_outputs[i].node, x_grad_node_outputs[i].index}; 2076 Node* ret = AddRet(g, grad, i); 2077 gbody_->ret_nodes.push_back(ret); 2078 } 2079 2080 auto ret = gbody_; 2081 gbody_ = nullptr; 2082 return ret; 2083 } 2084 2085 FunctionBody* SymbolicGradient(const FunctionBody& f) { 2086 return SymbolicGradientHelper(f).Compute(); 2087 } 2088 2089 Status FunctionDefToBodyHelper( 2090 const FunctionDef& fdef, const AttrSlice& attrs, 2091 const FunctionLibraryDefinition* const lib_def, 2092 const std::function<Status(const string&, const OpDef**)>& get_func_sig, 2093 FunctionBody** fbody) { 2094 // Instantiates the function template into a graph def. 2095 InstantiationResult result; 2096 TF_RETURN_IF_ERROR(InstantiateFunction(fdef, attrs, get_func_sig, &result)); 2097 2098 std::unique_ptr<Graph> graph(new Graph(lib_def)); 2099 GraphConstructorOptions opts; 2100 opts.allow_internal_ops = true; 2101 opts.expect_device_spec = false; 2102 TF_RETURN_IF_ERROR(ConvertNodeDefsToGraph(opts, result.nodes, graph.get())); 2103 2104 // Call BuildControlFlowInfo to validate that this function body has 2105 // well-formed control flow. 2106 std::vector<ControlFlowInfo> dummy; 2107 TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph.get(), &dummy)); 2108 2109 *fbody = new FunctionBody(fdef, result.arg_types, result.ret_types, 2110 graph.release()); 2111 return Status::OK(); 2112 } 2113 2114 } // end namespace tensorflow 2115