1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/distributed_runtime/graph_mgr.h" 17 18 #include <vector> 19 20 #include "tensorflow/core/common_runtime/constant_folding.h" 21 #include "tensorflow/core/common_runtime/debugger_state_interface.h" 22 #include "tensorflow/core/common_runtime/device.h" 23 #include "tensorflow/core/common_runtime/device_mgr.h" 24 #include "tensorflow/core/common_runtime/function.h" 25 #include "tensorflow/core/common_runtime/graph_optimizer.h" 26 #include "tensorflow/core/common_runtime/memory_types.h" 27 #include "tensorflow/core/common_runtime/optimization_registry.h" 28 #include "tensorflow/core/common_runtime/process_util.h" 29 #include "tensorflow/core/common_runtime/rendezvous_util.h" 30 #include "tensorflow/core/common_runtime/step_stats_collector.h" 31 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" 32 #include "tensorflow/core/framework/cancellation.h" 33 #include "tensorflow/core/framework/log_memory.h" 34 #include "tensorflow/core/framework/node_def.pb.h" 35 #include "tensorflow/core/framework/node_def_util.h" 36 #include "tensorflow/core/framework/versions.pb.h" 37 #include "tensorflow/core/graph/graph.h" 38 #include "tensorflow/core/graph/graph_constructor.h" 39 #include "tensorflow/core/graph/graph_partition.h" 40 #include "tensorflow/core/graph/validate.h" 41 #include "tensorflow/core/lib/core/errors.h" 42 #include "tensorflow/core/lib/strings/stringprintf.h" 43 #include "tensorflow/core/platform/env.h" 44 #include "tensorflow/core/platform/logging.h" 45 #include "tensorflow/core/platform/mutex.h" 46 #include "tensorflow/core/platform/types.h" 47 #include "tensorflow/core/protobuf/worker.pb.h" 48 #include "tensorflow/core/util/env_var.h" 49 50 namespace tensorflow { 51 52 GraphMgr::GraphMgr(const WorkerEnv* worker_env, DeviceMgr* device_mgr) 53 : worker_env_(worker_env), device_mgr_(device_mgr), table_(5) { 54 // The default value of sync_on_finish will be flipped soon and this 55 // environment variable will be removed as well. 56 Status status = 57 ReadBoolFromEnvVar("TF_SYNC_ON_FINISH", true, &sync_on_finish_); 58 if (!status.ok()) { 59 LOG(ERROR) << status.error_message(); 60 } 61 } 62 63 GraphMgr::~GraphMgr() { 64 for (auto p : table_) p.second->Unref(); 65 } 66 67 GraphMgr::Item::~Item() { 68 for (const auto& unit : this->units) { 69 CHECK_NOTNULL(unit.device); 70 if (!graph_mgr->skip_cost_models_) { 71 graph_mgr->cost_model_manager_.RemoveCostModelForGraph(unit.graph); 72 } 73 delete unit.root; 74 unit.device->op_segment()->RemoveHold(this->session); 75 } 76 } 77 78 // NOTE: node->device_name() is not set by GraphConstructor. We 79 // expects that NodeDef in GraphDef given to workers fully specifies 80 // device names. 81 static string SplitByDevice(const Node* node) { 82 return node->assigned_device_name(); 83 } 84 85 // Validates "gdef" device specifications. 86 static Status ValidateGraphDefForDevices(const GraphDef& gdef) { 87 DeviceNameUtils::ParsedName parsed; 88 for (const auto& ndef : gdef.node()) { 89 if (!DeviceNameUtils::ParseFullName(ndef.device(), &parsed)) { 90 return errors::InvalidArgument("Missing device name in: ", 91 SummarizeNodeDef(ndef)); 92 } 93 } 94 return Status::OK(); 95 } 96 97 Status GraphMgr::DecorateAndPublishGraphForDebug( 98 const DebugOptions& debug_options, Graph* graph, Device* device) { 99 std::unique_ptr<DebugGraphDecoratorInterface> decorator; 100 TF_RETURN_IF_ERROR( 101 DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator)); 102 TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device)); 103 TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph, device->name())); 104 return Status::OK(); 105 } 106 107 // Creates executors given a graph definition "gdef" of a "session". 108 // If a node in "gdef" is shared by other graphs in "session", the 109 // same op kernel is reused. E.g., typically a params node is shared 110 // by multiple graphs in a session. 111 // 112 // If "gdef" is assigned to multiple devices, extra nodes (e.g., 113 // send/recv nodes) maybe added. The extra nodes' name are generated 114 // by calling "new_name(old_name)". 115 // 116 // "executors" are filled with one executor per device if success and 117 // the caller takes the ownership of returned executors. 118 Status GraphMgr::InitItem(const string& session, const GraphDef& gdef, 119 const GraphOptions& graph_options, 120 const DebugOptions& debug_options, 121 DistributedFunctionLibraryRuntime* cluster_flr, 122 Item* item) { 123 item->session = session; 124 item->lib_def.reset( 125 new FunctionLibraryDefinition(OpRegistry::Global(), gdef.library())); 126 127 TF_RETURN_IF_ERROR(ValidateGraphDefForDevices(gdef)); 128 129 if (gdef.versions().producer() >= 5) { 130 // Validate the graph: we assume that merging two valid graphs 131 // should maintain graph validity. 132 TF_RETURN_IF_ERROR(graph::ValidateGraphDef(gdef, *item->lib_def)); 133 } 134 135 item->proc_flr.reset(new ProcessFunctionLibraryRuntime( 136 device_mgr_, worker_env_->env, gdef.versions().producer(), 137 item->lib_def.get(), graph_options.optimizer_options(), cluster_flr)); 138 139 // Constructs the graph out of "gdef". 140 Graph graph(OpRegistry::Global()); 141 GraphConstructorOptions opts; 142 opts.allow_internal_ops = true; 143 opts.expect_device_spec = true; 144 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, gdef, &graph)); 145 146 // Splits "graph" into multiple subgraphs by device names. 147 std::unordered_map<string, GraphDef> partitions; 148 PartitionOptions popts; 149 popts.node_to_loc = SplitByDevice; 150 popts.new_name = [this](const string& prefix) { 151 mutex_lock l(mu_); 152 return strings::StrCat(prefix, "_G", next_id_++); 153 }; 154 popts.get_incarnation = [this](const string& name) -> int64 { 155 Device* device = nullptr; 156 Status s = device_mgr_->LookupDevice(name, &device); 157 if (s.ok()) { 158 return device->attributes().incarnation(); 159 } else { 160 return PartitionOptions::kIllegalIncarnation; 161 } 162 }; 163 popts.flib_def = &graph.flib_def(); 164 popts.control_flow_added = true; 165 popts.scheduling_for_recvs = graph_options.enable_recv_scheduling(); 166 TF_RETURN_IF_ERROR(Partition(popts, &graph, &partitions)); 167 if (popts.scheduling_for_recvs) { 168 TF_RETURN_IF_ERROR(AddControlEdges(popts, &partitions)); 169 } 170 171 std::unordered_map<string, std::unique_ptr<Graph>> partition_graphs; 172 for (const auto& partition : partitions) { 173 std::unique_ptr<Graph> device_graph(new Graph(OpRegistry::Global())); 174 GraphConstructorOptions device_opts; 175 // There are internal operations (e.g., send/recv) that we now allow. 176 device_opts.allow_internal_ops = true; 177 device_opts.expect_device_spec = true; 178 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(device_opts, partition.second, 179 device_graph.get())); 180 partition_graphs.emplace(partition.first, std::move(device_graph)); 181 } 182 183 GraphOptimizationPassOptions optimization_options; 184 optimization_options.flib_def = item->lib_def.get(); 185 optimization_options.partition_graphs = &partition_graphs; 186 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( 187 OptimizationPassRegistry::POST_PARTITIONING, optimization_options)); 188 189 LocalExecutorParams params; 190 191 item->units.reserve(partitions.size()); 192 item->graph_mgr = this; 193 const auto& optimizer_opts = graph_options.optimizer_options(); 194 GraphOptimizer optimizer(optimizer_opts); 195 for (auto& p : partition_graphs) { 196 const string& device_name = p.first; 197 std::unique_ptr<Graph>& subgraph = p.second; 198 item->units.resize(item->units.size() + 1); 199 ExecutionUnit* unit = &(item->units.back()); 200 201 // Find the device. 202 Status s = device_mgr_->LookupDevice(device_name, &unit->device); 203 if (!s.ok()) { 204 // Remove the empty unit from the item as the item destructor wants all 205 // units to have valid devices. 206 item->units.pop_back(); 207 return s; 208 } 209 210 // Give the device an opportunity to rewrite its subgraph. 211 TF_RETURN_IF_ERROR(unit->device->MaybeRewriteGraph(&subgraph)); 212 213 // Top-level nodes in the graph uses the op segment to cache 214 // kernels. Therefore, as long as the executor is alive, we need 215 // to ensure the kernels cached for the session are alive. 216 auto opseg = unit->device->op_segment(); 217 opseg->AddHold(session); 218 219 // Function library runtime. 220 FunctionLibraryRuntime* lib = item->proc_flr->GetFLR(unit->device->name()); 221 if (lib == nullptr) { 222 return errors::InvalidArgument("Cannot find FLR for device: ", 223 unit->device->name()); 224 } 225 226 // Construct the root executor for the subgraph. 227 params.device = unit->device; 228 params.function_library = lib; 229 params.create_kernel = [session, lib, opseg](const NodeDef& ndef, 230 OpKernel** kernel) { 231 // We do not share the kernel via the OpSegment if the node is 232 // stateless, or a function. 233 // NOTE(mrry): We must not share function kernels (implemented 234 // using `CallOp`) between subgraphs, because `CallOp::handle_` 235 // is tied to a particular subgraph. Even if the function itself 236 // is stateful, the `CallOp` that invokes it is not. 237 if (!lib->IsStateful(ndef.op()) || 238 lib->GetFunctionLibraryDefinition()->Find(ndef.op()) != nullptr) { 239 return lib->CreateKernel(ndef, kernel); 240 } 241 auto create_fn = [lib, &ndef](OpKernel** kernel) { 242 return lib->CreateKernel(ndef, kernel); 243 }; 244 // Kernels created for subgraph nodes need to be cached. On 245 // cache miss, create_fn() is invoked to create a kernel based 246 // on the function library here + global op registry. 247 return opseg->FindOrCreate(session, ndef.name(), kernel, create_fn); 248 }; 249 params.delete_kernel = [lib](OpKernel* kernel) { 250 // If the node is stateful, opseg owns it. Otherwise, delete it. 251 if (kernel && !lib->IsStateful(kernel->type_string())) { 252 delete kernel; 253 } 254 }; 255 256 optimizer.Optimize(lib, worker_env_->env, params.device, &subgraph, 257 /*shape_map=*/nullptr); 258 259 // EXPERIMENTAL: tfdbg inserts debug nodes (i.e., probes) to the graph. 260 if (!debug_options.debug_tensor_watch_opts().empty()) { 261 TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug( 262 debug_options, subgraph.get(), params.device)); 263 } 264 265 TF_RETURN_IF_ERROR( 266 EnsureMemoryTypes(DeviceType(unit->device->device_type()), 267 unit->device->name(), subgraph.get())); 268 unit->graph = subgraph.get(); 269 unit->build_cost_model = graph_options.build_cost_model(); 270 if (unit->build_cost_model > 0) { 271 skip_cost_models_ = false; 272 } 273 TF_RETURN_IF_ERROR( 274 NewLocalExecutor(params, std::move(subgraph), &unit->root)); 275 } 276 return Status::OK(); 277 } 278 279 Status GraphMgr::Register(const string& session, const GraphDef& gdef, 280 const GraphOptions& graph_options, 281 const DebugOptions& debug_options, 282 DistributedFunctionLibraryRuntime* cluster_flr, 283 string* handle) { 284 Item* item = new Item; 285 Status s = 286 InitItem(session, gdef, graph_options, debug_options, cluster_flr, item); 287 if (!s.ok()) { 288 item->Unref(); 289 return s; 290 } 291 292 // Inserts one item into table_. 293 { 294 mutex_lock l(mu_); 295 *handle = strings::Printf("%016llx", ++next_id_); 296 item->handle = *handle; 297 CHECK(table_.insert({*handle, item}).second); 298 } 299 return Status::OK(); 300 } 301 302 Status GraphMgr::Deregister(const string& handle) { 303 Item* item = nullptr; 304 // Removes one item from table_. 305 { 306 mutex_lock l(mu_); 307 auto iter = table_.find(handle); 308 if (iter == table_.end()) { 309 return errors::Aborted("Graph handle is not found: ", handle, 310 ". Possibly, this worker just restarted."); 311 } 312 item = iter->second; 313 table_.erase(iter); 314 } 315 item->Unref(); 316 return Status::OK(); 317 } 318 319 Status GraphMgr::DeregisterAll() { 320 std::vector<Item*> items; 321 // Removes all items from table_. 322 { 323 mutex_lock l(mu_); 324 for (const auto& entry : table_) { 325 items.push_back(entry.second); 326 } 327 table_.clear(); 328 } 329 for (auto item : items) { 330 item->Unref(); 331 } 332 return Status::OK(); 333 } 334 335 Status GraphMgr::SendInputs(const int64 step_id, const NamedTensors& in) { 336 Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id); 337 std::vector<string> keys; 338 std::vector<Tensor> tensors_to_send; 339 keys.reserve(in.size()); 340 tensors_to_send.reserve(in.size()); 341 for (const auto& p : in) { 342 keys.push_back(p.first); 343 tensors_to_send.push_back(p.second); 344 } 345 Status s = 346 SendTensorsToRendezvous(rendezvous, nullptr, {}, keys, tensors_to_send); 347 rendezvous->Unref(); 348 return s; 349 } 350 351 Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) { 352 Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id); 353 Status s = RecvOutputsFromRendezvous(rendezvous, out, Rendezvous::Args()); 354 rendezvous->Unref(); 355 return s; 356 } 357 358 void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out, 359 StatusCallback done) { 360 Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id); 361 std::vector<string> keys; 362 std::vector<Tensor>* received_keys = new std::vector<Tensor>; 363 keys.reserve(out->size()); 364 received_keys->reserve(out->size()); 365 for (const auto& p : *out) { 366 keys.push_back(p.first); 367 received_keys->push_back(p.second); 368 } 369 RecvOutputsFromRendezvousAsync( 370 rendezvous, nullptr, {}, keys, received_keys, 371 [done, rendezvous, received_keys, out, keys](const Status s) { 372 rendezvous->Unref(); 373 for (int i = 0; i < keys.size(); ++i) { 374 (*out)[keys[i]] = (*received_keys)[i]; 375 } 376 delete received_keys; 377 done(s); 378 }); 379 } 380 381 void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id, 382 WorkerSession* session, const ExecutorOpts& opts, 383 StepStatsCollector* collector, 384 MutableRunGraphResponseWrapper* response, 385 CancellationManager* cancellation_manager, 386 const NamedTensors& in, StatusCallback done) { 387 // Lookup an item. Holds one ref while executing. 388 Item* item = nullptr; 389 { 390 mutex_lock l(mu_); 391 auto iter = table_.find(handle); 392 if (iter != table_.end()) { 393 item = iter->second; 394 item->Ref(); 395 } 396 } 397 398 if (item == nullptr) { 399 done(errors::Aborted("Graph handle is not found: ", handle)); 400 return; 401 } 402 403 CostGraphDef* cost_graph = nullptr; 404 if (response != nullptr) { 405 cost_graph = response->mutable_cost_graph(); 406 if (opts.record_partition_graphs()) { 407 for (const ExecutionUnit& unit : item->units) { 408 GraphDef graph_def; 409 unit.graph->ToGraphDef(&graph_def); 410 response->AddPartitionGraph(graph_def); 411 } 412 } 413 } 414 415 RemoteRendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id); 416 Status s = rendezvous->Initialize(session); 417 418 // Sends values specified by the caller. 419 if (s.ok()) { 420 std::vector<string> keys; 421 std::vector<Tensor> tensors_to_send; 422 keys.reserve(in.size()); 423 tensors_to_send.reserve(in.size()); 424 for (auto& p : in) { 425 keys.push_back(p.first); 426 tensors_to_send.push_back(p.second); 427 } 428 s = SendTensorsToRendezvous(rendezvous, nullptr, {}, keys, tensors_to_send); 429 } 430 431 if (!s.ok()) { 432 done(s); 433 item->Unref(); 434 rendezvous->Unref(); 435 return; 436 } 437 438 StartParallelExecutors(handle, step_id, item, rendezvous, collector, 439 cost_graph, cancellation_manager, 440 [this, item, rendezvous, done](const Status& s) { 441 done(s); 442 rendezvous->Unref(); 443 item->Unref(); 444 }); 445 } 446 447 void GraphMgr::StartParallelExecutors(const string& handle, int64 step_id, 448 Item* item, Rendezvous* rendezvous, 449 StepStatsCollector* collector, 450 CostGraphDef* cost_graph, 451 CancellationManager* cancellation_manager, 452 StatusCallback done) { 453 const int num_units = item->units.size(); 454 CHECK_GE(num_units, 1); 455 ScopedStepContainer* step_container = new ScopedStepContainer( 456 step_id, 457 [this](const string& name) { device_mgr_->ClearContainers({name}); }); 458 // NOTE: Transfer one ref of rendezvous and item. 459 ExecutorBarrier* barrier = 460 new ExecutorBarrier(num_units, rendezvous, 461 [this, item, collector, cost_graph, step_container, 462 done](const Status& s) { 463 BuildCostModel(item, collector, cost_graph); 464 done(s); 465 delete step_container; 466 }); 467 Executor::Args args; 468 { 469 mutex_lock l(mu_); 470 args.step_id = ++next_id_; 471 } 472 args.rendezvous = rendezvous; 473 args.cancellation_manager = cancellation_manager; 474 args.stats_collector = collector; 475 args.step_container = step_container; 476 args.sync_on_finish = sync_on_finish_; 477 if (LogMemory::IsEnabled()) { 478 LogMemory::RecordStep(args.step_id, handle); 479 } 480 thread::ThreadPool* pool = worker_env_->compute_pool; 481 using std::placeholders::_1; 482 // Line below is equivalent to this code, but does one less indirect call: 483 // args.runner = [pool](std::function<void()> fn) { pool->Schedule(fn); }; 484 auto default_runner = std::bind(&thread::ThreadPool::Schedule, pool, _1); 485 for (const auto& unit : item->units) { 486 // TODO(zhengxq): if the device picks its own threadpool, we need to assign 487 // less threads to the main compute pool by default. 488 thread::ThreadPool* device_thread_pool = 489 unit.device->tensorflow_device_thread_pool(); 490 if (!device_thread_pool) { 491 args.runner = default_runner; 492 } else { 493 args.runner = 494 std::bind(&thread::ThreadPool::Schedule, device_thread_pool, _1); 495 } 496 unit.root->RunAsync(args, barrier->Get()); 497 } 498 } 499 500 void GraphMgr::BuildCostModel(Item* item, StepStatsCollector* collector, 501 CostGraphDef* cost_graph) { 502 if (collector && !skip_cost_models_) { 503 // Build the cost model 504 std::unordered_map<string, const Graph*> device_to_graph; 505 for (const auto& unit : item->units) { 506 if (unit.build_cost_model > 0) { 507 device_to_graph[unit.device->name()] = unit.graph; 508 } 509 } 510 collector->BuildCostModel(&cost_model_manager_, device_to_graph); 511 512 if (cost_graph != nullptr) { 513 for (const auto& unit : item->units) { 514 cost_model_manager_.AddToCostGraphDef(unit.graph, cost_graph) 515 .IgnoreError(); 516 } 517 } 518 } 519 } 520 521 } // end namespace tensorflow 522