1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/distributed_runtime/master_session.h" 17 18 #include <unordered_map> 19 #include <unordered_set> 20 #include <vector> 21 22 #include "tensorflow/core/common_runtime/process_util.h" 23 #include "tensorflow/core/common_runtime/profile_handler.h" 24 #include "tensorflow/core/common_runtime/stats_publisher_interface.h" 25 #include "tensorflow/core/debug/debug_graph_utils.h" 26 #include "tensorflow/core/distributed_runtime/scheduler.h" 27 #include "tensorflow/core/distributed_runtime/worker_cache.h" 28 #include "tensorflow/core/distributed_runtime/worker_interface.h" 29 #include "tensorflow/core/framework/allocation_description.pb.h" 30 #include "tensorflow/core/framework/cost_graph.pb.h" 31 #include "tensorflow/core/framework/node_def.pb.h" 32 #include "tensorflow/core/framework/node_def_util.h" 33 #include "tensorflow/core/framework/tensor.h" 34 #include "tensorflow/core/framework/tensor_description.pb.h" 35 #include "tensorflow/core/graph/graph_partition.h" 36 #include "tensorflow/core/graph/tensor_id.h" 37 #include "tensorflow/core/lib/core/blocking_counter.h" 38 #include "tensorflow/core/lib/core/notification.h" 39 #include "tensorflow/core/lib/core/refcount.h" 40 #include "tensorflow/core/lib/core/status.h" 41 #include "tensorflow/core/lib/gtl/cleanup.h" 42 #include "tensorflow/core/lib/gtl/inlined_vector.h" 43 #include "tensorflow/core/lib/gtl/map_util.h" 44 #include "tensorflow/core/lib/random/random.h" 45 #include "tensorflow/core/lib/strings/numbers.h" 46 #include "tensorflow/core/lib/strings/str_util.h" 47 #include "tensorflow/core/lib/strings/strcat.h" 48 #include "tensorflow/core/lib/strings/stringprintf.h" 49 #include "tensorflow/core/platform/env.h" 50 #include "tensorflow/core/platform/logging.h" 51 #include "tensorflow/core/platform/macros.h" 52 #include "tensorflow/core/platform/mutex.h" 53 #include "tensorflow/core/platform/tracing.h" 54 #include "tensorflow/core/public/session_options.h" 55 56 namespace tensorflow { 57 58 // MasterSession wraps ClientGraph in a reference counted object. 59 // This way, MasterSession can clear up the cache mapping Run requests to 60 // compiled graphs while the compiled graph is still being used. 61 // 62 // TODO(zhifengc): Cleanup this class. It's becoming messy. 63 class MasterSession::ReffedClientGraph : public core::RefCounted { 64 public: 65 ReffedClientGraph(const string& handle, const BuildGraphOptions& bopts, 66 std::unique_ptr<ClientGraph> cg, 67 const SessionOptions& session_opts, 68 const StatsPublisherFactory& stats_publisher_factory, 69 GraphExecutionState* execution_state, bool is_partial, 70 WorkerCacheInterface* worker_cache, bool should_deregister) 71 : session_handle_(handle), 72 client_graph_(std::move(cg)), 73 session_opts_(session_opts), 74 is_partial_(is_partial), 75 debug_opts_(bopts.debug_options), 76 worker_cache_(worker_cache), 77 should_deregister_(should_deregister) { 78 VLOG(1) << "Created ReffedClientGraph for node with " 79 << client_graph()->graph.num_node_ids(); 80 81 stats_publisher_ = stats_publisher_factory(handle, bopts, session_opts); 82 83 // Initialize a name to node map for testing that fetches are reachable. 84 for (Node* n : execution_state->full_graph()->nodes()) { 85 name_to_node_.insert({n->name(), n}); 86 } 87 } 88 89 ~ReffedClientGraph() override { 90 if (should_deregister_) { 91 DeregisterPartitions(); 92 } 93 } 94 95 const ClientGraph* client_graph() { return client_graph_.get(); } 96 97 std::unique_ptr<ProfileHandler> GetProfileHandler(uint64 step, 98 int64 execution_count, 99 const RunOptions& ropts) { 100 return stats_publisher_->GetProfileHandler(step, execution_count, ropts); 101 } 102 103 // Turn RPC logging on or off, both at the WorkerCache used by this 104 // master process, and at each remote worker in use for the current 105 // partitions. 106 void SetRPCLogging(bool active) { 107 worker_cache_->SetLogging(active); 108 // Logging is a best-effort activity, so we make async calls to turn 109 // it on/off and don't make use of the responses. 110 for (auto& p : partitions_) { 111 LoggingRequest* req = new LoggingRequest; 112 req->set_rpc_logging(active); 113 LoggingResponse* resp = new LoggingResponse; 114 Ref(); 115 p.worker->LoggingAsync(req, resp, [this, req, resp](const Status& s) { 116 delete req; 117 delete resp; 118 // ReffedClientGraph owns p.worker so we need to hold a ref to 119 // ensure that the method doesn't attempt to access p.worker after 120 // ReffedClient graph has deleted it. 121 // TODO(suharshs): Simplify this ownership model. 122 Unref(); 123 }); 124 } 125 } 126 127 // Retrieve all RPC logs data accumulated for the current step, both 128 // from the local WorkerCache in use by this master process and from 129 // all the remote workers executing the remote partitions. 130 void RetrieveLogs(int64 step_id, StepStats* ss) { 131 // Get the local data first, because it sets *ss without merging. 132 worker_cache_->RetrieveLogs(step_id, ss); 133 134 // Then merge in data from all the remote workers. 135 LoggingRequest req; 136 req.add_fetch_step_id(step_id); 137 int waiting_for = partitions_.size(); 138 if (waiting_for > 0) { 139 mutex scoped_mu; 140 BlockingCounter all_done(waiting_for); 141 for (auto& p : partitions_) { 142 LoggingResponse* resp = new LoggingResponse; 143 p.worker->LoggingAsync( 144 &req, resp, 145 [step_id, ss, resp, &scoped_mu, &waiting_for, 146 &all_done](const Status& s) { 147 { 148 mutex_lock l(scoped_mu); 149 if (s.ok()) { 150 for (auto& lss : resp->step()) { 151 if (step_id != lss.step_id()) { 152 LOG(ERROR) << "Wrong step_id in LoggingResponse"; 153 continue; 154 } 155 ss->MergeFrom(lss.step_stats()); 156 } 157 } 158 delete resp; 159 } 160 // Must not decrement all_done until out of critical section where 161 // *ss is updated. 162 all_done.DecrementCount(); 163 }); 164 } 165 all_done.Wait(); 166 } 167 } 168 169 // Local execution methods. 170 171 // Partitions the graph into subgraphs and registers them on 172 // workers. 173 Status RegisterPartitions(const PartitionOptions& popts); 174 175 // Runs one step of all partitions. 176 Status RunPartitions(const MasterEnv* env, int64 step_id, 177 int64 execution_count, PerStepState* pss, 178 CallOptions* opts, const RunStepRequestWrapper& req, 179 MutableRunStepResponseWrapper* resp, 180 CancellationManager* cm, const bool is_last_partial_run); 181 182 // Calls workers to cleanup states for the step "step_id". Calls 183 // `done` when all cleanup RPCs have completed. 184 void CleanupPartitionsAsync(int64 step_id, StatusCallback done); 185 186 // Post-processing of any runtime statistics gathered during execution. 187 void ProcessStats(int64 step_id, PerStepState* pss, ProfileHandler* ph, 188 const RunOptions& options, RunMetadata* resp); 189 void ProcessDeviceStats(ProfileHandler* ph, const DeviceStepStats& ds, 190 bool is_rpc); 191 // Checks that the requested fetches can be computed from the provided feeds. 192 Status CheckFetches(const RunStepRequestWrapper& req, 193 const RunState* run_state, 194 GraphExecutionState* execution_state); 195 196 string DetailText(const Node& node, const NodeExecStats& ns) { 197 int64 tot = 0; 198 for (auto& no : ns.output()) { 199 tot += no.tensor_description().allocation_description().requested_bytes(); 200 } 201 string bytes; 202 if (tot >= 0.1 * 1048576.0) { 203 bytes = strings::Printf("[%.1fMB] ", tot / 1048576.0); 204 } 205 return strings::StrCat(bytes, node.name(), " = ", node.type_string(), "(", 206 str_util::Join(node.requested_inputs(), ", "), ")"); 207 } 208 209 private: 210 const string session_handle_; 211 const std::unique_ptr<ClientGraph> client_graph_; 212 const SessionOptions session_opts_; 213 const bool is_partial_; 214 const DebugOptions& debug_opts_; 215 WorkerCacheInterface* const worker_cache_; // Not owned. 216 std::unordered_map<StringPiece, Node*, StringPieceHasher> name_to_node_; 217 const bool should_deregister_; 218 219 // Graph partitioned into per-location subgraphs. 220 struct Part { 221 // Worker name. 222 string name; 223 224 // Maps feed names to rendezvous keys. Empty most of the time. 225 std::unordered_map<string, string> feed_key; 226 227 // Maps rendezvous keys to fetch names. Empty most of the time. 228 std::unordered_map<string, string> key_fetch; 229 230 // The interface to the worker. Owned. 231 WorkerInterface* worker = nullptr; 232 233 // After registeration with the worker, graph_handle identifies 234 // this partition on the worker. 235 string graph_handle; 236 237 Part() : feed_key(3), key_fetch(3) {} 238 }; 239 240 // partitions_ is immutable after RegisterPartitions() call 241 // finishes. RunPartitions() can access partitions_ safely without 242 // acquiring locks. 243 std::vector<Part> partitions_; 244 245 mutable mutex mu_; 246 247 // Partition initialization and registration only needs to happen 248 // once. init_started_ && !init_done_ indicates the initialization 249 // is on going. 250 bool init_started_ GUARDED_BY(mu_) = false; 251 Notification init_done_; 252 253 // init_result_ remembers the initialization error if any. 254 Status init_result_ GUARDED_BY(mu_); 255 256 std::unique_ptr<StatsPublisherInterface> stats_publisher_; 257 258 // Send/Recv nodes that are the result of client-added 259 // feeds and fetches must be tracked so that the tensors 260 // can be added to the local rendezvous. 261 static void TrackFeedsAndFetches(Part* part, const GraphDef& graph_def, 262 const PartitionOptions& popts); 263 264 // The actual graph partitioning and registration implementation. 265 Status DoBuildPartitions( 266 PartitionOptions pots, 267 std::unordered_map<string, GraphDef>* out_partitions); 268 Status DoRegisterPartitions( 269 const PartitionOptions& popts, 270 std::unordered_map<string, GraphDef> graph_partitions); 271 272 // Deregisters the partitions on the workers. Called in the 273 // destructor and does not wait for the rpc completion. 274 void DeregisterPartitions(); 275 276 TF_DISALLOW_COPY_AND_ASSIGN(ReffedClientGraph); 277 }; 278 279 Status MasterSession::ReffedClientGraph::RegisterPartitions( 280 const PartitionOptions& popts) { 281 { // Ensure register once. 282 mu_.lock(); 283 if (!init_started_) { 284 init_started_ = true; 285 mu_.unlock(); 286 std::unordered_map<string, GraphDef> graph_defs; 287 Status s = DoBuildPartitions(popts, &graph_defs); 288 if (s.ok()) { 289 // NOTE(mrry): The pointers in `graph_defs_for_publishing` do not remain 290 // valid after the call to DoRegisterPartitions begins, so 291 // `stats_publisher_` must make a copy if it wants to retain the 292 // GraphDef objects. 293 std::vector<const GraphDef*> graph_defs_for_publishing; 294 graph_defs_for_publishing.reserve(partitions_.size()); 295 for (const auto& name_def : graph_defs) { 296 graph_defs_for_publishing.push_back(&name_def.second); 297 } 298 stats_publisher_->PublishGraphProto(graph_defs_for_publishing); 299 s = DoRegisterPartitions(popts, std::move(graph_defs)); 300 } 301 mu_.lock(); 302 init_result_ = s; 303 init_done_.Notify(); 304 } else { 305 mu_.unlock(); 306 init_done_.WaitForNotification(); 307 mu_.lock(); 308 } 309 const Status result = init_result_; 310 mu_.unlock(); 311 return result; 312 } 313 } 314 315 static string SplitByWorker(const Node* node) { 316 string task; 317 string device; 318 CHECK(DeviceNameUtils::SplitDeviceName(node->assigned_device_name(), &task, 319 &device)) 320 << "node: " << node->name() << " dev: " << node->assigned_device_name(); 321 return task; 322 } 323 324 void MasterSession::ReffedClientGraph::TrackFeedsAndFetches( 325 Part* part, const GraphDef& graph_def, const PartitionOptions& popts) { 326 for (int i = 0; i < graph_def.node_size(); ++i) { 327 const NodeDef& ndef = graph_def.node(i); 328 const bool is_recv = ndef.op() == "_Recv"; 329 const bool is_send = ndef.op() == "_Send"; 330 331 if (is_recv || is_send) { 332 // Only send/recv nodes that were added as feeds and fetches 333 // (client-terminated) should be tracked. Other send/recv nodes 334 // are for transferring data between partitions / memory spaces. 335 bool client_terminated; 336 TF_CHECK_OK(GetNodeAttr(ndef, "client_terminated", &client_terminated)); 337 if (client_terminated) { 338 string name; 339 TF_CHECK_OK(GetNodeAttr(ndef, "tensor_name", &name)); 340 string send_device; 341 TF_CHECK_OK(GetNodeAttr(ndef, "send_device", &send_device)); 342 string recv_device; 343 TF_CHECK_OK(GetNodeAttr(ndef, "recv_device", &recv_device)); 344 uint64 send_device_incarnation; 345 TF_CHECK_OK( 346 GetNodeAttr(ndef, "send_device_incarnation", 347 reinterpret_cast<int64*>(&send_device_incarnation))); 348 const string& key = 349 Rendezvous::CreateKey(send_device, send_device_incarnation, 350 recv_device, name, FrameAndIter(0, 0)); 351 352 if (is_recv) { 353 part->feed_key.insert({name, key}); 354 } else { 355 part->key_fetch.insert({key, name}); 356 } 357 } 358 } 359 } 360 } 361 362 Status MasterSession::ReffedClientGraph::DoBuildPartitions( 363 PartitionOptions popts, 364 std::unordered_map<string, GraphDef>* out_partitions) { 365 if (popts.need_to_record_start_times) { 366 CostModel cost_model(true); 367 cost_model.InitFromGraph(client_graph()->graph); 368 // TODO(yuanbyu): Use the real cost model. 369 // execution_state_->MergeFromGlobal(&cost_model); 370 SlackAnalysis sa(&client_graph()->graph, &cost_model); 371 sa.ComputeAsap(&popts.start_times); 372 } 373 374 // Partition the graph. 375 return Partition(popts, &client_graph_->graph, out_partitions); 376 } 377 378 Status MasterSession::ReffedClientGraph::DoRegisterPartitions( 379 const PartitionOptions& popts, 380 std::unordered_map<string, GraphDef> graph_partitions) { 381 partitions_.reserve(graph_partitions.size()); 382 Status s; 383 for (auto& name_def : graph_partitions) { 384 partitions_.resize(partitions_.size() + 1); 385 Part* part = &partitions_.back(); 386 part->name = name_def.first; 387 TrackFeedsAndFetches(part, name_def.second, popts); 388 part->worker = worker_cache_->CreateWorker(part->name); 389 if (part->worker == nullptr) { 390 s = errors::NotFound("worker ", part->name); 391 break; 392 } 393 } 394 if (!s.ok()) { 395 for (Part& part : partitions_) { 396 worker_cache_->ReleaseWorker(part.name, part.worker); 397 } 398 return s; 399 } 400 struct Call { 401 RegisterGraphRequest req; 402 RegisterGraphResponse resp; 403 Status status; 404 }; 405 const int num = partitions_.size(); 406 gtl::InlinedVector<Call, 4> calls(num); 407 BlockingCounter done(num); 408 for (int i = 0; i < num; ++i) { 409 const Part& part = partitions_[i]; 410 Call* c = &calls[i]; 411 c->req.set_session_handle(session_handle_); 412 c->req.mutable_graph_def()->Swap(&graph_partitions[part.name]); 413 *c->req.mutable_graph_options() = session_opts_.config.graph_options(); 414 *c->req.mutable_debug_options() = debug_opts_; 415 VLOG(2) << "Register " << c->req.graph_def().DebugString(); 416 auto cb = [c, &done](const Status& s) { 417 c->status = s; 418 done.DecrementCount(); 419 }; 420 part.worker->RegisterGraphAsync(&c->req, &c->resp, cb); 421 } 422 done.Wait(); 423 for (int i = 0; i < num; ++i) { 424 Call* c = &calls[i]; 425 s.Update(c->status); 426 partitions_[i].graph_handle = c->resp.graph_handle(); 427 } 428 return s; 429 } 430 431 // Helper class to manage "num" parallel RunGraph calls. 432 class RunManyGraphs { 433 public: 434 explicit RunManyGraphs(int num) : calls_(num), pending_(num) {} 435 436 ~RunManyGraphs() {} 437 438 // Returns the index-th call. 439 struct Call { 440 CallOptions opts; 441 std::unique_ptr<MutableRunGraphRequestWrapper> req; 442 std::unique_ptr<MutableRunGraphResponseWrapper> resp; 443 }; 444 Call* get(int index) { return &calls_[index]; } 445 446 // When the index-th call is done, updates the overall status. 447 void WhenDone(int index, const Status& s) { 448 TRACEPRINTF("Partition %d %s", index, s.ToString().c_str()); 449 auto resp = get(index)->resp.get(); 450 if (resp->status_code() != error::Code::OK) { 451 // resp->status_code will only be non-OK if s.ok(). 452 mutex_lock l(mu_); 453 UpdateStatusLocked( 454 Status(resp->status_code(), resp->status_error_message())); 455 } else if (!s.ok()) { 456 mutex_lock l(mu_); 457 UpdateStatusLocked(s); 458 } 459 pending_.DecrementCount(); 460 } 461 462 void StartCancel() { 463 mutex_lock l(mu_); 464 UpdateStatusLocked(errors::Cancelled("RunManyGraphs")); 465 } 466 467 void Wait() { pending_.Wait(); } 468 469 Status status() const { 470 mutex_lock l(mu_); 471 return status_; 472 } 473 474 private: 475 gtl::InlinedVector<Call, 4> calls_; 476 477 BlockingCounter pending_; 478 mutable mutex mu_; 479 Status status_ GUARDED_BY(mu_); 480 481 void UpdateStatusLocked(const Status& s) EXCLUSIVE_LOCKS_REQUIRED(mu_) { 482 if (status_.ok()) { 483 status_ = s; 484 for (Call& call : calls_) { 485 call.opts.StartCancel(); 486 } 487 } 488 } 489 490 TF_DISALLOW_COPY_AND_ASSIGN(RunManyGraphs); 491 }; 492 493 Status MasterSession::ReffedClientGraph::RunPartitions( 494 const MasterEnv* env, int64 step_id, int64 execution_count, 495 PerStepState* pss, CallOptions* call_opts, const RunStepRequestWrapper& req, 496 MutableRunStepResponseWrapper* resp, CancellationManager* cm, 497 const bool is_last_partial_run) { 498 VLOG(2) << "RunPartitions step_id " << step_id << " execution_count " 499 << execution_count; 500 // Maps the names of fed tensors to their index in `req`. 501 std::unordered_map<StringPiece, size_t, StringPieceHasher> feeds(3); 502 503 for (size_t i = 0; i < req.num_feeds(); ++i) { 504 if (!feeds.insert({req.feed_name(i), i}).second) { 505 return errors::InvalidArgument("Duplicated feeds: ", req.feed_name(i)); 506 } 507 } 508 509 // Prepares a number of calls to workers. One call per partition. 510 511 // Collect execution cost stats on a smoothly decreasing frequency. 512 ExecutorOpts exec_opts; 513 if (pss->report_tensor_allocations_upon_oom) { 514 exec_opts.set_report_tensor_allocations_upon_oom(true); 515 } 516 if (pss->collect_costs) { 517 exec_opts.set_record_costs(true); 518 } 519 if (pss->collect_timeline) { 520 exec_opts.set_record_timeline(true); 521 } 522 if (pss->collect_rpcs) { 523 SetRPCLogging(true); 524 } 525 if (pss->collect_partition_graphs) { 526 exec_opts.set_record_partition_graphs(true); 527 } 528 if (pss->collect_costs || pss->collect_timeline) { 529 pss->step_stats.resize(partitions_.size()); 530 } 531 532 const int num = partitions_.size(); 533 RunManyGraphs calls(num); 534 535 for (int i = 0; i < num; ++i) { 536 const Part& part = partitions_[i]; 537 RunManyGraphs::Call* c = calls.get(i); 538 c->req.reset(part.worker->CreateRunGraphRequest()); 539 c->resp.reset(part.worker->CreateRunGraphResponse()); 540 if (is_partial_) { 541 c->req->set_is_partial(is_partial_); 542 c->req->set_is_last_partial_run(is_last_partial_run); 543 } 544 c->req->set_session_handle(session_handle_); 545 c->req->set_graph_handle(part.graph_handle); 546 c->req->set_step_id(step_id); 547 *c->req->mutable_exec_opts() = exec_opts; 548 c->req->set_store_errors_in_response_body(true); 549 // If any feeds are provided, send the feed values together 550 // in the RunGraph request. 551 // In the partial case, we only want to include feeds provided in the req. 552 // In the non-partial case, all feeds in the request are in the part. 553 // We keep these as separate paths for now, to ensure we aren't 554 // inadvertently slowing down the normal run path. 555 if (is_partial_) { 556 for (size_t i = 0; i < req.num_feeds(); ++i) { 557 const string& name = req.feed_name(i); 558 const auto iter = part.feed_key.find(name); 559 if (iter == part.feed_key.end()) { 560 // The provided feed must be for a different partition. 561 continue; 562 } 563 const string& key = iter->second; 564 auto feeds_iter = feeds.find(name); 565 if (feeds_iter == feeds.end()) { 566 return errors::InvalidArgument("No feed is provided for feed=", name, 567 ", key=", key); 568 } else if (feeds_iter->second != static_cast<size_t>(i)) { 569 return errors::Internal("Cannot find feed named \"", name, 570 " in request."); 571 } 572 TF_RETURN_IF_ERROR(c->req->AddSendFromRunStepRequest(req, i, key)); 573 } 574 // TODO(suharshs): Make a map from feed to fetch_key to make this faster. 575 // For now, we just iterate through partitions to find the matching key. 576 for (int i = 0; static_cast<size_t>(i) < req.num_fetches(); ++i) { 577 const string& req_fetch = req.fetch_name(i); 578 for (const auto& key_fetch : part.key_fetch) { 579 if (key_fetch.second == req_fetch) { 580 c->req->add_recv_key(key_fetch.first); 581 break; 582 } 583 } 584 } 585 } else { 586 for (const auto& feed_key : part.feed_key) { 587 const string& feed = feed_key.first; 588 const string& key = feed_key.second; 589 const int64 feed_index = feeds[feed]; 590 TF_RETURN_IF_ERROR( 591 c->req->AddSendFromRunStepRequest(req, feed_index, key)); 592 } 593 for (const auto& key_fetch : part.key_fetch) { 594 const string& key = key_fetch.first; 595 c->req->add_recv_key(key); 596 } 597 } 598 } 599 600 // Issues RunGraph calls. 601 for (int i = 0; i < num; ++i) { 602 const Part& part = partitions_[i]; 603 RunManyGraphs::Call* call = calls.get(i); 604 TRACEPRINTF("Partition %d %s", i, part.name.c_str()); 605 part.worker->RunGraphAsync( 606 &call->opts, call->req.get(), call->resp.get(), 607 std::bind(&RunManyGraphs::WhenDone, &calls, i, std::placeholders::_1)); 608 } 609 610 // Waits for the RunGraph calls. 611 call_opts->SetCancelCallback([&calls]() { calls.StartCancel(); }); 612 auto token = cm->get_cancellation_token(); 613 const bool success = 614 cm->RegisterCallback(token, [&calls]() { calls.StartCancel(); }); 615 if (!success) { 616 calls.StartCancel(); 617 } 618 calls.Wait(); 619 call_opts->ClearCancelCallback(); 620 if (success) { 621 cm->DeregisterCallback(token); 622 } else { 623 return errors::Cancelled("Step was cancelled"); 624 } 625 626 // Collects fetches. 627 Status status = calls.status(); 628 if (status.ok()) { 629 for (int i = 0; i < num; ++i) { 630 const Part& part = partitions_[i]; 631 MutableRunGraphResponseWrapper* run_graph_resp = calls.get(i)->resp.get(); 632 for (size_t j = 0; j < run_graph_resp->num_recvs(); ++j) { 633 auto iter = part.key_fetch.find(run_graph_resp->recv_key(j)); 634 if (iter == part.key_fetch.end()) { 635 status.Update(errors::Internal("Unexpected fetch key: ", 636 run_graph_resp->recv_key(j))); 637 break; 638 } 639 const string& fetch = iter->second; 640 status.Update( 641 resp->AddTensorFromRunGraphResponse(fetch, run_graph_resp, j)); 642 if (!status.ok()) { 643 break; 644 } 645 } 646 if (pss->collect_timeline) { 647 pss->step_stats[i].Swap(run_graph_resp->mutable_step_stats()); 648 } 649 if (pss->collect_costs) { 650 CostGraphDef* cost_graph = run_graph_resp->mutable_cost_graph(); 651 for (int j = 0; j < cost_graph->node_size(); ++j) { 652 resp->mutable_metadata()->mutable_cost_graph()->add_node()->Swap( 653 cost_graph->mutable_node(j)); 654 } 655 } 656 if (pss->collect_partition_graphs) { 657 protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs = 658 resp->mutable_metadata()->mutable_partition_graphs(); 659 for (size_t i = 0; i < run_graph_resp->num_partition_graphs(); i++) { 660 partition_graph_defs->Add()->Swap( 661 run_graph_resp->mutable_partition_graph(i)); 662 } 663 } 664 } 665 } 666 return status; 667 } 668 669 namespace { 670 671 class CleanupBroadcastHelper { 672 public: 673 CleanupBroadcastHelper(int64 step_id, int num_calls, StatusCallback done) 674 : resps_(num_calls), num_pending_(num_calls), done_(std::move(done)) { 675 req_.set_step_id(step_id); 676 } 677 678 // Returns a non-owned pointer to a request buffer for all calls. 679 CleanupGraphRequest* request() { return &req_; } 680 681 // Returns a non-owned pointer to a response buffer for the ith call. 682 CleanupGraphResponse* response(int i) { return &resps_[i]; } 683 684 // Called when the ith response is received. 685 void call_done(int i, const Status& s) { 686 bool run_callback = false; 687 Status status_copy; 688 { 689 mutex_lock l(mu_); 690 status_.Update(s); 691 if (--num_pending_ == 0) { 692 run_callback = true; 693 status_copy = status_; 694 } 695 } 696 if (run_callback) { 697 done_(status_copy); 698 // This is the last call, so delete the helper object. 699 delete this; 700 } 701 } 702 703 private: 704 // A single request shared between all workers. 705 CleanupGraphRequest req_; 706 // One response buffer for each worker. 707 gtl::InlinedVector<CleanupGraphResponse, 4> resps_; 708 709 mutex mu_; 710 // Number of requests remaining to be collected. 711 int num_pending_ GUARDED_BY(mu_); 712 // Aggregate status of the operation. 713 Status status_ GUARDED_BY(mu_); 714 // Callback to be called when all operations complete. 715 StatusCallback done_; 716 717 TF_DISALLOW_COPY_AND_ASSIGN(CleanupBroadcastHelper); 718 }; 719 720 } // namespace 721 722 void MasterSession::ReffedClientGraph::CleanupPartitionsAsync( 723 int64 step_id, StatusCallback done) { 724 const int num = partitions_.size(); 725 // Helper object will be deleted when the final call completes. 726 CleanupBroadcastHelper* helper = 727 new CleanupBroadcastHelper(step_id, num, std::move(done)); 728 for (int i = 0; i < num; ++i) { 729 const Part& part = partitions_[i]; 730 part.worker->CleanupGraphAsync( 731 helper->request(), helper->response(i), 732 [helper, i](const Status& s) { helper->call_done(i, s); }); 733 } 734 } 735 736 void MasterSession::ReffedClientGraph::ProcessStats(int64 step_id, 737 PerStepState* pss, 738 ProfileHandler* ph, 739 const RunOptions& options, 740 RunMetadata* resp) { 741 if (!pss->collect_costs && !pss->collect_timeline) return; 742 743 // Out-of-band logging data is collected now, during post-processing. 744 if (pss->collect_timeline) { 745 SetRPCLogging(false); 746 RetrieveLogs(step_id, &pss->rpc_stats); 747 } 748 for (size_t i = 0; i < partitions_.size(); ++i) { 749 const StepStats& ss = pss->step_stats[i]; 750 if (ph) { 751 for (const auto& ds : ss.dev_stats()) { 752 ProcessDeviceStats(ph, ds, false /*is_rpc*/); 753 } 754 } 755 } 756 if (ph) { 757 for (const auto& ds : pss->rpc_stats.dev_stats()) { 758 ProcessDeviceStats(ph, ds, true /*is_rpc*/); 759 } 760 ph->StepDone(pss->start_micros, pss->end_micros, 761 Microseconds(0) /*cleanup_time*/, 0 /*total_runops*/, 762 Status::OK()); 763 } 764 // Assemble all stats for this timeline into a merged StepStats. 765 if (pss->collect_timeline) { 766 StepStats step_stats_proto; 767 step_stats_proto.Swap(&pss->rpc_stats); 768 for (size_t i = 0; i < partitions_.size(); ++i) { 769 step_stats_proto.MergeFrom(pss->step_stats[i]); 770 pss->step_stats[i].Clear(); 771 } 772 pss->step_stats.clear(); 773 // Copy the stats back, but only for on-demand profiling to avoid slowing 774 // down calls that trigger the automatic profiling. 775 if (options.trace_level() == RunOptions::FULL_TRACE) { 776 resp->mutable_step_stats()->Swap(&step_stats_proto); 777 } else { 778 // If FULL_TRACE, it can be fetched from Session API, no need for 779 // duplicated publishing. 780 stats_publisher_->PublishStatsProto(step_stats_proto); 781 } 782 } 783 } 784 785 void MasterSession::ReffedClientGraph::ProcessDeviceStats( 786 ProfileHandler* ph, const DeviceStepStats& ds, bool is_rpc) { 787 const string& dev_name = ds.device(); 788 VLOG(1) << "Device " << dev_name << " reports stats for " 789 << ds.node_stats_size() << " nodes"; 790 for (const auto& ns : ds.node_stats()) { 791 if (is_rpc) { 792 // We don't have access to a good Node pointer, so we rely on 793 // sufficient data being present in the NodeExecStats. 794 ph->RecordOneOp(dev_name, ns, true /*is_copy*/, "", ns.node_name(), 795 ns.timeline_label()); 796 } else { 797 const Node* node = name_to_node_[ns.node_name()]; 798 const bool found_node_in_graph = node != nullptr; 799 if (!found_node_in_graph && ns.timeline_label().empty()) { 800 // The counter incrementing is not thread-safe. But we don't really 801 // care. 802 // TODO(zhengxq): we should implement a LOG_FIRST_N and LOG_EVERY_N for 803 // more general usage. 804 static int log_counter = 0; 805 if (log_counter < 10) { 806 log_counter++; 807 LOG(WARNING) << "Failed to find node " << ns.node_name() 808 << " for dev " << dev_name; 809 } 810 continue; 811 } 812 string optype = 813 found_node_in_graph ? node->type_string() : ns.node_name(); 814 string details; 815 if (!ns.timeline_label().empty()) { 816 details = ns.timeline_label(); 817 } else if (found_node_in_graph) { 818 details = DetailText(*node, ns); 819 } else { 820 // Leave details string empty 821 } 822 ph->RecordOneOp(dev_name, ns, false /*is_copy*/, ns.node_name(), optype, 823 details); 824 } 825 } 826 } 827 828 // TODO(suharshs): Merge with CheckFetches in DirectSession. 829 // TODO(suharsh,mrry): Build a map from fetch target to set of feeds it depends 830 // on once at setup time to prevent us from computing the dependencies 831 // everytime. 832 // TODO(suharshs,mrry): Consider removing the need for execution_state to reduce 833 // contention. 834 Status MasterSession::ReffedClientGraph::CheckFetches( 835 const RunStepRequestWrapper& req, const RunState* run_state, 836 GraphExecutionState* execution_state) { 837 // Build the set of pending feeds that we haven't seen. 838 std::unordered_set<TensorId, TensorId::Hasher> pending_feeds; 839 for (const auto& input : run_state->pending_inputs) { 840 // Skip if already fed. 841 if (input.second) continue; 842 TensorId id(ParseTensorName(input.first)); 843 const auto it = name_to_node_.find(id.first); 844 if (it == name_to_node_.end()) { 845 return errors::NotFound("Feed ", input.first, ": not found"); 846 } 847 pending_feeds.insert(id); 848 } 849 for (size_t i = 0; i < req.num_feeds(); ++i) { 850 const TensorId id(ParseTensorName(req.feed_name(i))); 851 pending_feeds.erase(id); 852 } 853 854 // Initialize the stack with the fetch nodes. 855 std::vector<const Node*> stack; 856 for (size_t i = 0; i < req.num_fetches(); ++i) { 857 const string& fetch = req.fetch_name(i); 858 const TensorId id(ParseTensorName(fetch)); 859 auto it = name_to_node_.find(id.first); 860 if (it == name_to_node_.end()) { 861 return errors::NotFound("Fetch ", fetch, ": not found"); 862 } 863 stack.push_back(it->second); 864 } 865 866 // Any tensor needed for fetches can't be in pending_feeds. 867 // We need to use the original full graph from execution state. 868 const Graph* graph = execution_state->full_graph(); 869 std::vector<bool> visited(graph->num_node_ids(), false); 870 while (!stack.empty()) { 871 const Node* n = stack.back(); 872 stack.pop_back(); 873 874 for (const Edge* in_edge : n->in_edges()) { 875 const Node* in_node = in_edge->src(); 876 if (pending_feeds.count({in_node->name(), in_edge->src_output()}) > 0) { 877 return errors::InvalidArgument("Fetch ", in_node->name(), ":", 878 in_edge->src_output(), 879 " can't be computed from the feeds" 880 " that have been fed so far."); 881 } 882 if (!visited[in_node->id()]) { 883 visited[in_node->id()] = true; 884 stack.push_back(in_node); 885 } 886 } 887 } 888 return Status::OK(); 889 } 890 891 // Asynchronously deregisters subgraphs on the workers, without waiting for the 892 // result. 893 void MasterSession::ReffedClientGraph::DeregisterPartitions() { 894 struct Call { 895 DeregisterGraphRequest req; 896 DeregisterGraphResponse resp; 897 }; 898 for (Part& part : partitions_) { 899 // The graph handle may be empty if we failed during partition registration. 900 if (!part.graph_handle.empty()) { 901 Call* c = new Call; 902 c->req.set_session_handle(session_handle_); 903 c->req.set_graph_handle(part.graph_handle); 904 // NOTE(mrry): We must capture `worker_cache_` since `this` 905 // could be deleted before the callback is called. 906 WorkerCacheInterface* worker_cache = worker_cache_; 907 const string name = part.name; 908 WorkerInterface* w = part.worker; 909 CHECK_NOTNULL(w); 910 auto cb = [worker_cache, c, name, w](const Status& s) { 911 if (!s.ok()) { 912 // This error is potentially benign, so we don't log at the 913 // error level. 914 LOG(INFO) << "DeregisterGraph error: " << s; 915 } 916 delete c; 917 worker_cache->ReleaseWorker(name, w); 918 }; 919 w->DeregisterGraphAsync(&c->req, &c->resp, cb); 920 } 921 } 922 } 923 924 void BuildBuildGraphOptions(const RunStepRequestWrapper& req, 925 BuildGraphOptions* opts) { 926 for (size_t i = 0; i < req.num_feeds(); ++i) { 927 opts->feed_endpoints.push_back(req.feed_name(i)); 928 } 929 for (size_t i = 0; i < req.num_fetches(); ++i) { 930 opts->fetch_endpoints.push_back(req.fetch_name(i)); 931 } 932 for (size_t i = 0; i < req.num_targets(); ++i) { 933 opts->target_nodes.push_back(req.target_name(i)); 934 } 935 936 if (!req.options().debug_options().debug_tensor_watch_opts().empty()) { 937 opts->debug_options = req.options().debug_options(); 938 } 939 940 std::sort(opts->feed_endpoints.begin(), opts->feed_endpoints.end()); 941 std::sort(opts->target_nodes.begin(), opts->target_nodes.end()); 942 std::sort(opts->fetch_endpoints.begin(), opts->fetch_endpoints.end()); 943 } 944 945 void BuildBuildGraphOptions(const PartialRunSetupRequest& req, 946 BuildGraphOptions* opts) { 947 for (const auto& feed : req.feed()) { 948 opts->feed_endpoints.push_back(feed); 949 } 950 for (const auto& fetch : req.fetch()) { 951 opts->fetch_endpoints.push_back(fetch); 952 } 953 for (const auto& target : req.target()) { 954 opts->target_nodes.push_back(target); 955 } 956 957 // TODO(cais): Add TFDBG support to partial runs. 958 959 std::sort(opts->feed_endpoints.begin(), opts->feed_endpoints.end()); 960 std::sort(opts->target_nodes.begin(), opts->target_nodes.end()); 961 std::sort(opts->fetch_endpoints.begin(), opts->fetch_endpoints.end()); 962 } 963 964 uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) { 965 uint64 h = 0x2b992ddfa23249d6ull; 966 for (const string& name : opts.feed_endpoints) { 967 h = Hash64(name.c_str(), name.size(), h); 968 } 969 for (const string& name : opts.target_nodes) { 970 h = Hash64(name.c_str(), name.size(), h); 971 } 972 for (const string& name : opts.fetch_endpoints) { 973 h = Hash64(name.c_str(), name.size(), h); 974 } 975 976 if (!opts.debug_options.debug_tensor_watch_opts().empty()) { 977 const string watch_summary = SummarizeDebugTensorWatches( 978 opts.debug_options.debug_tensor_watch_opts()); 979 h = Hash64(watch_summary.c_str(), watch_summary.size(), h); 980 } 981 982 return h; 983 } 984 985 string BuildGraphOptionsString(const BuildGraphOptions& opts) { 986 string buf; 987 for (const string& name : opts.feed_endpoints) { 988 strings::StrAppend(&buf, " FdE: ", name); 989 } 990 strings::StrAppend(&buf, "\n"); 991 for (const string& name : opts.target_nodes) { 992 strings::StrAppend(&buf, " TN: ", name); 993 } 994 strings::StrAppend(&buf, "\n"); 995 for (const string& name : opts.fetch_endpoints) { 996 strings::StrAppend(&buf, " FeE: ", name); 997 } 998 strings::StrAppend(&buf, "\n"); 999 return buf; 1000 } 1001 1002 MasterSession::MasterSession( 1003 const SessionOptions& opt, const MasterEnv* env, 1004 std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs, 1005 std::unique_ptr<WorkerCacheInterface> worker_cache, 1006 std::unique_ptr<DeviceSet> device_set, 1007 StatsPublisherFactory stats_publisher_factory) 1008 : session_opts_(opt), 1009 env_(env), 1010 handle_(strings::FpToString(random::New64())), 1011 remote_devs_(std::move(remote_devs)), 1012 worker_cache_(std::move(worker_cache)), 1013 devices_(std::move(device_set)), 1014 stats_publisher_factory_(std::move(stats_publisher_factory)), 1015 graph_version_(0), 1016 run_graphs_(5), 1017 partial_run_graphs_(5) { 1018 UpdateLastAccessTime(); 1019 CHECK(devices_) << "device_set was null!"; 1020 1021 VLOG(1) << "Session " << handle_ << " #local " << env->local_devices.size() 1022 << " #remote " << remote_devs_->size(); 1023 1024 LOG(INFO) << "Start master session " << handle_ 1025 << " with config: " << session_opts_.config.ShortDebugString(); 1026 } 1027 1028 MasterSession::~MasterSession() { 1029 for (const auto& iter : run_graphs_) iter.second->Unref(); 1030 for (const auto& iter : partial_run_graphs_) iter.second->Unref(); 1031 } 1032 1033 void MasterSession::UpdateLastAccessTime() { 1034 last_access_time_usec_.store(Env::Default()->NowMicros()); 1035 } 1036 1037 Status MasterSession::Create(GraphDef* graph_def, 1038 const WorkerCacheFactoryOptions& options) { 1039 if (session_opts_.config.use_per_session_threads() || 1040 session_opts_.config.session_inter_op_thread_pool_size() > 0) { 1041 return errors::InvalidArgument( 1042 "Distributed session does not support session thread pool options."); 1043 } 1044 if (session_opts_.config.graph_options().place_pruned_graph()) { 1045 // TODO(b/29900832): Fix this or remove the option. 1046 LOG(WARNING) << "Distributed session does not support the " 1047 "place_pruned_graph option."; 1048 session_opts_.config.mutable_graph_options()->set_place_pruned_graph(false); 1049 } 1050 1051 GraphExecutionStateOptions execution_options; 1052 execution_options.device_set = devices_.get(); 1053 execution_options.session_options = &session_opts_; 1054 { 1055 mutex_lock l(mu_); 1056 TF_RETURN_IF_ERROR(GraphExecutionState::MakeForBaseGraph( 1057 graph_def, execution_options, &execution_state_)); 1058 } 1059 // TODO(b/36574172): Remove these conditions when ClusterSpec 1060 // propagation is supported in all servers. 1061 if (options.cluster_def != nullptr || 1062 session_opts_.config.isolate_session_state()) { 1063 should_delete_worker_sessions_ = true; 1064 return CreateWorkerSessions(options); 1065 } 1066 return Status::OK(); 1067 } 1068 1069 Status MasterSession::CreateWorkerSessions( 1070 const WorkerCacheFactoryOptions& options) { 1071 std::vector<string> worker_names; 1072 WorkerCacheInterface* worker_cache = get_worker_cache(); 1073 worker_cache->ListWorkers(&worker_names); 1074 1075 struct WorkerGroup { 1076 // The worker name. (Not owned.) 1077 const string* name; 1078 1079 // The worker referenced by name. (Not owned.) 1080 WorkerInterface* worker = nullptr; 1081 1082 // Request and responses used for a given worker. 1083 CreateWorkerSessionRequest request; 1084 CreateWorkerSessionResponse response; 1085 Status status = Status::OK(); 1086 }; 1087 BlockingCounter done(worker_names.size()); 1088 std::vector<WorkerGroup> workers(worker_names.size()); 1089 1090 // Release the workers. 1091 auto cleanup = gtl::MakeCleanup([this, &workers, worker_cache] { 1092 for (auto&& worker_group : workers) { 1093 if (worker_group.worker != nullptr) { 1094 worker_cache->ReleaseWorker(*worker_group.name, worker_group.worker); 1095 } 1096 } 1097 }); 1098 1099 Status status = Status::OK(); 1100 // Create all the workers & kick off the computations. 1101 for (size_t i = 0; i < worker_names.size(); ++i) { 1102 workers[i].name = &worker_names[i]; 1103 workers[i].worker = worker_cache->CreateWorker(worker_names[i]); 1104 workers[i].request.set_session_handle(handle_); 1105 if (options.cluster_def) { 1106 *workers[i].request.mutable_server_def()->mutable_cluster() = 1107 *options.cluster_def; 1108 workers[i].request.mutable_server_def()->set_protocol(*options.protocol); 1109 // Session state is always isolated when ClusterSpec propagation 1110 // is in use. 1111 workers[i].request.set_isolate_session_state(true); 1112 } else { 1113 workers[i].request.set_isolate_session_state( 1114 session_opts_.config.isolate_session_state()); 1115 } 1116 1117 DeviceNameUtils::ParsedName name; 1118 if (!DeviceNameUtils::ParseFullName(worker_names[i], &name)) { 1119 status = errors::Internal("Could not parse name ", worker_names[i]); 1120 LOG(WARNING) << status; 1121 return status; 1122 } 1123 if (!name.has_job || !name.has_task) { 1124 status = errors::Internal("Incomplete worker name ", worker_names[i]); 1125 LOG(WARNING) << status; 1126 return status; 1127 } 1128 1129 workers[i].request.mutable_server_def()->set_job_name(name.job); 1130 workers[i].request.mutable_server_def()->set_task_index(name.task); 1131 } 1132 1133 for (size_t i = 0; i < worker_names.size(); ++i) { 1134 auto cb = [i, &workers, &done](const Status& s) { 1135 workers[i].status = s; 1136 done.DecrementCount(); 1137 }; 1138 workers[i].worker->CreateWorkerSessionAsync(&workers[i].request, 1139 &workers[i].response, cb); 1140 } 1141 1142 done.Wait(); 1143 for (size_t i = 0; i < workers.size(); ++i) { 1144 status.Update(workers[i].status); 1145 } 1146 return status; 1147 } 1148 1149 Status MasterSession::DeleteWorkerSessions() { 1150 WorkerCacheInterface* worker_cache = get_worker_cache(); 1151 std::vector<string> worker_names; 1152 worker_cache->ListWorkers(&worker_names); 1153 1154 struct WorkerGroup { 1155 // The worker name. (Not owned.) 1156 const string* name; 1157 1158 // The worker referenced by name. (Not owned.) 1159 WorkerInterface* worker = nullptr; 1160 1161 // Request and responses used for a given worker. 1162 DeleteWorkerSessionRequest request; 1163 DeleteWorkerSessionResponse response; 1164 Status status = Status::OK(); 1165 }; 1166 BlockingCounter done(worker_names.size()); 1167 std::vector<WorkerGroup> workers(worker_names.size()); 1168 1169 // Release the workers. 1170 auto cleanup = gtl::MakeCleanup([this, &workers, worker_cache] { 1171 for (auto&& worker_group : workers) { 1172 if (worker_group.worker != nullptr) { 1173 worker_cache->ReleaseWorker(*worker_group.name, worker_group.worker); 1174 } 1175 } 1176 }); 1177 1178 Status status = Status::OK(); 1179 // Create all the workers & kick off the computations. 1180 for (size_t i = 0; i < worker_names.size(); ++i) { 1181 workers[i].name = &worker_names[i]; 1182 workers[i].worker = worker_cache->CreateWorker(worker_names[i]); 1183 workers[i].request.set_session_handle(handle_); 1184 } 1185 1186 for (size_t i = 0; i < worker_names.size(); ++i) { 1187 auto cb = [i, &workers, &done](const Status& s) { 1188 workers[i].status = s; 1189 done.DecrementCount(); 1190 }; 1191 workers[i].worker->DeleteWorkerSessionAsync(&workers[i].request, 1192 &workers[i].response, cb); 1193 } 1194 1195 done.Wait(); 1196 for (size_t i = 0; i < workers.size(); ++i) { 1197 status.Update(workers[i].status); 1198 } 1199 return status; 1200 } 1201 1202 Status MasterSession::ListDevices(ListDevicesResponse* resp) const { 1203 if (worker_cache_) { 1204 // This is a ClusterSpec-propagated session, and thus env_->local_devices 1205 // are invalid. 1206 1207 // Mark the "client_device" as the sole local device. 1208 const Device* client_device = devices_->client_device(); 1209 for (const Device* dev : devices_->devices()) { 1210 if (dev != client_device) { 1211 *(resp->add_remote_device()) = dev->attributes(); 1212 } 1213 } 1214 *(resp->add_local_device()) = client_device->attributes(); 1215 } else { 1216 for (Device* dev : env_->local_devices) { 1217 *(resp->add_local_device()) = dev->attributes(); 1218 } 1219 for (auto&& dev : *remote_devs_) { 1220 *(resp->add_local_device()) = dev->attributes(); 1221 } 1222 } 1223 return Status::OK(); 1224 } 1225 1226 Status MasterSession::Extend(const ExtendSessionRequest* req, 1227 ExtendSessionResponse* resp) { 1228 UpdateLastAccessTime(); 1229 std::unique_ptr<GraphExecutionState> extended_execution_state; 1230 { 1231 mutex_lock l(mu_); 1232 if (closed_) { 1233 return errors::FailedPrecondition("Session is closed."); 1234 } 1235 1236 if (graph_version_ != req->current_graph_version()) { 1237 return errors::Aborted("Current version is ", graph_version_, 1238 " but caller expected ", 1239 req->current_graph_version(), "."); 1240 } 1241 1242 CHECK(execution_state_); 1243 TF_RETURN_IF_ERROR( 1244 execution_state_->Extend(req->graph_def(), &extended_execution_state)); 1245 1246 CHECK(extended_execution_state); 1247 // The old execution state will be released outside the lock. 1248 execution_state_.swap(extended_execution_state); 1249 ++graph_version_; 1250 resp->set_new_graph_version(graph_version_); 1251 } 1252 return Status::OK(); 1253 } 1254 1255 WorkerCacheInterface* MasterSession::get_worker_cache() const { 1256 if (worker_cache_) { 1257 return worker_cache_.get(); 1258 } 1259 return env_->worker_cache; 1260 } 1261 1262 Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count, 1263 ReffedClientGraph** rcg, bool is_partial) { 1264 const uint64 hash = HashBuildGraphOptions(opts); 1265 { 1266 mutex_lock l(mu_); 1267 // Keep track of how many times this subgraph has been executed in 1268 // this session. 1269 int64* c = &subgraph_execution_counts_[hash]; 1270 *count = (*c)++; 1271 // TODO(suharshs): We cache partial run graphs and run graphs separately 1272 // because there is preprocessing that needs to only be run for partial 1273 // run calls. 1274 RCGMap* m = is_partial ? &partial_run_graphs_ : &run_graphs_; 1275 auto iter = m->find(hash); 1276 if (iter == m->end()) { 1277 // We have not seen this subgraph before. Build the subgraph and 1278 // cache it. 1279 VLOG(1) << "Unseen hash " << hash << " for " 1280 << BuildGraphOptionsString(opts) << " is_partial = " << is_partial 1281 << "\n"; 1282 std::unique_ptr<ClientGraph> client_graph; 1283 TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph)); 1284 WorkerCacheInterface* worker_cache = get_worker_cache(); 1285 auto entry = new ReffedClientGraph( 1286 handle_, opts, std::move(client_graph), session_opts_, 1287 stats_publisher_factory_, execution_state_.get(), is_partial, 1288 worker_cache, !should_delete_worker_sessions_); 1289 iter = m->insert({hash, entry}).first; 1290 VLOG(1) << "Preparing to execute new graph"; 1291 } 1292 *rcg = iter->second; 1293 (*rcg)->Ref(); 1294 } 1295 return Status::OK(); 1296 } 1297 1298 void MasterSession::ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref, 1299 RCGMap* rcg_map) { 1300 VLOG(1) << "Discarding all reffed graphs"; 1301 for (auto p : *rcg_map) { 1302 ReffedClientGraph* rcg = p.second; 1303 if (to_unref) { 1304 to_unref->push_back(rcg); 1305 } else { 1306 rcg->Unref(); 1307 } 1308 } 1309 rcg_map->clear(); 1310 } 1311 1312 Status MasterSession::PartialRunSetup(const PartialRunSetupRequest* req, 1313 PartialRunSetupResponse* resp) { 1314 std::vector<string> inputs, outputs, targets; 1315 for (const auto& feed : req->feed()) { 1316 inputs.push_back(feed); 1317 } 1318 for (const auto& fetch : req->fetch()) { 1319 outputs.push_back(fetch); 1320 } 1321 for (const auto& target : req->target()) { 1322 targets.push_back(target); 1323 } 1324 1325 string handle = std::to_string(partial_run_handle_counter_.fetch_add(1)); 1326 1327 ReffedClientGraph* rcg = nullptr; 1328 int64 count = 0; 1329 1330 // Prepare. 1331 BuildGraphOptions opts; 1332 BuildBuildGraphOptions(*req, &opts); 1333 TF_RETURN_IF_ERROR(StartStep(opts, &count, &rcg, true)); 1334 // Keeps the highest 8 bits 0x01: we reserve some bits of the 1335 // step_id for future use. 1336 uint64 step_id = (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56); 1337 TRACEPRINTF("stepid %llu", step_id); 1338 1339 rcg->Ref(); 1340 RunState* run_state = new RunState(inputs, outputs, rcg, step_id, count); 1341 { 1342 mutex_lock l(mu_); 1343 partial_runs_.emplace( 1344 std::make_pair(handle, std::unique_ptr<RunState>(run_state))); 1345 } 1346 1347 TF_RETURN_IF_ERROR(BuildAndRegisterPartitions(rcg)); 1348 1349 resp->set_partial_run_handle(handle); 1350 return Status::OK(); 1351 } 1352 1353 Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req, 1354 MutableRunStepResponseWrapper* resp) { 1355 UpdateLastAccessTime(); 1356 { 1357 mutex_lock l(mu_); 1358 if (closed_) { 1359 return errors::FailedPrecondition("Session is closed."); 1360 } 1361 ++num_running_; 1362 // Note: all code paths must eventually call MarkRunCompletion() 1363 // in order to appropriate decrement the num_running_ counter. 1364 } 1365 Status status; 1366 if (!req.partial_run_handle().empty()) { 1367 status = DoPartialRun(opts, req, resp); 1368 } else { 1369 status = DoRunWithLocalExecution(opts, req, resp); 1370 } 1371 return status; 1372 } 1373 1374 // Decrements num_running_ and broadcasts if num_running_ is zero. 1375 void MasterSession::MarkRunCompletion() { 1376 mutex_lock l(mu_); 1377 --num_running_; 1378 if (num_running_ == 0) { 1379 num_running_is_zero_.notify_all(); 1380 } 1381 } 1382 1383 Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) { 1384 // Registers subgraphs if haven't done so. 1385 PartitionOptions popts; 1386 popts.node_to_loc = SplitByWorker; 1387 // The closures potps.{new_name,get_incarnation} are called synchronously in 1388 // RegisterPartitions() below, so do not need a Ref()/Unref() pair to keep 1389 // "this" alive during the closure. 1390 popts.new_name = [this](const string& prefix) { 1391 mutex_lock l(mu_); 1392 return strings::StrCat(prefix, "_S", next_node_id_++); 1393 }; 1394 popts.flib_def = rcg->client_graph()->flib_def.get(); 1395 popts.get_incarnation = [this](const string& name) -> int64 { 1396 Device* d = devices_->FindDeviceByName(name); 1397 if (d == nullptr) { 1398 return PartitionOptions::kIllegalIncarnation; 1399 } else { 1400 return d->attributes().incarnation(); 1401 } 1402 }; 1403 popts.control_flow_added = false; 1404 const bool enable_bfloat16_sendrecv = 1405 session_opts_.config.graph_options().enable_bfloat16_sendrecv(); 1406 popts.should_cast = [enable_bfloat16_sendrecv](const Edge* e) { 1407 if (e->IsControlEdge()) { 1408 return DT_FLOAT; 1409 } 1410 DataType dtype = BaseType(e->src()->output_type(e->src_output())); 1411 if (enable_bfloat16_sendrecv && dtype == DT_FLOAT) { 1412 return DT_BFLOAT16; 1413 } else { 1414 return dtype; 1415 } 1416 }; 1417 if (session_opts_.config.graph_options().enable_recv_scheduling()) { 1418 popts.scheduling_for_recvs = true; 1419 popts.need_to_record_start_times = true; 1420 } 1421 1422 TF_RETURN_IF_ERROR(rcg->RegisterPartitions(popts)); 1423 1424 return Status::OK(); 1425 } 1426 1427 Status MasterSession::DoPartialRun(CallOptions* opts, 1428 const RunStepRequestWrapper& req, 1429 MutableRunStepResponseWrapper* resp) { 1430 auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); }); 1431 const string& prun_handle = req.partial_run_handle(); 1432 RunState* run_state = nullptr; 1433 { 1434 mutex_lock l(mu_); 1435 auto it = partial_runs_.find(prun_handle); 1436 if (it == partial_runs_.end()) { 1437 return errors::InvalidArgument( 1438 "Must run PartialRunSetup before performing partial runs"); 1439 } 1440 run_state = it->second.get(); 1441 } 1442 1443 // If this is the first partial run, initialize the PerStepState. 1444 if (!run_state->step_started) { 1445 run_state->step_started = true; 1446 PerStepState pss; 1447 1448 const auto count = run_state->count; 1449 pss.collect_timeline = 1450 req.options().trace_level() == RunOptions::FULL_TRACE; 1451 pss.collect_rpcs = req.options().trace_level() == RunOptions::FULL_TRACE; 1452 pss.report_tensor_allocations_upon_oom = 1453 req.options().report_tensor_allocations_upon_oom(); 1454 1455 // Build the cost model every 'build_cost_model_every' steps after skipping 1456 // an 1457 // initial 'build_cost_model_after' steps. 1458 const int64 build_cost_model_after = 1459 session_opts_.config.graph_options().build_cost_model_after(); 1460 const int64 build_cost_model_every = 1461 session_opts_.config.graph_options().build_cost_model(); 1462 pss.collect_costs = 1463 build_cost_model_every > 0 && 1464 ((count + 1 - build_cost_model_after) % build_cost_model_every == 0); 1465 pss.collect_partition_graphs = req.options().output_partition_graphs(); 1466 1467 std::unique_ptr<ProfileHandler> ph = run_state->rcg->GetProfileHandler( 1468 run_state->step_id, count, req.options()); 1469 if (ph) { 1470 pss.collect_timeline = true; 1471 pss.collect_rpcs = ph->should_collect_rpcs(); 1472 } 1473 1474 run_state->pss = std::move(pss); 1475 run_state->ph = std::move(ph); 1476 } 1477 1478 // Make sure that this is a new set of feeds that are still pending. 1479 for (size_t i = 0; i < req.num_feeds(); ++i) { 1480 const string& feed = req.feed_name(i); 1481 auto it = run_state->pending_inputs.find(feed); 1482 if (it == run_state->pending_inputs.end()) { 1483 return errors::InvalidArgument( 1484 "The feed ", feed, " was not specified in partial_run_setup."); 1485 } else if (it->second) { 1486 return errors::InvalidArgument("The feed ", feed, 1487 " has already been fed."); 1488 } 1489 } 1490 // Check that this is a new set of fetches that are still pending. 1491 for (size_t i = 0; i < req.num_fetches(); ++i) { 1492 const string& fetch = req.fetch_name(i); 1493 auto it = run_state->pending_outputs.find(fetch); 1494 if (it == run_state->pending_outputs.end()) { 1495 return errors::InvalidArgument( 1496 "The fetch ", fetch, " was not specified in partial_run_setup."); 1497 } else if (it->second) { 1498 return errors::InvalidArgument("The fetch ", fetch, 1499 " has already been fetched."); 1500 } 1501 } 1502 1503 // Ensure that the requested fetches can be computed from the provided feeds. 1504 { 1505 mutex_lock l(mu_); 1506 TF_RETURN_IF_ERROR( 1507 run_state->rcg->CheckFetches(req, run_state, execution_state_.get())); 1508 } 1509 1510 // Determine if this partial run satisfies all the pending inputs and outputs. 1511 for (size_t i = 0; i < req.num_feeds(); ++i) { 1512 auto it = run_state->pending_inputs.find(req.feed_name(i)); 1513 it->second = true; 1514 } 1515 for (size_t i = 0; i < req.num_fetches(); ++i) { 1516 auto it = run_state->pending_outputs.find(req.fetch_name(i)); 1517 it->second = true; 1518 } 1519 bool is_last_partial_run = run_state->PendingDone(); 1520 1521 Status s = run_state->rcg->RunPartitions( 1522 env_, run_state->step_id, run_state->count, &run_state->pss, opts, req, 1523 resp, &cancellation_manager_, is_last_partial_run); 1524 1525 // Delete the run state if there is an error or all fetches are done. 1526 if (!s.ok() || is_last_partial_run) { 1527 ReffedClientGraph* rcg = run_state->rcg; 1528 run_state->pss.end_micros = Env::Default()->NowMicros(); 1529 // Schedule post-processing and cleanup to be done asynchronously. 1530 Ref(); 1531 rcg->Ref(); 1532 rcg->ProcessStats(run_state->step_id, &run_state->pss, run_state->ph.get(), 1533 req.options(), resp->mutable_metadata()); 1534 cleanup.release(); // MarkRunCompletion called in done closure. 1535 rcg->CleanupPartitionsAsync( 1536 run_state->step_id, [this, rcg, prun_handle](const Status& s) { 1537 if (!s.ok()) { 1538 LOG(ERROR) << "Cleanup partition error: " << s; 1539 } 1540 rcg->Unref(); 1541 MarkRunCompletion(); 1542 Unref(); 1543 }); 1544 mutex_lock l(mu_); 1545 partial_runs_.erase(prun_handle); 1546 } 1547 return s; 1548 } 1549 1550 Status MasterSession::CreateDebuggerState( 1551 const DebugOptions& debug_options, const RunStepRequestWrapper& req, 1552 int64 rcg_execution_count, 1553 std::unique_ptr<DebuggerStateInterface>* debugger_state) { 1554 TF_RETURN_IF_ERROR( 1555 DebuggerStateRegistry::CreateState(debug_options, debugger_state)); 1556 1557 std::vector<string> input_names; 1558 for (size_t i = 0; i < req.num_feeds(); ++i) { 1559 input_names.push_back(req.feed_name(i)); 1560 } 1561 std::vector<string> output_names; 1562 for (size_t i = 0; i < req.num_fetches(); ++i) { 1563 output_names.push_back(req.fetch_name(i)); 1564 } 1565 std::vector<string> target_names; 1566 for (size_t i = 0; i < req.num_targets(); ++i) { 1567 target_names.push_back(req.target_name(i)); 1568 } 1569 1570 // TODO(cais): We currently use -1 as a dummy value for session run count. 1571 // While this counter value is straightforward to define and obtain for 1572 // DirectSessions, it is less so for non-direct Sessions. Devise a better 1573 // way to get its value when the need arises. 1574 TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata( 1575 debug_options.global_step(), rcg_execution_count, rcg_execution_count, 1576 input_names, output_names, target_names)); 1577 1578 return Status::OK(); 1579 } 1580 1581 Status MasterSession::DoRunWithLocalExecution( 1582 CallOptions* opts, const RunStepRequestWrapper& req, 1583 MutableRunStepResponseWrapper* resp) { 1584 VLOG(2) << "DoRunWithLocalExecution req: " << req.DebugString(); 1585 PerStepState pss; 1586 pss.start_micros = Env::Default()->NowMicros(); 1587 auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); }); 1588 1589 // Prepare. 1590 BuildGraphOptions bgopts; 1591 BuildBuildGraphOptions(req, &bgopts); 1592 ReffedClientGraph* rcg = nullptr; 1593 int64 count = 0; 1594 TF_RETURN_IF_ERROR(StartStep(bgopts, &count, &rcg, false)); 1595 1596 // Unref "rcg" when out of scope. 1597 core::ScopedUnref unref(rcg); 1598 1599 std::unique_ptr<DebuggerStateInterface> debugger_state; 1600 const DebugOptions& debug_options = req.options().debug_options(); 1601 1602 if (!debug_options.debug_tensor_watch_opts().empty()) { 1603 TF_RETURN_IF_ERROR( 1604 CreateDebuggerState(debug_options, req, count, &debugger_state)); 1605 } 1606 TF_RETURN_IF_ERROR(BuildAndRegisterPartitions(rcg)); 1607 1608 // Keeps the highest 8 bits 0x01: we reserve some bits of the 1609 // step_id for future use. 1610 const uint64 step_id = (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56); 1611 TRACEPRINTF("stepid %llu", step_id); 1612 1613 pss.collect_timeline = req.options().trace_level() == RunOptions::FULL_TRACE; 1614 pss.collect_rpcs = req.options().trace_level() == RunOptions::FULL_TRACE; 1615 pss.report_tensor_allocations_upon_oom = 1616 req.options().report_tensor_allocations_upon_oom(); 1617 // Build the cost model every 'build_cost_model_every' steps after skipping an 1618 // initial 'build_cost_model_after' steps. 1619 const int64 build_cost_model_after = 1620 session_opts_.config.graph_options().build_cost_model_after(); 1621 const int64 build_cost_model_every = 1622 session_opts_.config.graph_options().build_cost_model(); 1623 pss.collect_costs = 1624 build_cost_model_every > 0 && 1625 ((count + 1 - build_cost_model_after) % build_cost_model_every == 0); 1626 pss.collect_partition_graphs = req.options().output_partition_graphs(); 1627 1628 std::unique_ptr<ProfileHandler> ph = 1629 rcg->GetProfileHandler(step_id, count, req.options()); 1630 if (ph) { 1631 pss.collect_timeline = true; 1632 pss.collect_rpcs = ph->should_collect_rpcs(); 1633 } 1634 1635 Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp, 1636 &cancellation_manager_, false); 1637 if (s.ok()) { 1638 pss.end_micros = Env::Default()->NowMicros(); 1639 1640 // Schedule post-processing and cleanup to be done asynchronously. 1641 rcg->ProcessStats(step_id, &pss, ph.get(), req.options(), 1642 resp->mutable_metadata()); 1643 } else if (errors::IsCancelled(s)) { 1644 mutex_lock l(mu_); 1645 if (closed_) { 1646 if (garbage_collected_) { 1647 s = errors::Cancelled( 1648 "Step was cancelled because the session was garbage collected due " 1649 "to inactivity."); 1650 } else { 1651 s = errors::Cancelled( 1652 "Step was cancelled by an explicit call to `Session::Close()`."); 1653 } 1654 } 1655 } 1656 Ref(); 1657 rcg->Ref(); 1658 cleanup.release(); // MarkRunCompletion called in done closure. 1659 rcg->CleanupPartitionsAsync(step_id, [this, rcg](const Status& s) { 1660 if (!s.ok()) { 1661 LOG(ERROR) << "Cleanup partition error: " << s; 1662 } 1663 rcg->Unref(); 1664 MarkRunCompletion(); 1665 Unref(); 1666 }); 1667 return s; 1668 } 1669 1670 Status MasterSession::Close() { 1671 { 1672 mutex_lock l(mu_); 1673 closed_ = true; // All subsequent calls to Run() or Extend() will fail. 1674 } 1675 cancellation_manager_.StartCancel(); 1676 std::vector<ReffedClientGraph*> to_unref; 1677 { 1678 mutex_lock l(mu_); 1679 while (num_running_ != 0) { 1680 num_running_is_zero_.wait(l); 1681 } 1682 ClearRunsTable(&to_unref, &run_graphs_); 1683 ClearRunsTable(&to_unref, &partial_run_graphs_); 1684 } 1685 for (ReffedClientGraph* rcg : to_unref) rcg->Unref(); 1686 if (should_delete_worker_sessions_) { 1687 Status s = DeleteWorkerSessions(); 1688 if (!s.ok()) { 1689 LOG(WARNING) << s; 1690 } 1691 } 1692 return Status::OK(); 1693 } 1694 1695 void MasterSession::GarbageCollect() { 1696 { 1697 mutex_lock l(mu_); 1698 closed_ = true; 1699 garbage_collected_ = true; 1700 } 1701 cancellation_manager_.StartCancel(); 1702 Unref(); 1703 } 1704 1705 MasterSession::RunState::RunState(const std::vector<string>& input_names, 1706 const std::vector<string>& output_names, 1707 ReffedClientGraph* rcg, const uint64 step_id, 1708 const int64 count) 1709 : rcg(rcg), step_id(step_id), count(count) { 1710 // Initially all the feeds and fetches are pending. 1711 for (auto& name : input_names) { 1712 pending_inputs[name] = false; 1713 } 1714 for (auto& name : output_names) { 1715 pending_outputs[name] = false; 1716 } 1717 } 1718 1719 MasterSession::RunState::~RunState() { 1720 if (rcg) rcg->Unref(); 1721 } 1722 1723 bool MasterSession::RunState::PendingDone() const { 1724 for (const auto& it : pending_inputs) { 1725 if (!it.second) return false; 1726 } 1727 for (const auto& it : pending_outputs) { 1728 if (!it.second) return false; 1729 } 1730 return true; 1731 } 1732 1733 } // end namespace tensorflow 1734