1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/common_runtime/executor.h" 17 18 #include <atomic> 19 #include <deque> 20 #include <memory> 21 #include <string> 22 #include <unordered_map> 23 #include <vector> 24 25 #include "tensorflow/core/common_runtime/costmodel_manager.h" 26 #include "tensorflow/core/common_runtime/pending_counts.h" 27 #include "tensorflow/core/common_runtime/step_stats_collector.h" 28 #include "tensorflow/core/framework/allocation_description.pb.h" 29 #include "tensorflow/core/framework/allocator.h" 30 #include "tensorflow/core/framework/cancellation.h" 31 #include "tensorflow/core/framework/control_flow.h" 32 #include "tensorflow/core/framework/device_attributes.pb.h" 33 #include "tensorflow/core/framework/graph.pb.h" 34 #include "tensorflow/core/framework/log_memory.h" 35 #include "tensorflow/core/framework/node_def_util.h" 36 #include "tensorflow/core/framework/op_kernel.h" 37 #include "tensorflow/core/framework/op_segment.h" 38 #include "tensorflow/core/framework/step_stats.pb.h" 39 #include "tensorflow/core/framework/tensor.h" 40 #include "tensorflow/core/framework/tensor_reference.h" 41 #include "tensorflow/core/framework/types.h" 42 #include "tensorflow/core/framework/types.pb.h" 43 #include "tensorflow/core/graph/edgeset.h" 44 #include "tensorflow/core/lib/core/errors.h" 45 #include "tensorflow/core/lib/core/notification.h" 46 #include "tensorflow/core/lib/core/stringpiece.h" 47 #include "tensorflow/core/lib/gtl/flatmap.h" 48 #include "tensorflow/core/lib/gtl/flatset.h" 49 #include "tensorflow/core/lib/gtl/inlined_vector.h" 50 #include "tensorflow/core/lib/gtl/manual_constructor.h" 51 #include "tensorflow/core/lib/gtl/stl_util.h" 52 #include "tensorflow/core/lib/hash/hash.h" 53 #include "tensorflow/core/lib/strings/str_util.h" 54 #include "tensorflow/core/lib/strings/stringprintf.h" 55 #include "tensorflow/core/platform/logging.h" 56 #include "tensorflow/core/platform/macros.h" 57 #include "tensorflow/core/platform/mutex.h" 58 #include "tensorflow/core/platform/thread_annotations.h" 59 #include "tensorflow/core/platform/tracing.h" 60 #include "tensorflow/core/platform/types.h" 61 #include "tensorflow/core/util/tensor_slice_reader_cache.h" 62 63 namespace tensorflow { 64 namespace { 65 66 // 1-D, 0 element tensor. 67 static const Tensor* const kEmptyTensor = new Tensor; 68 69 bool IsInitializationOp(const Node* node) { 70 return node->op_def().allows_uninitialized_input(); 71 } 72 73 // Sets the timeline_label field of *node_stats, using data from *node. 74 // Returns true iff the node is a transfer node. 75 // TODO(tucker): merge with the DetailText function in session.cc 76 // in a common location. 77 bool SetTimelineLabel(const Node* node, NodeExecStatsWrapper* stats) { 78 bool is_transfer_node = false; 79 if (!stats) { 80 return is_transfer_node; 81 } 82 string memory; 83 for (auto& all : stats->stats()->memory()) { 84 int64 tot = all.total_bytes(); 85 if (tot >= 0.1 * 1048576.0) { 86 int64 peak = all.peak_bytes(); 87 if (peak > 0) { 88 memory = 89 strings::StrCat(memory, "[", all.allocator_name(), 90 strings::Printf(" %.1fMB %.1fMB] ", tot / 1048576.0, 91 peak / 1048576.0)); 92 } else { 93 memory = strings::StrCat(memory, "[", all.allocator_name(), 94 strings::Printf(" %.1fMB] ", tot / 1048576.0)); 95 } 96 } 97 } 98 const AttrSlice attrs = node->attrs(); 99 string text; 100 if (IsSend(node)) { 101 string tensor_name; 102 TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name)); 103 string recv_device; 104 TF_CHECK_OK(GetNodeAttr(attrs, "recv_device", &recv_device)); 105 text = strings::StrCat(memory, node->name(), " = ", node->type_string(), 106 "(", tensor_name, " @", recv_device); 107 is_transfer_node = true; 108 } else if (IsRecv(node)) { 109 string tensor_name; 110 TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name)); 111 string send_device; 112 TF_CHECK_OK(GetNodeAttr(attrs, "send_device", &send_device)); 113 text = strings::StrCat(memory, node->name(), " = ", node->type_string(), 114 "(", tensor_name, " @", send_device); 115 is_transfer_node = true; 116 } else { 117 text = 118 strings::StrCat(memory, node->name(), " = ", node->type_string(), "(", 119 str_util::Join(node->requested_inputs(), ", "), ")"); 120 } 121 stats->stats()->set_timeline_label(text); 122 return is_transfer_node; 123 } 124 125 // Helper routines for collecting step stats. 126 namespace nodestats { 127 inline int64 NowInUsec() { return Env::Default()->NowMicros(); } 128 129 void SetScheduled(NodeExecStatsWrapper* stats, int64 t) { 130 if (!stats) return; 131 stats->stats()->set_scheduled_micros(t); 132 } 133 134 void SetAllStart(NodeExecStatsWrapper* stats) { 135 if (!stats) return; 136 stats->stats()->set_all_start_micros(NowInUsec()); 137 } 138 139 void SetOpStart(NodeExecStatsWrapper* stats) { 140 if (!stats) return; 141 NodeExecStats* nt = stats->stats(); 142 DCHECK_NE(nt->all_start_micros(), 0); 143 nt->set_op_start_rel_micros(NowInUsec() - nt->all_start_micros()); 144 } 145 146 void SetOpEnd(NodeExecStatsWrapper* stats) { 147 if (!stats) return; 148 NodeExecStats* nt = stats->stats(); 149 DCHECK_NE(nt->all_start_micros(), 0); 150 nt->set_op_end_rel_micros(NowInUsec() - nt->all_start_micros()); 151 } 152 153 void SetAllEnd(NodeExecStatsWrapper* stats) { 154 if (!stats) return; 155 NodeExecStats* nt = stats->stats(); 156 DCHECK_NE(nt->all_start_micros(), 0); 157 nt->set_all_end_rel_micros(NowInUsec() - nt->all_start_micros()); 158 } 159 160 void SetOutput(NodeExecStatsWrapper* stats, int slot, const Tensor* v) { 161 if (!stats) return; 162 DCHECK(v); 163 NodeOutput* no = stats->stats()->add_output(); 164 no->set_slot(slot); 165 v->FillDescription(no->mutable_tensor_description()); 166 } 167 168 void SetMemory(NodeExecStatsWrapper* stats, OpKernelContext* ctx) { 169 if (!stats) return; 170 171 for (const auto& allocator_pair : ctx->wrapped_allocators()) { 172 stats->AddAllocation(allocator_pair.first, allocator_pair.second); 173 } 174 auto* ms = stats->stats()->mutable_memory_stats(); 175 ms->set_temp_memory_size(ctx->temp_memory_allocated()); 176 for (const auto& alloc_id : ctx->persistent_alloc_ids()) { 177 ms->mutable_persistent_tensor_alloc_ids()->Add(alloc_id); 178 } 179 ms->set_persistent_memory_size(ctx->persistent_memory_allocated()); 180 } 181 182 void SetReferencedTensors(NodeExecStatsWrapper* stats, 183 const TensorReferenceVector& tensors) { 184 if (!stats) return; 185 // be careful not to increment the reference count on any tensor 186 // while recording the information 187 for (size_t i = 0; i < tensors.size(); ++i) { 188 AllocationDescription* description = 189 stats->stats()->add_referenced_tensor(); 190 tensors.at(i).FillDescription(description); 191 } 192 } 193 194 } // namespace nodestats 195 196 class ExecutorImpl; 197 class GraphView; 198 199 struct EdgeInfo { 200 int dst_id; 201 int output_slot : 31; 202 // true if this is the last info for output_slot in the EdgeInfo list. 203 bool is_last : 1; 204 int input_slot; 205 }; 206 207 struct NodeItem { 208 NodeItem() {} 209 210 // A graph node. 211 const Node* node = nullptr; 212 213 // The kernel for this node. 214 OpKernel* kernel = nullptr; 215 216 bool kernel_is_expensive : 1; // True iff kernel->IsExpensive() 217 bool kernel_is_async : 1; // True iff kernel->AsAsync() != nullptr 218 bool is_merge : 1; // True iff IsMerge(node) 219 bool is_enter : 1; // True iff IsEnter(node) 220 bool is_exit : 1; // True iff IsExit(node) 221 bool is_control_trigger : 1; // True iff IsControlTrigger(node) 222 bool is_sink : 1; // True iff IsSink(node) 223 // True iff IsEnter(node) || IsExit(node) || IsNextIteration(node) 224 bool is_enter_exit_or_next_iter : 1; 225 226 // Cached values of node->num_inputs() and node->num_outputs(), to 227 // avoid levels of indirection. 228 int num_inputs; 229 int num_outputs; 230 231 // ExecutorImpl::tensors_[input_start] is the 1st positional input 232 // for this node. 233 int input_start = 0; 234 235 // Number of output edges. 236 size_t num_output_edges; 237 238 PendingCounts::Handle pending_id; 239 240 const EdgeInfo* output_edge_list() const { return output_edge_base(); } 241 242 // ith output edge. 243 const EdgeInfo& output_edge(int i) const { 244 DCHECK_GE(i, 0); 245 DCHECK_LT(i, num_output_edges); 246 return output_edge_base()[i]; 247 } 248 249 DataType input_type(int i) const { 250 DCHECK_LT(i, num_inputs); 251 return static_cast<DataType>(input_type_base()[i]); 252 } 253 DataType output_type(int i) const { 254 DCHECK_LT(i, num_outputs); 255 return static_cast<DataType>(output_type_base()[i]); 256 } 257 258 // Return array of per-output allocator attributes. 259 const AllocatorAttributes* output_attrs() const { return output_attr_base(); } 260 261 private: 262 friend class GraphView; 263 264 // Variable length section starts immediately after *this 265 // (uint8 is enough for DataType). 266 // EdgeInfo out_edges[num_out_edges]; 267 // AllocatorAttributes output_attr[num_outputs]; 268 // uint8 input_type[num_inputs]; 269 // uint8 output_type[num_outputs]; 270 271 // Return pointer to variable length section. 272 char* var() const { 273 return const_cast<char*>(reinterpret_cast<const char*>(this) + 274 sizeof(NodeItem)); 275 } 276 277 EdgeInfo* output_edge_base() const { 278 return reinterpret_cast<EdgeInfo*>(var()); 279 } 280 AllocatorAttributes* output_attr_base() const { 281 return reinterpret_cast<AllocatorAttributes*>(var() + sizeof(EdgeInfo) * 282 num_output_edges); 283 } 284 uint8* input_type_base() const { 285 return reinterpret_cast<uint8*>(var() + 286 sizeof(EdgeInfo) * num_output_edges + 287 sizeof(AllocatorAttributes) * num_outputs); 288 } 289 uint8* output_type_base() const { 290 return reinterpret_cast<uint8*>( 291 var() + sizeof(EdgeInfo) * num_output_edges + 292 sizeof(AllocatorAttributes) * num_outputs + sizeof(uint8) * num_inputs); 293 } 294 295 TF_DISALLOW_COPY_AND_ASSIGN(NodeItem); 296 }; 297 298 typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec; 299 typedef gtl::InlinedVector<DeviceContext*, 4> DeviceContextVec; 300 typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec; 301 302 // Immutable view of a Graph organized for efficient execution. 303 class GraphView { 304 public: 305 GraphView() : space_(nullptr) {} 306 ~GraphView(); 307 308 void Initialize(const Graph* g); 309 Status SetAllocAttrs(const Graph* g, const Device* device); 310 311 NodeItem* node(size_t id) const { 312 DCHECK_GE(id, 0); 313 DCHECK_LT(id, num_nodes_); 314 uint32 offset = node_offsets_[id]; 315 return ((offset == kuint32max) 316 ? nullptr 317 : reinterpret_cast<NodeItem*>(space_ + node_offsets_[id])); 318 } 319 320 private: 321 char* InitializeNode(char* ptr, const Node* n); 322 size_t NodeItemBytes(const Node* n); 323 324 int32 num_nodes_ = 0; 325 uint32* node_offsets_ = nullptr; // array of size "graph_.num_node_ids()" 326 // node_offsets_[id] holds the byte offset for node w/ "id" in space_ 327 328 char* space_; // NodeItem objects are allocated here 329 330 TF_DISALLOW_COPY_AND_ASSIGN(GraphView); 331 }; 332 333 class ExecutorImpl : public Executor { 334 public: 335 ExecutorImpl(const LocalExecutorParams& p, std::unique_ptr<const Graph> g) 336 : params_(p), graph_(std::move(g)), gview_() { 337 CHECK(p.create_kernel != nullptr); 338 CHECK(p.delete_kernel != nullptr); 339 } 340 341 ~ExecutorImpl() override { 342 for (int i = 0; i < graph_->num_node_ids(); i++) { 343 NodeItem* item = gview_.node(i); 344 if (item != nullptr) { 345 params_.delete_kernel(item->kernel); 346 } 347 } 348 for (auto fiter : frame_info_) { 349 delete fiter.second; 350 } 351 } 352 353 Status Initialize(); 354 355 // Process all Nodes in the current graph, attempting to infer the 356 // memory allocation attributes to be used wherever they may allocate 357 // a tensor buffer. 358 Status SetAllocAttrs(); 359 360 void RunAsync(const Args& args, DoneCallback done) override; 361 362 private: 363 friend class ExecutorState; 364 365 struct ControlFlowInfo { 366 gtl::FlatSet<string> unique_frame_names; 367 std::vector<string> frame_names; 368 }; 369 370 struct FrameInfo { 371 FrameInfo() 372 : input_count(0), 373 total_inputs(0), 374 pending_counts(nullptr), 375 nodes(nullptr) {} 376 377 // The total number of inputs to a frame. 378 int input_count; 379 380 // The total number of input tensors of a frame. 381 // == sum(nodes[*].num_inputs()) where nodes are the nodes in the frame. 382 int total_inputs; 383 384 // Used to determine the next place to allocate space in the 385 // pending_counts data structure we'll eventually construct 386 PendingCounts::Layout pending_counts_layout; 387 388 // Each frame has its own PendingCounts only for the nodes in the frame. 389 PendingCounts* pending_counts; // Owned 390 391 // The nodes in a frame. Used only for debugging. 392 std::vector<const Node*>* nodes; // Owned 393 394 ~FrameInfo() { 395 delete pending_counts; 396 delete nodes; 397 } 398 }; 399 400 static Status BuildControlFlowInfo(const Graph* graph, 401 ControlFlowInfo* cf_info); 402 void InitializePending(const Graph* graph, const ControlFlowInfo& cf_info); 403 404 FrameInfo* EnsureFrameInfo(const string& fname) { 405 auto slot = &frame_info_[fname]; 406 if (*slot == nullptr) { 407 *slot = new FrameInfo; 408 } 409 return *slot; 410 } 411 412 // Owned. 413 LocalExecutorParams params_; 414 std::unique_ptr<const Graph> graph_; 415 GraphView gview_; 416 417 // A cached value of params_ 418 bool device_record_tensor_accesses_ = false; 419 420 // Root nodes (with no in edges) that should form the initial ready queue 421 std::vector<const Node*> root_nodes_; 422 423 // Mapping from frame name to static information about the frame. 424 // TODO(yuanbyu): We could cache it along with the graph so to avoid 425 // the overhead of constructing it for each executor instance. 426 gtl::FlatMap<string, FrameInfo*> frame_info_; 427 428 TF_DISALLOW_COPY_AND_ASSIGN(ExecutorImpl); 429 }; 430 431 // Infer memory allocation attributes of a node n's output, 432 // based on its use node dst. Note that dst might not be directly 433 // connected to n by a single edge, but might be a downstream 434 // consumer of n's output by reference. *attr is updated with any 435 // necessary attributes. 436 Status InferAllocAttr(const Node* n, const Node* dst, 437 const DeviceNameUtils::ParsedName& local_dev_name, 438 AllocatorAttributes* attr); 439 440 GraphView::~GraphView() { 441 static_assert(std::is_trivially_destructible<AllocatorAttributes>::value, 442 "Update code if AllocatorAttributes gains a destructor"); 443 static_assert(std::is_trivially_destructible<EdgeInfo>::value, 444 "Update code if EdgeInfo gains a destructor"); 445 for (int i = 0; i < num_nodes_; i++) { 446 NodeItem* n = node(i); 447 if (n != nullptr) { 448 n->NodeItem::~NodeItem(); 449 // Memory for "n" itself is held in space_ & gets cleaned up below 450 } 451 } 452 delete[] node_offsets_; 453 delete[] space_; 454 } 455 456 size_t GraphView::NodeItemBytes(const Node* n) { 457 const size_t num_output_edges = n->out_edges().size(); 458 const int num_inputs = n->num_inputs(); 459 const int num_outputs = n->num_outputs(); 460 461 // Compute number of bytes needed for NodeItem and variable length data. 462 // We do not subtract sizeof(var) since num_inputs/num_outputs might 463 // both be zero. 464 const size_t raw_bytes = 465 sizeof(NodeItem) // Fixed 466 + num_output_edges * sizeof(EdgeInfo) // output_edges[...] 467 + num_outputs * sizeof(AllocatorAttributes) // output_attr[...] 468 + num_inputs * sizeof(uint8) // input_type[num_inputs] 469 + num_outputs * sizeof(uint8); // output_type[num_outputs] 470 static constexpr size_t kItemAlignment = sizeof(NodeItem*); 471 static_assert(kItemAlignment % alignof(NodeItem) == 0, 472 "NodeItem must be aligned with kItemAlignment"); 473 static_assert(kItemAlignment % alignof(EdgeInfo) == 0, 474 "EdgeInfo must be aligned with kItemAlignment"); 475 static_assert(kItemAlignment % alignof(AllocatorAttributes) == 0, 476 "AllocatorAttributes must be aligned with kItemAlignment"); 477 static_assert(sizeof(NodeItem) % alignof(EdgeInfo) == 0, 478 "NodeItem must be aligned with EdgeInfo"); 479 static_assert(sizeof(NodeItem) % alignof(AllocatorAttributes) == 0, 480 "NodeItem must be aligned with AllocatorAttributes"); 481 static_assert(sizeof(EdgeInfo) % alignof(AllocatorAttributes) == 0, 482 "EdgeInfo must be aligned with AllocatorAttributes"); 483 const size_t bytes = 484 ((raw_bytes + kItemAlignment - 1) / kItemAlignment) * kItemAlignment; 485 return bytes; 486 } 487 488 char* GraphView::InitializeNode(char* ptr, const Node* n) { 489 const int id = n->id(); 490 CHECK(node_offsets_[id] == kuint32max); // Initial value in constructor 491 492 const size_t bytes = NodeItemBytes(n); 493 constexpr size_t kItemAlignment = sizeof(NodeItem*); 494 CHECK_EQ(reinterpret_cast<uintptr_t>(ptr) % kItemAlignment, 0); 495 NodeItem* item = reinterpret_cast<NodeItem*>(ptr); 496 497 // We store a 32-bit offset relative to the beginning of space_, so that we 498 // only need an array of 32-bit values to map from node id to the NodeItem*, 499 // (versus 64 bits on most machines if we just stored an array of NodeItem* 500 // pointers). Casting to int64 is needed on 32bit CPU to avoid comparing 501 // values as "int" vs "size_t" in CHECK_LE. 502 CHECK_LE(static_cast<int64>(ptr - space_), kuint32max); 503 const uint32 offset = static_cast<uint32>(ptr - space_); 504 node_offsets_[id] = offset; 505 ptr += bytes; 506 507 const size_t num_output_edges = n->out_edges().size(); 508 const int num_inputs = n->num_inputs(); 509 const int num_outputs = n->num_outputs(); 510 511 new (item) NodeItem(); 512 item->num_inputs = num_inputs; 513 item->num_outputs = num_outputs; 514 item->num_output_edges = num_output_edges; 515 516 // Fill output edges. 517 // Keep track of the last EdgeInfo in the EdgeInfo array that references 518 // a given output slot. For all but the last, we need to do a copy of the 519 // Tensor when propagating results downstream in the graph, but for the 520 // last one, we can just do a move of the Tensor object to propagate it. 521 gtl::InlinedVector<EdgeInfo*, 4> last_indices(num_outputs, nullptr); 522 EdgeInfo* dst_edge = item->output_edge_base(); 523 for (auto e : n->out_edges()) { 524 dst_edge->dst_id = e->dst()->id(); 525 CHECK_LE(e->src_output(), 0x3FFFFFFF); // Must fit in 31 bits 526 dst_edge->output_slot = e->src_output(); 527 dst_edge->is_last = false; 528 const int output_slot = dst_edge->output_slot; 529 if (output_slot >= 0) { 530 last_indices[output_slot] = dst_edge; 531 } 532 dst_edge->input_slot = e->dst_input(); 533 dst_edge++; 534 } 535 for (EdgeInfo* edge_info : last_indices) { 536 if (edge_info != nullptr) { 537 edge_info->is_last = true; 538 } 539 } 540 541 AllocatorAttributes* output_attrs = item->output_attr_base(); 542 for (int i = 0; i < num_outputs; i++) { 543 new (&output_attrs[i]) AllocatorAttributes(); 544 } 545 546 DCHECK_LT(DataType_MAX, 255); // Must fit in uint8 547 uint8* input_types = item->input_type_base(); 548 for (int i = 0; i < num_inputs; i++) { 549 input_types[i] = static_cast<uint8>(n->input_type(i)); 550 DCHECK_EQ(item->input_type(i), n->input_type(i)); 551 } 552 553 uint8* output_types = item->output_type_base(); 554 for (int i = 0; i < num_outputs; i++) { 555 output_types[i] = static_cast<uint8>(n->output_type(i)); 556 DCHECK_EQ(item->output_type(i), n->output_type(i)); 557 } 558 return ptr; 559 } 560 561 void GraphView::Initialize(const Graph* g) { 562 CHECK(node_offsets_ == nullptr); 563 const int num_nodes = g->num_node_ids(); 564 num_nodes_ = num_nodes; 565 size_t total_bytes = 0; 566 for (const Node* n : g->nodes()) { 567 total_bytes += NodeItemBytes(n); 568 } 569 570 node_offsets_ = new uint32[num_nodes]; 571 for (int i = 0; i < num_nodes; i++) { 572 node_offsets_[i] = kuint32max; 573 } 574 575 space_ = new char[total_bytes]; // NodeItem objects are allocated here 576 char* ptr = space_; 577 for (const Node* n : g->nodes()) { 578 ptr = InitializeNode(ptr, n); 579 } 580 CHECK_EQ(ptr, space_ + total_bytes); 581 } 582 583 void GetMaxPendingCounts(const Node* n, size_t* max_pending, 584 size_t* max_dead_count) { 585 const size_t num_in_edges = n->in_edges().size(); 586 size_t initial_count; 587 if (IsMerge(n)) { 588 // merge waits all control inputs so we initialize the pending 589 // count to be the number of control edges. 590 int32 num_control_edges = 0; 591 for (const Edge* edge : n->in_edges()) { 592 if (edge->IsControlEdge()) { 593 num_control_edges++; 594 } 595 } 596 // Use bit 0 to indicate if we are waiting for a ready live data input. 597 initial_count = 1 + (num_control_edges << 1); 598 } else { 599 initial_count = num_in_edges; 600 } 601 602 *max_pending = initial_count; 603 *max_dead_count = num_in_edges; 604 } 605 606 Status ExecutorImpl::Initialize() { 607 gview_.Initialize(graph_.get()); 608 609 // Build the information about frames in this subgraph. 610 ControlFlowInfo cf_info; 611 TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph_.get(), &cf_info)); 612 613 // Cache this value so we make this virtual function call once, rather 614 // that O(# steps * # nodes per step) times. 615 device_record_tensor_accesses_ = 616 params_.device->RequiresRecordingAccessedTensors(); 617 618 for (auto& it : cf_info.unique_frame_names) { 619 EnsureFrameInfo(it)->nodes = new std::vector<const Node*>; 620 } 621 622 // Preprocess every node in the graph to create an instance of op 623 // kernel for each node. 624 for (const Node* n : graph_->nodes()) { 625 const int id = n->id(); 626 const string& frame_name = cf_info.frame_names[id]; 627 FrameInfo* frame_info = EnsureFrameInfo(frame_name); 628 629 // See if this node is a root node, and if so, add to root_nodes_. 630 if (n->in_edges().empty()) { 631 root_nodes_.push_back(n); 632 } 633 634 NodeItem* item = gview_.node(id); 635 item->node = n; 636 637 item->input_start = frame_info->total_inputs; 638 frame_info->total_inputs += n->num_inputs(); 639 640 Status s = params_.create_kernel(n->def(), &item->kernel); 641 if (!s.ok()) { 642 item->kernel = nullptr; 643 s = AttachDef(s, *n); 644 LOG(ERROR) << "Executor failed to create kernel. " << s; 645 return s; 646 } 647 CHECK(item->kernel); 648 item->kernel_is_expensive = item->kernel->IsExpensive(); 649 item->kernel_is_async = (item->kernel->AsAsync() != nullptr); 650 item->is_merge = IsMerge(n); 651 item->is_enter = IsEnter(n); 652 item->is_exit = IsExit(n); 653 item->is_control_trigger = IsControlTrigger(n); 654 item->is_sink = IsSink(n); 655 item->is_enter_exit_or_next_iter = 656 (IsEnter(n) || IsExit(n) || IsNextIteration(n)); 657 658 // Compute the maximum values we'll store for this node in the 659 // pending counts data structure, and allocate a handle in 660 // that frame's pending counts data structure that has enough 661 // space to store these maximal count values. 662 size_t max_pending, max_dead; 663 GetMaxPendingCounts(n, &max_pending, &max_dead); 664 item->pending_id = 665 frame_info->pending_counts_layout.CreateHandle(max_pending, max_dead); 666 667 // Initialize static information about the frames in the graph. 668 frame_info->nodes->push_back(n); 669 if (IsEnter(n)) { 670 string enter_name; 671 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &enter_name)); 672 EnsureFrameInfo(enter_name)->input_count++; 673 } 674 } 675 676 // Initialize PendingCounts only after item->pending_id is initialized for 677 // all nodes. 678 InitializePending(graph_.get(), cf_info); 679 680 return gview_.SetAllocAttrs(graph_.get(), params_.device); 681 } 682 683 Status GraphView::SetAllocAttrs(const Graph* g, const Device* device) { 684 Status s; 685 DeviceNameUtils::ParsedName local_dev_name = device->parsed_name(); 686 687 for (const Node* n : g->nodes()) { 688 NodeItem* item = node(n->id()); 689 AllocatorAttributes* attrs = item->output_attr_base(); 690 691 // Examine the out edges of each node looking for special use 692 // cases that may affect memory allocation attributes. 693 for (auto e : n->out_edges()) { 694 if (!e->IsControlEdge()) { 695 AllocatorAttributes attr; 696 s = InferAllocAttr(n, e->dst(), local_dev_name, &attr); 697 if (!s.ok()) return s; 698 if (attr.value != 0) { 699 attrs[e->src_output()].Merge(attr); 700 } 701 } 702 } 703 704 for (int out = 0; out < n->num_outputs(); out++) { 705 const OpKernel* op_kernel = item->kernel; 706 DCHECK_LT(out, op_kernel->output_memory_types().size()); 707 bool on_host = op_kernel->output_memory_types()[out] == HOST_MEMORY; 708 if (on_host) { 709 AllocatorAttributes h; 710 h.set_on_host(on_host); 711 attrs[out].Merge(h); 712 } 713 } 714 } 715 return s; 716 } 717 718 Status InferAllocAttr(const Node* n, const Node* dst, 719 const DeviceNameUtils::ParsedName& local_dev_name, 720 AllocatorAttributes* attr) { 721 Status s; 722 // Note that it's possible for *n to be a Recv and *dst to be a Send, 723 // so these two cases are not mutually exclusive. 724 if (IsRecv(n)) { 725 string src_name; 726 s = GetNodeAttr(n->attrs(), "send_device", &src_name); 727 if (!s.ok()) return s; 728 DeviceNameUtils::ParsedName parsed_src_name; 729 if (!DeviceNameUtils::ParseFullName(src_name, &parsed_src_name)) { 730 s = errors::Internal("Bad send_device attr '", src_name, "' in node ", 731 n->name()); 732 return s; 733 } 734 if (!DeviceNameUtils::IsSameAddressSpace(parsed_src_name, local_dev_name)) { 735 // Value is going to be the sink of an RPC. 736 attr->set_nic_compatible(true); 737 VLOG(2) << "node " << n->name() << " is the sink of an RPC in"; 738 } else if ((local_dev_name.type == "CPU" || n->IsHostRecv()) && 739 parsed_src_name.type != "CPU") { 740 // Value is going to be the sink of a local DMA from GPU to CPU (or other 741 // types of accelerators). 742 attr->set_gpu_compatible(true); 743 VLOG(2) << "node " << n->name() << " is the sink of a gpu->cpu copy"; 744 } else { 745 VLOG(2) << "default alloc case local type " << local_dev_name.type 746 << " remote type " << parsed_src_name.type; 747 } 748 } 749 if (IsSend(dst)) { 750 string dst_name; 751 s = GetNodeAttr(dst->attrs(), "recv_device", &dst_name); 752 if (!s.ok()) return s; 753 DeviceNameUtils::ParsedName parsed_dst_name; 754 if (!DeviceNameUtils::ParseFullName(dst_name, &parsed_dst_name)) { 755 s = errors::Internal("Bad recv_device attr '", dst_name, "' in node ", 756 n->name()); 757 return s; 758 } 759 if (!DeviceNameUtils::IsSameAddressSpace(parsed_dst_name, local_dev_name)) { 760 // Value is going to be the source of an RPC. 761 attr->set_nic_compatible(true); 762 VLOG(2) << "node " << n->name() << " is the source of an RPC out"; 763 } else if ((local_dev_name.type == "CPU" || dst->IsHostSend()) && 764 parsed_dst_name.type != "CPU") { 765 // Value is going to be the source of a local DMA from CPU to GPU (or 766 // other types of accelerators). 767 // Note that this does not cover the case where the allocation of the 768 // output tensor is not generated by the src: n. 769 attr->set_gpu_compatible(true); 770 VLOG(2) << "node " << n->name() << " is the source of a cpu->gpu copy"; 771 } else { 772 VLOG(2) << "default alloc case local type " << local_dev_name.type 773 << " remote type " << parsed_dst_name.type; 774 } 775 } 776 return s; 777 } 778 779 // The state associated with one invocation of ExecutorImpl::Run. 780 // ExecutorState dispatches nodes when they become ready and keeps 781 // track of how many predecessors of a node have not done (pending_). 782 class ExecutorState { 783 public: 784 ExecutorState(const Executor::Args& args, ExecutorImpl* impl); 785 ~ExecutorState(); 786 787 void RunAsync(Executor::DoneCallback done); 788 789 private: 790 // Either a tensor pointer (pass-by-reference) or a tensor (pass-by-value). 791 // TODO(yuanbyu): A better way to do "has_value"? 792 struct Entry { 793 Entry() {} 794 Entry(const Entry& other) 795 : ref(other.ref), 796 ref_mu(other.ref_mu), 797 has_value(other.has_value), 798 val_field_is_set(other.val_field_is_set), 799 alloc_attr(other.alloc_attr), 800 device_context(other.device_context) { 801 if (val_field_is_set) { 802 val.Init(*other.val); 803 } 804 } 805 ~Entry() { 806 if (val_field_is_set) val.Destroy(); 807 } 808 809 Entry& operator=(const Entry& other) { 810 if (val_field_is_set) { 811 val.Destroy(); 812 } 813 ref = other.ref; 814 ref_mu = other.ref_mu; 815 has_value = other.has_value; 816 val_field_is_set = other.val_field_is_set; 817 alloc_attr = other.alloc_attr; 818 device_context = other.device_context; 819 if (val_field_is_set) { 820 val.Init(*other.val); 821 } 822 return *this; 823 } 824 825 Entry& operator=(Entry&& other) { 826 if (val_field_is_set) { 827 val.Destroy(); 828 } 829 ref = other.ref; 830 ref_mu = other.ref_mu; 831 has_value = other.has_value; 832 val_field_is_set = other.val_field_is_set; 833 alloc_attr = other.alloc_attr; 834 device_context = other.device_context; 835 if (val_field_is_set) { 836 val.Init(std::move(*other.val)); 837 } 838 return *this; 839 } 840 841 // Clears the <val> field. 842 void ClearVal() { 843 if (val_field_is_set) { 844 val.Destroy(); 845 val_field_is_set = false; 846 has_value = false; 847 } 848 } 849 850 // A tensor value, if val_field_is_set. 851 ManualConstructor<Tensor> val; 852 853 Tensor* ref = nullptr; // A tensor reference. 854 mutex* ref_mu = nullptr; // mutex for *ref if ref is not nullptr. 855 856 // Whether the value exists, either in <val> or <ref>. 857 bool has_value = false; 858 859 bool val_field_is_set = false; 860 861 // The attributes of the allocator that creates the tensor. 862 AllocatorAttributes alloc_attr; 863 864 // Every entry carries an optional DeviceContext containing 865 // Device-specific information about how the Tensor was produced. 866 DeviceContext* device_context = nullptr; 867 }; 868 869 // Contains a value for [node->id()] for the device context assigned by the 870 // device at the beginning of a step. 871 DeviceContextMap device_context_map_; 872 873 struct TaggedNode; 874 typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq; 875 typedef gtl::InlinedVector<Entry, 4> EntryVector; 876 877 struct IterationState { 878 explicit IterationState(const PendingCounts* pending_counts, 879 int total_input_tensors) 880 : input_tensors(new Entry[total_input_tensors]), 881 outstanding_ops(0), 882 outstanding_frame_count(0), 883 counts_(*pending_counts) { // Initialize with copy of *pending_counts 884 } 885 886 // The state of an iteration. 887 888 // One copy per iteration. For iteration k, i-th node's j-th input is in 889 // input_tensors[k][impl_->nodes[i].input_start + j]. An entry is either 890 // a tensor pointer (pass-by-reference) or a tensor (pass-by-value). 891 // 892 // NOTE: No need to protect input_tensors[i] by any locks because it 893 // is resized once. Each element of tensors_ is written once by the 894 // source node of an edge and is cleared by the destination of the same 895 // edge. The latter node is never run concurrently with the former node. 896 Entry* input_tensors; 897 898 // The number of outstanding ops for each iteration. 899 size_t outstanding_ops; 900 901 // The number of outstanding frames for each iteration. 902 int outstanding_frame_count; 903 int pending(PendingCounts::Handle h) { return counts_.pending(h); } 904 int decrement_pending(PendingCounts::Handle h, int v) { 905 return counts_.decrement_pending(h, v); 906 } 907 // Mark a merge node as live 908 // REQUIRES: Node corresponding to "h" is a merge node 909 void mark_live(PendingCounts::Handle h) { counts_.mark_live(h); } 910 // Mark a node to show that processing has started. 911 void mark_started(PendingCounts::Handle h) { counts_.mark_started(h); } 912 // Mark a node to show that processing has completed. 913 void mark_completed(PendingCounts::Handle h) { counts_.mark_completed(h); } 914 PendingCounts::NodeState node_state(PendingCounts::Handle h) { 915 return counts_.node_state(h); 916 } 917 918 int dead_count(PendingCounts::Handle h) { return counts_.dead_count(h); } 919 void increment_dead_count(PendingCounts::Handle h) { 920 counts_.increment_dead_count(h); 921 } 922 void adjust_for_activation(PendingCounts::Handle h, bool increment_dead, 923 int* pending_result, int* dead_result) { 924 counts_.adjust_for_activation(h, increment_dead, pending_result, 925 dead_result); 926 } 927 928 ~IterationState() { delete[] input_tensors; } 929 930 private: 931 PendingCounts counts_; 932 }; 933 934 struct FrameState { 935 explicit FrameState(const ExecutorImpl* impl, int parallel_iters) 936 : executor(impl), 937 max_parallel_iterations(parallel_iters), 938 num_outstanding_iterations(1) {} 939 940 // A new frame is created for each loop. Execution starts at iteration 0. 941 // When a value at iteration 0 passes through a NextIteration node, 942 // iteration 1 is created and starts running. Note that iteration 0 may 943 // still be running so multiple iterations may run in parallel. The 944 // frame maintains the state of iterations in several data structures 945 // such as pending_count and input_tensors. When iteration 0 completes, 946 // we garbage collect the state of iteration 0. 947 // 948 // A frame instance is considered "done" and can be garbage collected 949 // if all its inputs have entered and all its iterations are "done". 950 // 951 // A frame manages the live iterations of an iterative computation. 952 // Iteration i is considered "done" when there are no outstanding ops, 953 // frames at iteration i are done, all recvs for this iteration are 954 // completed, and iteration i-1 is done. For iteration 0, we instead 955 // wait for there to be no more pending inputs of the frame. 956 // 957 // Frames and iterations are garbage collected once they are done. 958 // The state we need to keep around is highly dependent on the 959 // parallelism enabled by the scheduler. We may want to have the 960 // scheduler dynamically control the outstanding number of live 961 // parallel frames and iterations. To reduce the state space, the 962 // scheduler might want to schedule ops in inner frames first and 963 // lower iterations first. 964 // 965 // This frame state is mostly initialized lazily on demand so we 966 // don't introduce unnecessary overhead. 967 968 // The executor the frame is in. 969 const ExecutorImpl* executor = nullptr; 970 971 // The name of this frame, which is the concatenation of its parent 972 // frame name, the iteration of the parent frame when this frame was 973 // created, and the value of the attr 'frame_name'. 974 string frame_name; 975 976 // The unique id for this frame. Generated by fingerprinting 977 // frame_name. 978 uint64 frame_id; 979 980 // The iteration id of its parent frame when this frame is created. 981 // -1 if there is no parent frame. The frame_name/parent_iter pair 982 // uniquely identifies this FrameState. 983 int64 parent_iter = -1; 984 985 // The FrameState of its parent frame. 986 FrameState* parent_frame = nullptr; 987 988 // The maximum allowed number of parallel iterations. 989 const int max_parallel_iterations; 990 991 // The number of inputs this frame is still waiting. 992 int num_pending_inputs = 0; 993 994 // The highest iteration number we have reached so far in this frame. 995 int64 iteration_count GUARDED_BY(mu) = 0; 996 997 // The number of outstanding iterations. 998 int num_outstanding_iterations GUARDED_BY(mu) = 1; 999 1000 // The active iteration states of this frame. 1001 gtl::InlinedVector<IterationState*, 12> iterations; 1002 1003 // The NextIteration nodes to enter a new iteration. If the number of 1004 // outstanding iterations reaches the limit, we will defer the start of 1005 // the next iteration until the number of outstanding iterations falls 1006 // below the limit. 1007 std::vector<std::pair<const Node*, Entry>> next_iter_roots GUARDED_BY(mu); 1008 1009 // The values of the loop invariants for this loop. They are added into 1010 // this list as they "enter" the frame. When a loop invariant enters, 1011 // we make it available to all active iterations. When the frame starts 1012 // a new iteration, we make all the current loop invariants available 1013 // to the new iteration. 1014 std::vector<std::pair<const Node*, Entry>> inv_values GUARDED_BY(mu); 1015 1016 // The list of dead exit nodes for the current highest iteration. We 1017 // will only "execute" the dead exits of the final iteration. 1018 std::vector<const Node*> dead_exits GUARDED_BY(mu); 1019 1020 // Static information specific to this frame. 1021 PendingCounts* pending_counts = nullptr; 1022 int total_input_tensors = 0; 1023 std::vector<const Node*>* nodes = nullptr; 1024 1025 // Lock ordering: ExecutorState.mu_ < mu. 1026 mutex mu; 1027 1028 void InitializeFrameInfo(const string& enter_name) { 1029 auto it_frame_info = executor->frame_info_.find(enter_name); 1030 DCHECK(it_frame_info != executor->frame_info_.end()); 1031 ExecutorImpl::FrameInfo* finfo = it_frame_info->second; 1032 pending_counts = finfo->pending_counts; 1033 total_input_tensors = finfo->total_inputs; 1034 num_pending_inputs = finfo->input_count; 1035 nodes = finfo->nodes; 1036 } 1037 1038 inline IterationState* GetIteration(int64 iter) 1039 EXCLUSIVE_LOCKS_REQUIRED(mu) { 1040 size_t index = iter % iterations.size(); 1041 return iterations[index]; 1042 } 1043 1044 inline void SetIteration(int64 iter, IterationState* state) 1045 EXCLUSIVE_LOCKS_REQUIRED(mu) { 1046 size_t index = iter % iterations.size(); 1047 DCHECK(state == nullptr || iterations[index] == nullptr); 1048 iterations[index] = state; 1049 } 1050 1051 // Decrement the outstanding op count and clean up the iterations in the 1052 // frame. Return true iff the execution of the frame is done. 1053 inline bool DecrementOutstandingOps(const GraphView* gview, int64 iter, 1054 TaggedNodeSeq* ready) { 1055 mutex_lock l(mu); 1056 return DecrementOutstandingOpsLocked(gview, iter, ready); 1057 } 1058 1059 // Decrement the outstanding op count and clean up the iterations in the 1060 // frame. Return true iff the execution of the frame is done. 1061 inline bool DecrementOutstandingOpsLocked(const GraphView* gview, 1062 int64 iter, TaggedNodeSeq* ready) 1063 EXCLUSIVE_LOCKS_REQUIRED(mu) { 1064 IterationState* istate = GetIteration(iter); 1065 istate->outstanding_ops--; 1066 if (istate->outstanding_ops != 0) { 1067 return false; 1068 } else { 1069 return CleanupIterations(gview, iter, ready); 1070 } 1071 } 1072 1073 // Returns true if the computation in the frame is completed. 1074 inline bool IsFrameDone() EXCLUSIVE_LOCKS_REQUIRED(mu) { 1075 return (num_pending_inputs == 0 && num_outstanding_iterations == 0); 1076 } 1077 1078 // Returns true if the iteration of the frame is completed. 1079 bool IsIterationDone(int64 iter) EXCLUSIVE_LOCKS_REQUIRED(mu); 1080 1081 // Increments the iteration id. If this is a new iteration, initialize it. 1082 void IncrementIteration(const GraphView* gview, TaggedNodeSeq* ready) 1083 EXCLUSIVE_LOCKS_REQUIRED(mu); 1084 1085 // Activate all the deferred NextIteration nodes in a new iteration. 1086 void ActivateNexts(const GraphView* gview, int64 iter, TaggedNodeSeq* ready) 1087 EXCLUSIVE_LOCKS_REQUIRED(mu); 1088 1089 // Activate all the current loop invariants in a new iteration. 1090 void ActivateLoopInvs(const GraphView* gview, int64 iter, 1091 TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu); 1092 1093 // Add a new loop invariant and make it available to all active iterations. 1094 void AddLoopInv(const NodeItem* item, const Entry& value, 1095 TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu); 1096 1097 // Activate the successors of a node. Contents of *outputs are left in an 1098 // indeterminate state after returning from this method. 1099 void ActivateNodes(const NodeItem* item, const bool is_dead, int64 iter, 1100 EntryVector* outputs, TaggedNodeSeq* ready) 1101 EXCLUSIVE_LOCKS_REQUIRED(mu); 1102 1103 // Cleanup iterations of this frame starting from iteration iter. 1104 bool CleanupIterations(const GraphView* gview, int64 iter, 1105 TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu); 1106 1107 ~FrameState() { 1108 for (size_t i = 0; i < iterations.size(); ++i) { 1109 delete iterations[i]; 1110 iterations[i] = nullptr; 1111 } 1112 } 1113 }; 1114 1115 // A tagged node: <frame*, iter, node*>. 1116 struct TaggedNode { 1117 const Node* node = nullptr; 1118 FrameState* input_frame = nullptr; 1119 int64 input_iter = -1; 1120 bool is_dead = false; 1121 1122 TaggedNode(const Node* t_node, FrameState* in_frame, int64 in_iter, 1123 bool dead) { 1124 node = t_node; 1125 input_frame = in_frame; 1126 input_iter = in_iter; 1127 is_dead = dead; 1128 } 1129 }; 1130 1131 // A drop-in replacement for std::deque<TaggedNode>. We typically don't 1132 // have that many nodes in the ready queue, so we just use a vector and 1133 // don't free up memory from the queue as we consume nodes. 1134 class TaggedNodeReadyQueue { 1135 public: 1136 TaggedNodeReadyQueue() : front_index_(0) {} 1137 1138 void push_back(TaggedNode node) { ready_.push_back(node); } 1139 TaggedNode front() const { 1140 DCHECK_LT(front_index_, ready_.size()); 1141 return ready_[front_index_]; 1142 } 1143 void pop_front() { 1144 DCHECK_LT(front_index_, ready_.size()); 1145 front_index_++; 1146 if ((front_index_ == ready_.size()) || (front_index_ > 16384)) { 1147 if (front_index_ == ready_.size()) { 1148 ready_.clear(); 1149 } else { 1150 // Lots of unused entries at beginning of vector: move everything down 1151 // to start of vector. 1152 ready_.erase(ready_.begin(), ready_.begin() + front_index_); 1153 } 1154 front_index_ = 0; 1155 } 1156 } 1157 bool empty() const { return ready_.empty(); } 1158 const TaggedNode* begin() const { return ready_.begin() + front_index_; } 1159 const TaggedNode* end() const { return ready_.end(); } 1160 1161 private: 1162 gtl::InlinedVector<TaggedNode, 16> ready_; 1163 int front_index_; 1164 }; 1165 1166 struct AsyncState; 1167 1168 const bool vlog_; // true if VLOG_IS_ON(1). Used to check vlog cheaply. 1169 1170 // true if LogMemory::IsEnabled(). Used to check memory enabled cheaply. 1171 const bool log_memory_; 1172 1173 int64 step_id_; 1174 // Not owned. 1175 Rendezvous* rendezvous_; 1176 SessionState* session_state_; 1177 TensorStore* tensor_store_; 1178 // Step-local container. 1179 ScopedStepContainer* step_container_; 1180 StepStatsCollector* stats_collector_; 1181 // QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper 1182 // instead of a pointer? (avoids having to delete). 1183 checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_; 1184 CallFrameInterface* call_frame_; 1185 const ExecutorImpl* impl_; 1186 CancellationManager* cancellation_manager_; 1187 Executor::Args::Runner runner_; 1188 bool sync_on_finish_; 1189 1190 // Owned. 1191 1192 // A flag that is set on error after the frame state has been 1193 // dumped for diagnostic purposes. 1194 bool dumped_on_error_ = false; 1195 1196 // The root frame in which the execution of this step is started. 1197 FrameState* root_frame_; 1198 1199 // Invoked when the execution finishes. 1200 Executor::DoneCallback done_cb_; 1201 1202 std::atomic_int_fast32_t num_outstanding_ops_; 1203 1204 mutex mu_; 1205 Status status_ GUARDED_BY(mu_); 1206 1207 // Mapping from frame name to outstanding frames. A new frame is created 1208 // at some iteration of an active frame. So the unique key for the new 1209 // child frame is composed of the name of the parent frame, the iteration 1210 // number at which the parent frame is creating the new frame, and the 1211 // name of the new frame from nodedef. 1212 gtl::FlatMap<string, FrameState*> outstanding_frames_ GUARDED_BY(mu_); 1213 1214 // The unique name of a frame. 1215 inline string MakeFrameName(FrameState* frame, int64 iter_id, 1216 const string& name) { 1217 return strings::StrCat(frame->frame_name, ";", iter_id, ";", name); 1218 } 1219 1220 // Find an existing or create a new child frame in the frame 'frame' at 1221 // iteration 'iter'. 1222 void FindOrCreateChildFrame(FrameState* frame, int64 iter, const Node* node, 1223 FrameState** child); 1224 1225 // Delete a frame. Called when the frame is done. 1226 void DeleteFrame(FrameState* frame, TaggedNodeSeq* ready); 1227 1228 // Cleanup frames and iterations starting from frame/iter. Called when 1229 // a child frame is done. 1230 void CleanupFramesIterations(FrameState* frame, int64 iter, 1231 TaggedNodeSeq* ready); 1232 1233 // Process a ready node in current thread. 1234 void Process(TaggedNode node, int64 scheduled_usec); 1235 1236 // Before invoking item->kernel, fills in its "inputs". 1237 Status PrepareInputs(const NodeItem& item, Entry* first_input, 1238 TensorValueVec* inputs, 1239 DeviceContextVec* input_device_contexts, 1240 AllocatorAttributeVec* input_alloc_attrs, 1241 bool* is_input_dead); 1242 1243 // After item->kernel computation is done, processes its outputs. 1244 Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx, 1245 EntryVector* outputs, NodeExecStatsWrapper* stats); 1246 1247 // After processing the outputs, propagates the outputs to their dsts. 1248 // Contents of *outputs are left in an indeterminate state after 1249 // returning from this method. 1250 void PropagateOutputs(const TaggedNode& tagged_node, const NodeItem* item, 1251 EntryVector* outputs, TaggedNodeSeq* ready); 1252 1253 // "node" just finishes. Takes ownership of "stats". Returns true if 1254 // execution has completed. 1255 bool NodeDone(const Status& s, const Node* node, const TaggedNodeSeq& ready, 1256 NodeExecStatsWrapper* stats, 1257 TaggedNodeReadyQueue* inline_ready); 1258 1259 // Schedule all the expensive nodes in 'ready', and put all the inexpensive 1260 // nodes in 'ready' into 'inline_ready'. 1261 void ScheduleReady(const TaggedNodeSeq& ready, 1262 TaggedNodeReadyQueue* inline_ready); 1263 1264 // For debugging/logging only. 1265 inline void MaybeMarkCompleted(FrameState* frame, int64 iter, int64 id); 1266 1267 // Provide debugging output about an outstanding node in the executor. 1268 void DumpPendingNodeState(const int node_id, const Entry* input_vector, 1269 bool show_nodes_with_no_ready_inputs); 1270 void DumpActiveNodeState(const int node_id, const Entry* input_vector); 1271 1272 // Provide debugging output about an outstanding iteration in the executor. 1273 void DumpIterationState(const FrameState* frame, IterationState* iteration); 1274 1275 // Provide debugging output of the state of the executor. 1276 void DumpState(); 1277 const Tensor* GetTensorValueForDump(const Entry& input); 1278 1279 // Clean up when this executor is done. 1280 void Finish(); 1281 1282 // A standalone routine for this expression so that we can express 1283 // that we don't want thread safety analysis on this reference (it's 1284 // safe to do without the lock because the iterations array never 1285 // resizes and this particular iteration's array element will not 1286 // be changed out from under us because the iteration is still alive). 1287 Entry* GetInputTensors(FrameState* input_frame, 1288 int64 input_iter) const NO_THREAD_SAFETY_ANALYSIS { 1289 return input_frame->GetIteration(input_iter)->input_tensors; 1290 } 1291 }; 1292 1293 ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl) 1294 : vlog_(VLOG_IS_ON(1)), 1295 log_memory_(LogMemory::IsEnabled()), 1296 step_id_(args.step_id), 1297 rendezvous_(args.rendezvous), 1298 session_state_(args.session_state), 1299 tensor_store_(args.tensor_store), 1300 step_container_(args.step_container), 1301 stats_collector_(args.stats_collector), 1302 slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper), 1303 call_frame_(args.call_frame), 1304 impl_(impl), 1305 cancellation_manager_(args.cancellation_manager), 1306 runner_(args.runner), 1307 sync_on_finish_(args.sync_on_finish), 1308 num_outstanding_ops_(0) { 1309 // We start the entire execution in iteration 0 of the root frame 1310 // so let us create the root frame and the state for iteration 0. 1311 // We assume root_frame_->frame_name.empty(). 1312 root_frame_ = new FrameState(impl_, 1); 1313 root_frame_->frame_id = 0; // must be 0 1314 root_frame_->InitializeFrameInfo(root_frame_->frame_name); 1315 1316 // Initialize iteration 0. 1317 root_frame_->iterations.resize(root_frame_->max_parallel_iterations); 1318 root_frame_->iterations[0] = new IterationState( 1319 root_frame_->pending_counts, root_frame_->total_input_tensors); 1320 1321 outstanding_frames_.insert({root_frame_->frame_name, root_frame_}); 1322 } 1323 1324 ExecutorState::~ExecutorState() { 1325 for (auto name_frame : outstanding_frames_) { 1326 delete name_frame.second; 1327 } 1328 for (auto it : device_context_map_) { 1329 it->Unref(); 1330 } 1331 delete slice_reader_cache_; 1332 } 1333 1334 Status ExecutorImpl::BuildControlFlowInfo(const Graph* g, 1335 ControlFlowInfo* cf_info) { 1336 const int num_nodes = g->num_node_ids(); 1337 cf_info->frame_names.resize(num_nodes); 1338 std::vector<Node*> parent_nodes; 1339 parent_nodes.resize(num_nodes); 1340 std::vector<bool> visited; 1341 visited.resize(num_nodes); 1342 1343 string frame_name; 1344 std::deque<Node*> ready; 1345 1346 // Initialize with the root nodes. 1347 for (Node* n : g->nodes()) { 1348 if (n->in_edges().empty()) { 1349 visited[n->id()] = true; 1350 cf_info->unique_frame_names.insert(frame_name); 1351 ready.push_back(n); 1352 } 1353 } 1354 1355 while (!ready.empty()) { 1356 Node* curr_node = ready.front(); 1357 int curr_id = curr_node->id(); 1358 ready.pop_front(); 1359 1360 Node* parent = nullptr; 1361 if (IsEnter(curr_node)) { 1362 // Enter a child frame. 1363 TF_RETURN_IF_ERROR( 1364 GetNodeAttr(curr_node->attrs(), "frame_name", &frame_name)); 1365 parent = curr_node; 1366 } else if (IsExit(curr_node)) { 1367 // Exit to the parent frame. 1368 parent = parent_nodes[curr_id]; 1369 frame_name = cf_info->frame_names[parent->id()]; 1370 parent = parent_nodes[parent->id()]; 1371 } else { 1372 parent = parent_nodes[curr_id]; 1373 frame_name = cf_info->frame_names[curr_id]; 1374 } 1375 1376 for (const Edge* out_edge : curr_node->out_edges()) { 1377 Node* out = out_edge->dst(); 1378 const int out_id = out->id(); 1379 1380 // Add to ready queue if not visited. 1381 bool is_visited = visited[out_id]; 1382 if (!is_visited) { 1383 ready.push_back(out); 1384 visited[out_id] = true; 1385 1386 // Process the node 'out'. 1387 cf_info->frame_names[out_id] = frame_name; 1388 parent_nodes[out_id] = parent; 1389 cf_info->unique_frame_names.insert(frame_name); 1390 } 1391 } 1392 } 1393 1394 return Status::OK(); 1395 } 1396 1397 void ExecutorImpl::InitializePending(const Graph* graph, 1398 const ControlFlowInfo& cf_info) { 1399 for (auto& it : cf_info.unique_frame_names) { 1400 FrameInfo* finfo = EnsureFrameInfo(it); 1401 PendingCounts* counts = new PendingCounts(finfo->pending_counts_layout); 1402 DCHECK_EQ(finfo->pending_counts, nullptr); 1403 finfo->pending_counts = counts; 1404 } 1405 for (const Node* n : graph->nodes()) { 1406 const int id = n->id(); 1407 const string& name = cf_info.frame_names[id]; 1408 size_t max_pending, max_dead; 1409 GetMaxPendingCounts(n, &max_pending, &max_dead); 1410 const NodeItem* item = gview_.node(id); 1411 PendingCounts* counts = EnsureFrameInfo(name)->pending_counts; 1412 counts->set_initial_count(item->pending_id, max_pending); 1413 } 1414 } 1415 1416 void ExecutorState::RunAsync(Executor::DoneCallback done) { 1417 const Graph* graph = impl_->graph_.get(); 1418 TaggedNodeSeq ready; 1419 1420 // Ask the device to fill in the device context map. 1421 Device* device = impl_->params_.device; 1422 const Status fill_status = 1423 device->FillContextMap(graph, &device_context_map_); 1424 if (!fill_status.ok()) { 1425 done(fill_status); 1426 return; 1427 } 1428 1429 // Initialize the ready queue. 1430 for (const Node* n : impl_->root_nodes_) { 1431 DCHECK_EQ(n->in_edges().size(), 0); 1432 ready.push_back(TaggedNode{n, root_frame_, 0, false}); 1433 } 1434 if (ready.empty()) { 1435 done(Status::OK()); 1436 } else { 1437 num_outstanding_ops_ = ready.size(); 1438 root_frame_->iterations[0]->outstanding_ops = ready.size(); 1439 done_cb_ = std::move(done); 1440 // Schedule to run all the ready ops in thread pool. 1441 ScheduleReady(ready, nullptr); 1442 } 1443 } 1444 1445 // State kept alive for executing an asynchronous node in another 1446 // thread. NOTE: We need to make a copy of p.input, 1447 // p.input_device_contexts, and p.input_alloc_attrs for asynchronous 1448 // kernels because OpKernelContext methods like input_type(i) needs 1449 // the param points to valid input type vector. It's not an issue for 1450 // sync kernels because these vectors are kept on the stack. 1451 struct ExecutorState::AsyncState { 1452 AsyncState(const OpKernelContext::Params& p, const TaggedNode& _tagged_node, 1453 const NodeItem* _item, Entry* _first_input, 1454 NodeExecStatsWrapper* _stats) 1455 : saved_inputs(*p.inputs), 1456 saved_input_device_contexts(*p.input_device_contexts), 1457 saved_input_alloc_attrs(*p.input_alloc_attrs), 1458 params(p), 1459 tagged_node(_tagged_node), 1460 item(_item), 1461 first_input(_first_input), 1462 // ParamsButClearingEigenGPUDevice does equivalent of 1463 // params.eigen_gpu_device = nullptr; 1464 ctx(ParamsButClearingEigenGPUDevice(¶ms), item->num_outputs), 1465 stats(_stats) { 1466 params.inputs = &saved_inputs; 1467 params.input_device_contexts = &saved_input_device_contexts; 1468 params.input_alloc_attrs = &saved_input_alloc_attrs; 1469 } 1470 1471 TensorValueVec saved_inputs; 1472 DeviceContextVec saved_input_device_contexts; 1473 AllocatorAttributeVec saved_input_alloc_attrs; 1474 OpKernelContext::Params params; 1475 TaggedNode tagged_node; 1476 const NodeItem* item; 1477 Entry* first_input; 1478 OpKernelContext ctx; 1479 NodeExecStatsWrapper* stats; 1480 1481 private: 1482 OpKernelContext::Params* ParamsButClearingEigenGPUDevice( 1483 OpKernelContext::Params* p) { 1484 // Ensure OpKernelContext constructor will make a new eigen GPU device if 1485 // necessary. 1486 p->eigen_gpu_device = nullptr; // Force allocation 1487 return p; 1488 } 1489 }; 1490 1491 void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) { 1492 const GraphView& gview = impl_->gview_; 1493 TaggedNodeSeq ready; 1494 TaggedNodeReadyQueue inline_ready; 1495 1496 // Parameters passed to OpKernel::Compute. 1497 TensorValueVec inputs; 1498 DeviceContextVec input_device_contexts; 1499 AllocatorAttributeVec input_alloc_attrs; 1500 1501 OpKernelContext::Params params; 1502 params.step_id = step_id_; 1503 Device* device = impl_->params_.device; 1504 params.device = device; 1505 params.log_memory = log_memory_; 1506 params.record_tensor_accesses = impl_->device_record_tensor_accesses_; 1507 params.rendezvous = rendezvous_; 1508 params.session_state = session_state_; 1509 params.tensor_store = tensor_store_; 1510 params.cancellation_manager = cancellation_manager_; 1511 params.call_frame = call_frame_; 1512 params.function_library = impl_->params_.function_library; 1513 params.resource_manager = device->resource_manager(); 1514 params.step_container = step_container_; 1515 params.slice_reader_cache = slice_reader_cache_; 1516 params.inputs = &inputs; 1517 params.input_device_contexts = &input_device_contexts; 1518 params.input_alloc_attrs = &input_alloc_attrs; 1519 params.runner = &runner_; 1520 params.stats_collector = stats_collector_; 1521 1522 Status s; 1523 NodeExecStatsWrapper* stats = nullptr; 1524 EntryVector outputs; 1525 bool completed = false; 1526 inline_ready.push_back(tagged_node); 1527 while (!inline_ready.empty()) { 1528 tagged_node = inline_ready.front(); 1529 inline_ready.pop_front(); 1530 const Node* node = tagged_node.node; 1531 FrameState* input_frame = tagged_node.input_frame; 1532 const int64 input_iter = tagged_node.input_iter; 1533 const int id = node->id(); 1534 const NodeItem& item = *gview.node(id); 1535 1536 // TODO(misard) Replace with a finer-grain enabling flag once we 1537 // add better optional debugging support. 1538 if (vlog_ && VLOG_IS_ON(1)) { 1539 mutex_lock l(input_frame->mu); 1540 input_frame->GetIteration(input_iter)->mark_started(item.pending_id); 1541 } 1542 1543 // Set the device_context for this node id, if it exists. 1544 if (id < device_context_map_.size()) { 1545 params.op_device_context = device_context_map_[id]; 1546 } 1547 1548 params.track_allocations = false; 1549 stats = nullptr; 1550 if (stats_collector_ && !tagged_node.is_dead) { 1551 // track allocations if and only if we are collecting statistics 1552 params.track_allocations = true; 1553 stats = new NodeExecStatsWrapper; 1554 stats->stats()->set_node_name(node->name()); 1555 nodestats::SetScheduled(stats, scheduled_usec); 1556 nodestats::SetAllStart(stats); 1557 } 1558 1559 if (vlog_) { 1560 VLOG(1) << "Process node: " << id << " step " << params.step_id << " " 1561 << SummarizeNode(*node) << " is dead: " << tagged_node.is_dead; 1562 } 1563 1564 Entry* input_tensors = GetInputTensors(input_frame, input_iter); 1565 Entry* first_input = input_tensors + item.input_start; 1566 outputs.clear(); 1567 1568 TensorReferenceVector accessed_tensors; 1569 DeviceContext* device_context = nullptr; 1570 // Only execute this node if it is not dead or it is a send/recv 1571 // transfer node. For transfer nodes, we need to propagate the "dead" 1572 // bit even when the node is dead. 1573 bool launched_asynchronously = false; 1574 if (tagged_node.is_dead && !IsTransferNode(node)) { 1575 outputs.resize(item.num_outputs); 1576 } else { 1577 // Prepares inputs. 1578 bool is_input_dead = false; 1579 s = PrepareInputs(item, first_input, &inputs, &input_device_contexts, 1580 &input_alloc_attrs, &is_input_dead); 1581 if (!s.ok()) { 1582 // Clear inputs. 1583 int num_inputs = item.num_inputs; 1584 for (int i = 0; i < num_inputs; ++i) { 1585 (first_input + i)->ClearVal(); 1586 } 1587 MaybeMarkCompleted(input_frame, input_iter, id); 1588 // Continue to process the nodes in 'inline_ready'. 1589 completed = NodeDone(s, item.node, ready, stats, &inline_ready); 1590 continue; 1591 } 1592 1593 // Set up compute params. 1594 OpKernel* op_kernel = item.kernel; 1595 params.op_kernel = op_kernel; 1596 params.frame_iter = FrameAndIter(input_frame->frame_id, input_iter); 1597 params.is_input_dead = is_input_dead; 1598 params.output_attr_array = item.output_attrs(); 1599 1600 if (item.kernel_is_async) { 1601 // Asynchronous computes. 1602 AsyncOpKernel* async = item.kernel->AsAsync(); 1603 DCHECK(async != nullptr); 1604 launched_asynchronously = true; 1605 AsyncState* state = 1606 new AsyncState(params, tagged_node, &item, first_input, stats); 1607 1608 auto done = [this, state]() { 1609 Device* device = impl_->params_.device; 1610 NodeExecStatsWrapper* stats = state->stats; // Shorthand 1611 Entry* first_input = state->first_input; // Shorthand 1612 1613 nodestats::SetOpEnd(stats); 1614 EntryVector outputs; 1615 Status s = ProcessOutputs(*state->item, &state->ctx, &outputs, stats); 1616 nodestats::SetMemory(stats, &state->ctx); 1617 if (vlog_) { 1618 VLOG(2) << "Async kernel done: " << state->item->node->id() 1619 << " step " << step_id_ << " " 1620 << SummarizeNode(*state->item->node) 1621 << " is dead: " << state->tagged_node.is_dead; 1622 } 1623 1624 // Clears inputs. 1625 const int num_inputs = state->item->num_inputs; 1626 for (int i = 0; i < num_inputs; ++i) { 1627 (first_input + i)->ClearVal(); 1628 } 1629 FrameState* input_frame = state->tagged_node.input_frame; 1630 const int64 input_iter = state->tagged_node.input_iter; 1631 const int id = state->tagged_node.node->id(); 1632 MaybeMarkCompleted(input_frame, input_iter, id); 1633 TaggedNodeSeq ready; 1634 if (s.ok()) { 1635 PropagateOutputs(state->tagged_node, state->item, &outputs, &ready); 1636 } 1637 outputs.clear(); 1638 if (s.ok() && impl_->device_record_tensor_accesses_) { 1639 // Get the list of all tensors accessed during the execution 1640 TensorReferenceVector accessed; 1641 state->ctx.retrieve_accessed_tensors(&accessed); 1642 nodestats::SetReferencedTensors(stats, accessed); 1643 // callee takes ownership of the vector 1644 device->ConsumeListOfAccessedTensors(state->ctx.op_device_context(), 1645 accessed); 1646 } 1647 const bool completed = 1648 NodeDone(s, state->item->node, ready, stats, nullptr); 1649 delete state; 1650 if (completed) Finish(); 1651 }; 1652 nodestats::SetOpStart(stats); 1653 device->ComputeAsync(async, &state->ctx, done); 1654 } else { 1655 // Synchronous computes. 1656 OpKernelContext ctx(¶ms, item.num_outputs); 1657 nodestats::SetOpStart(stats); 1658 device->Compute(CHECK_NOTNULL(op_kernel), &ctx); 1659 nodestats::SetOpEnd(stats); 1660 s = ProcessOutputs(item, &ctx, &outputs, stats); 1661 if (s.ok() && impl_->device_record_tensor_accesses_) { 1662 // Get the list of all tensors accessed during the execution 1663 ctx.retrieve_accessed_tensors(&accessed_tensors); 1664 device_context = ctx.op_device_context(); 1665 } 1666 nodestats::SetMemory(stats, &ctx); 1667 } 1668 } 1669 1670 if (!launched_asynchronously) { 1671 if (vlog_) { 1672 VLOG(2) << "Synchronous kernel done: " << id << " step " 1673 << params.step_id << " " << SummarizeNode(*node) 1674 << " is dead: " << tagged_node.is_dead; 1675 } 1676 1677 // Clears inputs. 1678 const int num_inputs = item.num_inputs; 1679 for (int i = 0; i < num_inputs; ++i) { 1680 (first_input + i)->ClearVal(); 1681 } 1682 MaybeMarkCompleted(input_frame, input_iter, id); 1683 // Propagates outputs. 1684 if (s.ok()) { 1685 PropagateOutputs(tagged_node, &item, &outputs, &ready); 1686 } 1687 outputs.clear(); 1688 if (!accessed_tensors.empty()) { 1689 nodestats::SetReferencedTensors(stats, accessed_tensors); 1690 // device_context is set above in synchronous computes 1691 device->ConsumeListOfAccessedTensors(device_context, accessed_tensors); 1692 } 1693 if (stats) { 1694 scheduled_usec = nodestats::NowInUsec(); 1695 } 1696 // Postprocess. 1697 completed = NodeDone(s, item.node, ready, stats, &inline_ready); 1698 } 1699 } // while !inline_ready.empty() 1700 1701 // This thread of computation is done if completed = true. 1702 if (completed) Finish(); 1703 } 1704 1705 Status ExecutorState::PrepareInputs(const NodeItem& item, Entry* first_input, 1706 TensorValueVec* inputs, 1707 DeviceContextVec* input_device_contexts, 1708 AllocatorAttributeVec* input_alloc_attrs, 1709 bool* is_input_dead) { 1710 const Node* node = item.node; 1711 1712 inputs->clear(); 1713 inputs->resize(item.num_inputs); 1714 input_device_contexts->clear(); 1715 input_device_contexts->resize(item.num_inputs); 1716 input_alloc_attrs->clear(); 1717 input_alloc_attrs->resize(item.num_inputs); 1718 1719 *is_input_dead = false; 1720 1721 bool is_merge = item.is_merge; 1722 for (int i = 0; i < item.num_inputs; ++i) { 1723 const bool expect_ref = IsRefType(item.input_type(i)); 1724 Entry* entry = first_input + i; 1725 (*input_device_contexts)[i] = entry->device_context; 1726 (*input_alloc_attrs)[i] = entry->alloc_attr; 1727 1728 // i-th input. 1729 TensorValue* inp = &(*inputs)[i]; 1730 1731 // Only merge and transfer nodes can have no-value inputs. 1732 if (!entry->has_value) { 1733 if (!is_merge) { 1734 DCHECK(IsTransferNode(node)) << node->name() << " - input " << i; 1735 DCHECK(!entry->val_field_is_set) << node->name() << " - input " << i; 1736 entry->has_value = true; 1737 entry->val_field_is_set = true; 1738 entry->val.Init(*kEmptyTensor); 1739 inp->tensor = entry->val.get(); 1740 *is_input_dead = true; 1741 } 1742 continue; 1743 } 1744 if (entry->ref == nullptr) { 1745 if (expect_ref) { 1746 return AttachDef( 1747 errors::InvalidArgument(i, "-th input expects a ref type"), 1748 item.kernel->def()); 1749 } 1750 inp->tensor = entry->val.get(); 1751 } else { 1752 { 1753 mutex_lock ml(*entry->ref_mu); 1754 if (!entry->ref->IsInitialized() && !IsInitializationOp(item.node)) { 1755 return AttachDef(errors::FailedPrecondition( 1756 "Attempting to use uninitialized value ", 1757 item.kernel->requested_input(i)), 1758 item.kernel->def()); 1759 } 1760 } 1761 if (expect_ref) { 1762 inp->mutex_if_ref = entry->ref_mu; 1763 inp->tensor = entry->ref; 1764 } else { 1765 // Automatically deref the tensor ref when the op expects a 1766 // tensor but is given a ref to a tensor. Need to deref it 1767 // under the mutex. 1768 { 1769 mutex_lock l(*(entry->ref_mu)); 1770 DCHECK(!entry->val_field_is_set); 1771 entry->val.Init(*entry->ref); 1772 entry->val_field_is_set = true; 1773 } 1774 entry->ref = nullptr; 1775 entry->ref_mu = nullptr; 1776 1777 inp->tensor = entry->val.get(); 1778 // The dtype of entry->ref could have been changed by another operation 1779 // that ran after the operation that "produced" it executed, so 1780 // re-validate that the type of the dereferenced tensor matches the 1781 // expected input type. 1782 if (item.input_type(i) != inp->tensor->dtype()) { 1783 return AttachDef( 1784 errors::InvalidArgument( 1785 i, "-th input expects type ", 1786 DataTypeString(item.input_type(i)), 1787 " but automatically dereferenced input tensor has type ", 1788 DataTypeString(inp->tensor->dtype())), 1789 item.kernel->def()); 1790 } 1791 } 1792 } 1793 } 1794 return Status::OK(); 1795 } 1796 1797 Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx, 1798 EntryVector* outputs, 1799 NodeExecStatsWrapper* stats) { 1800 const Node* node = item.node; 1801 DCHECK_EQ(0, outputs->size()); 1802 outputs->resize(item.num_outputs); 1803 1804 Status s = ctx->status(); 1805 if (!s.ok()) { 1806 s = AttachDef(s, item.kernel->def()); 1807 // TODO(misard) Replace with a finer-grain enabling flag once we 1808 // add better optional debugging support. 1809 if (vlog_ && VLOG_IS_ON(1)) { 1810 LOG(WARNING) << this << " Compute status: " << s; 1811 DumpState(); 1812 } 1813 if (s.code() == error::RESOURCE_EXHAUSTED) { 1814 if (stats_collector_) { 1815 string err = stats_collector_->ReportAllocsOnResourceExhausted( 1816 s.error_message()); 1817 s = Status(s.code(), strings::StrCat(s.error_message(), err)); 1818 } else { 1819 s = Status( 1820 s.code(), 1821 strings::StrCat( 1822 s.error_message(), 1823 "\nHint: If you want to see a list of allocated tensors when " 1824 "OOM happens, add report_tensor_allocations_upon_oom " 1825 "to RunOptions for current allocation info.\n")); 1826 } 1827 } 1828 return s; 1829 } 1830 1831 // Get the device_context for this node id, if it exists. 1832 DeviceContext* device_context = nullptr; 1833 if (node->id() < device_context_map_.size()) { 1834 device_context = device_context_map_[node->id()]; 1835 } 1836 1837 // Experimental: debugger (tfdb) access to intermediate node completion. 1838 if (item.num_outputs == 0 && impl_->params_.node_outputs_cb != nullptr) { 1839 // If the node has no output, invoke the callback with output slot set to 1840 // -1, signifying that this is a no-output node. 1841 s.Update(impl_->params_.node_outputs_cb(item.node->name(), -1, nullptr, 1842 false, ctx)); 1843 } 1844 1845 for (int i = 0; i < item.num_outputs; ++i) { 1846 const TensorValue val = ctx->release_output(i); 1847 if (*ctx->is_output_dead() || val.tensor == nullptr) { 1848 // Unless it's a Switch or a Recv, the node must produce a 1849 // tensor value at i-th output. 1850 if (!IsSwitch(node) && !IsRecv(node)) { 1851 s.Update(errors::Internal("Missing ", i, "-th output from ", 1852 SummarizeNode(*node))); 1853 } 1854 } else { 1855 Entry* out = &((*outputs)[i]); 1856 1857 // Set the device context of the output entry. 1858 out->device_context = device_context; 1859 1860 // Set the allocator attributes of the output entry. 1861 out->alloc_attr = ctx->output_alloc_attr(i); 1862 1863 // Sanity check of output tensor types. 1864 DataType dtype; 1865 if (val.is_ref()) { 1866 mutex_lock ml(*val.mutex_if_ref); 1867 dtype = MakeRefType(val->dtype()); 1868 } else { 1869 dtype = val->dtype(); 1870 } 1871 if (dtype == item.output_type(i)) { 1872 if (stats && val.tensor->IsInitialized()) { 1873 nodestats::SetOutput(stats, i, val.tensor); 1874 } 1875 if (val.is_ref()) { 1876 out->has_value = true; 1877 out->ref = val.tensor; 1878 out->ref_mu = val.mutex_if_ref; 1879 if (log_memory_) { 1880 Tensor to_log; 1881 { 1882 // Dereference the tensor under the lock. 1883 mutex_lock l(*out->ref_mu); 1884 to_log = *out->ref; 1885 } 1886 LogMemory::RecordTensorOutput(ctx->op_kernel().name(), 1887 ctx->step_id(), i, to_log); 1888 } 1889 1890 // Experimental: debugger (tfdb) access to intermediate node 1891 // outputs. 1892 if (impl_->params_.node_outputs_cb != nullptr) { 1893 s.Update(impl_->params_.node_outputs_cb(item.node->name(), i, 1894 out->ref, true, ctx)); 1895 } 1896 } else { 1897 // NOTE that std::move is used here, so val.tensor goes to 1898 // uninitialized state (val.tensor->IsInitialized return false). 1899 DCHECK(!out->val_field_is_set); 1900 out->has_value = true; 1901 out->val_field_is_set = true; 1902 out->val.Init(std::move(*val.tensor)); 1903 if (log_memory_) { 1904 LogMemory::RecordTensorOutput(ctx->op_kernel().name(), 1905 ctx->step_id(), i, *out->val); 1906 } 1907 1908 // Experimental: debugger access to intermediate node outputs. 1909 if (impl_->params_.node_outputs_cb != nullptr) { 1910 s.Update(impl_->params_.node_outputs_cb( 1911 item.node->name(), i, out->val.get(), false, ctx)); 1912 } 1913 } 1914 } else { 1915 s.Update(errors::Internal("Output ", i, " of type ", 1916 DataTypeString(dtype), 1917 " does not match declared output type ", 1918 DataTypeString(item.output_type(i)), 1919 " for node ", SummarizeNode(*node))); 1920 } 1921 } 1922 if (!val.is_ref()) { 1923 // If OpKernelContext returns outputs via pass-by-value, we 1924 // don't need this trouble. 1925 delete val.tensor; 1926 } 1927 } 1928 return s; 1929 } 1930 1931 void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node, 1932 const NodeItem* item, EntryVector* outputs, 1933 TaggedNodeSeq* ready) { 1934 const Node* node = tagged_node.node; 1935 FrameState* input_frame = tagged_node.input_frame; 1936 const int64 input_iter = tagged_node.input_iter; 1937 const bool is_dead = tagged_node.is_dead; 1938 1939 // Propagates outputs along out edges, and puts newly ready nodes 1940 // into the ready queue. 1941 ready->clear(); 1942 bool is_frame_done = false; 1943 FrameState* output_frame = input_frame; 1944 int64 output_iter = input_iter; 1945 1946 if (!item->is_enter_exit_or_next_iter) { 1947 // Fast path for nodes types that don't need special handling 1948 DCHECK_EQ(input_frame, output_frame); 1949 // Normal path for most nodes 1950 mutex_lock l(input_frame->mu); 1951 output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready); 1952 is_frame_done = input_frame->DecrementOutstandingOpsLocked( 1953 &impl_->gview_, input_iter, ready); 1954 } else if (item->is_enter) { 1955 bool is_constant; 1956 const Status s = GetNodeAttr(node->attrs(), "is_constant", &is_constant); 1957 DCHECK(s.ok()) << s; 1958 FindOrCreateChildFrame(input_frame, input_iter, node, &output_frame); 1959 output_iter = 0; 1960 { 1961 const NodeItem* item = impl_->gview_.node(node->id()); 1962 mutex_lock l(output_frame->mu); 1963 if (is_constant) { 1964 // Propagate to all active iterations if this is a loop invariant. 1965 output_frame->AddLoopInv(item, (*outputs)[0], ready); 1966 } else { 1967 output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready); 1968 } 1969 output_frame->num_pending_inputs--; 1970 } 1971 is_frame_done = 1972 input_frame->DecrementOutstandingOps(&impl_->gview_, input_iter, ready); 1973 } else if (item->is_exit) { 1974 if (is_dead) { 1975 mutex_lock l(input_frame->mu); 1976 // Stop and remember this node if it is a dead exit. 1977 if (input_iter == input_frame->iteration_count) { 1978 input_frame->dead_exits.push_back(node); 1979 } 1980 is_frame_done = input_frame->DecrementOutstandingOpsLocked( 1981 &impl_->gview_, input_iter, ready); 1982 } else { 1983 output_frame = input_frame->parent_frame; 1984 output_iter = input_frame->parent_iter; 1985 { 1986 mutex_lock l(output_frame->mu); 1987 output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready); 1988 } 1989 is_frame_done = input_frame->DecrementOutstandingOps(&impl_->gview_, 1990 input_iter, ready); 1991 } 1992 } else { 1993 DCHECK(IsNextIteration(node)); 1994 mutex_lock l(input_frame->mu); 1995 if (is_dead) { 1996 // Stop the deadness propagation. 1997 output_frame = nullptr; 1998 } else { 1999 if (input_iter == input_frame->iteration_count && 2000 input_frame->num_outstanding_iterations == 2001 input_frame->max_parallel_iterations) { 2002 // Reached the maximum for parallel iterations. 2003 input_frame->next_iter_roots.push_back({node, (*outputs)[0]}); 2004 output_frame = nullptr; 2005 } else { 2006 // If this is a new iteration, start it. 2007 if (input_iter == input_frame->iteration_count) { 2008 input_frame->IncrementIteration(&impl_->gview_, ready); 2009 } 2010 output_iter = input_iter + 1; 2011 } 2012 } 2013 if (output_frame != nullptr) { 2014 // This is the case when node is not Enter, Exit, or NextIteration. 2015 DCHECK(input_frame == output_frame); 2016 output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready); 2017 } 2018 is_frame_done = input_frame->DecrementOutstandingOpsLocked( 2019 &impl_->gview_, input_iter, ready); 2020 } 2021 2022 // At this point, this node is completely done. We also know if the 2023 // completion of this node makes its frame completed. 2024 if (is_frame_done) { 2025 FrameState* parent_frame = input_frame->parent_frame; 2026 const int64 parent_iter = input_frame->parent_iter; 2027 DeleteFrame(input_frame, ready); 2028 if (parent_frame != nullptr) { 2029 // The completion of frame may cause completions in its parent frame. 2030 // So clean things up recursively. 2031 CleanupFramesIterations(parent_frame, parent_iter, ready); 2032 } 2033 } 2034 } 2035 2036 bool ExecutorState::NodeDone(const Status& s, const Node* node, 2037 const TaggedNodeSeq& ready, 2038 NodeExecStatsWrapper* stats, 2039 TaggedNodeReadyQueue* inline_ready) { 2040 nodestats::SetAllEnd(stats); 2041 if (stats_collector_ != nullptr && !SetTimelineLabel(node, stats)) { 2042 // Only record non-transfer nodes. 2043 // Transfers 'stats' ownership to 'stats_collector_'. 2044 stats_collector_->Save(impl_->params_.device->name(), stats); 2045 } else if (stats) { 2046 delete stats; 2047 } 2048 2049 bool abort_run = false; 2050 if (!s.ok()) { 2051 // Some error happened. This thread of computation is done. 2052 mutex_lock l(mu_); 2053 if (status_.ok()) { 2054 abort_run = true; 2055 status_ = s; 2056 } 2057 } 2058 if (abort_run) { 2059 TRACEPRINTF("StartAbort: %s", s.ToString().c_str()); 2060 if (rendezvous_) { 2061 rendezvous_->StartAbort(s); 2062 } 2063 if (cancellation_manager_) { 2064 cancellation_manager_->StartCancel(); 2065 } 2066 } 2067 2068 bool completed = false; 2069 const size_t ready_size = ready.size(); 2070 if (ready_size == 0 || !s.ok()) { 2071 completed = (num_outstanding_ops_.fetch_sub(1) == 1); 2072 } else if (ready_size > 1) { 2073 num_outstanding_ops_.fetch_add(ready_size - 1, std::memory_order_relaxed); 2074 } 2075 2076 // Schedule the ready nodes in 'ready'. 2077 if (s.ok()) { 2078 ScheduleReady(ready, inline_ready); 2079 } 2080 return completed; 2081 } 2082 2083 void ExecutorState::ScheduleReady(const TaggedNodeSeq& ready, 2084 TaggedNodeReadyQueue* inline_ready) { 2085 if (ready.empty()) return; 2086 2087 int64 scheduled_usec = 0; 2088 if (stats_collector_) { 2089 scheduled_usec = nodestats::NowInUsec(); 2090 } 2091 if (inline_ready == nullptr) { 2092 // Schedule to run all the ready ops in thread pool. 2093 for (auto& tagged_node : ready) { 2094 runner_([=]() { Process(tagged_node, scheduled_usec); }); 2095 } 2096 return; 2097 } 2098 const GraphView& gview = impl_->gview_; 2099 const TaggedNode* curr_expensive_node = nullptr; 2100 for (auto& tagged_node : ready) { 2101 const NodeItem& item = *gview.node(tagged_node.node->id()); 2102 if (tagged_node.is_dead || !item.kernel_is_expensive) { 2103 // Inline this inexpensive node. 2104 inline_ready->push_back(tagged_node); 2105 } else { 2106 if (curr_expensive_node) { 2107 // Dispatch to another thread since there is plenty of work to 2108 // do for this thread. 2109 runner_(std::bind(&ExecutorState::Process, this, *curr_expensive_node, 2110 scheduled_usec)); 2111 } 2112 curr_expensive_node = &tagged_node; 2113 } 2114 } 2115 if (curr_expensive_node) { 2116 if (inline_ready->empty()) { 2117 // Tail recursion optimization 2118 inline_ready->push_back(*curr_expensive_node); 2119 } else { 2120 // There are inline nodes to run already. We dispatch this expensive 2121 // node to other thread. 2122 runner_(std::bind(&ExecutorState::Process, this, *curr_expensive_node, 2123 scheduled_usec)); 2124 } 2125 } 2126 } 2127 2128 inline void ExecutorState::MaybeMarkCompleted(FrameState* frame, int64 iter, 2129 int64 node_id) { 2130 // TODO(misard) Replace with a finer-grain enabling flag once we 2131 // add better optional debugging support. 2132 if (vlog_ && VLOG_IS_ON(1)) { 2133 const NodeItem* item = impl_->gview_.node(node_id); 2134 mutex_lock l(frame->mu); 2135 frame->GetIteration(iter)->mark_completed(item->pending_id); 2136 } 2137 } 2138 2139 const Tensor* ExecutorState::GetTensorValueForDump(const Entry& input) { 2140 if (!input.has_value) { 2141 return kEmptyTensor; 2142 } else if (input.ref == nullptr) { 2143 return input.val.get(); 2144 } else { 2145 return input.ref; 2146 } 2147 } 2148 2149 void ExecutorState::DumpPendingNodeState( 2150 const int node_id, const Entry* input_vector, 2151 const bool show_nodes_with_no_ready_inputs) { 2152 const NodeItem& node_item = *impl_->gview_.node(node_id); 2153 const Node& node = *node_item.node; 2154 const int input_base = node_item.input_start; 2155 if (!show_nodes_with_no_ready_inputs) { 2156 bool has_ready_input = false; 2157 for (int i = 0; i < node.num_inputs(); ++i) { 2158 const Entry& input = input_vector[input_base + i]; 2159 const Tensor* tensor = GetTensorValueForDump(input); 2160 if (tensor->IsInitialized()) { 2161 has_ready_input = true; 2162 break; 2163 } 2164 } 2165 if (!has_ready_input) { 2166 return; 2167 } 2168 } 2169 LOG(WARNING) << " Pending Node: " << node.DebugString(); 2170 for (int i = 0; i < node.num_inputs(); ++i) { 2171 const Entry& input = input_vector[input_base + i]; 2172 const Tensor* tensor = GetTensorValueForDump(input); 2173 if (tensor->IsInitialized()) { 2174 LOG(WARNING) << " Input " << i << ": " 2175 << strings::StrCat( 2176 "Tensor<type: ", DataTypeString(tensor->dtype()), 2177 " shape: ", tensor->shape().DebugString(), ">"); 2178 } else { 2179 LOG(WARNING) << " Input " << i << ": not present"; 2180 } 2181 } 2182 } 2183 2184 void ExecutorState::DumpActiveNodeState(const int node_id, 2185 const Entry* input_vector) { 2186 const NodeItem& node_item = *impl_->gview_.node(node_id); 2187 const Node& node = *node_item.node; 2188 LOG(WARNING) << " Active Node: " << node.DebugString(); 2189 const int input_base = node_item.input_start; 2190 for (int i = 0; i < node.num_inputs(); ++i) { 2191 const Entry& input = input_vector[input_base + i]; 2192 const Tensor* tensor = GetTensorValueForDump(input); 2193 if (tensor->IsInitialized()) { 2194 LOG(WARNING) << " Input " << i << ": " 2195 << strings::StrCat( 2196 "Tensor<type: ", DataTypeString(tensor->dtype()), 2197 " shape: ", tensor->shape().DebugString(), ">"); 2198 } else { 2199 LOG(WARNING) << " Input " << i << ": not present"; 2200 } 2201 } 2202 } 2203 2204 void ExecutorState::DumpIterationState(const FrameState* frame, 2205 IterationState* iteration) { 2206 const std::vector<const Node*>* nodes = frame->nodes; 2207 // Dump any waiting nodes that are holding on to tensors. 2208 for (const Node* node : *nodes) { 2209 const int node_id = node->id(); 2210 PendingCounts::Handle pending_id = impl_->gview_.node(node_id)->pending_id; 2211 if (iteration->node_state(pending_id) == PendingCounts::PENDING_NOTREADY || 2212 iteration->node_state(pending_id) == PendingCounts::PENDING_READY) { 2213 DumpPendingNodeState(node_id, iteration->input_tensors, false); 2214 } 2215 } 2216 // Then the active nodes. 2217 for (const Node* node : *nodes) { 2218 const int node_id = node->id(); 2219 PendingCounts::Handle pending_id = impl_->gview_.node(node_id)->pending_id; 2220 if (iteration->node_state(pending_id) == PendingCounts::STARTED) { 2221 DumpActiveNodeState(node_id, iteration->input_tensors); 2222 } 2223 } 2224 // Show all input tensors in use. 2225 const int total_input_tensors = frame->total_input_tensors; 2226 size_t total_bytes = 0; 2227 for (int i = 0; i < total_input_tensors; ++i) { 2228 const Entry& input = iteration->input_tensors[i]; 2229 const Tensor* tensor = GetTensorValueForDump(input); 2230 if (tensor->IsInitialized()) { 2231 LOG(WARNING) << " Input " << i << ": " 2232 << strings::StrCat( 2233 "Tensor<type: ", DataTypeString(tensor->dtype()), 2234 " shape: ", tensor->shape().DebugString(), 2235 ", bytes: ", tensor->TotalBytes(), ">"); 2236 total_bytes += tensor->TotalBytes(); 2237 } 2238 } 2239 LOG(WARNING) << " Total bytes " << total_bytes; 2240 } 2241 2242 void ExecutorState::DumpState() { 2243 mutex_lock l(mu_); 2244 if (!dumped_on_error_) { 2245 LOG(WARNING) << "Dumping state"; 2246 for (auto& frame : outstanding_frames_) { 2247 LOG(WARNING) << frame.first; 2248 FrameState* frame_state = frame.second; 2249 mutex_lock frame_lock(frame_state->mu); 2250 for (IterationState* iteration : frame_state->iterations) { 2251 LOG(WARNING) << " Iteration:"; 2252 DumpIterationState(frame_state, iteration); 2253 } 2254 } 2255 dumped_on_error_ = true; 2256 } 2257 } 2258 2259 void ExecutorState::Finish() { 2260 mu_.lock(); 2261 auto status = status_; 2262 auto done_cb = std::move(done_cb_); 2263 auto runner = std::move(runner_); 2264 mu_.unlock(); 2265 if (sync_on_finish_ && status.ok()) { 2266 // Block until the device has finished all queued operations. For 2267 // devices like GPUs that continue to execute Ops after their Compute 2268 // methods have completed, this ensures that control is not returned to 2269 // the user until the step (and its side-effects) has actually completed. 2270 status = impl_->params_.device->Sync(); 2271 } 2272 delete this; 2273 CHECK(done_cb != nullptr); 2274 runner([=]() { done_cb(status); }); 2275 } 2276 2277 void ExecutorState::FindOrCreateChildFrame(FrameState* frame, int64 iter, 2278 const Node* node, 2279 FrameState** child) { 2280 // Get the child frame name. 2281 string enter_name; 2282 Status s = GetNodeAttr(node->attrs(), "frame_name", &enter_name); 2283 DCHECK(s.ok()) << s; 2284 const string child_name = MakeFrameName(frame, iter, enter_name); 2285 2286 { 2287 mutex_lock executor_lock(mu_); 2288 auto it = outstanding_frames_.find(child_name); 2289 if (it != outstanding_frames_.end()) { 2290 *child = it->second; 2291 return; 2292 } 2293 } 2294 2295 // Need to create a new frame instance. 2296 // Note that this new frame instance is created without any locks. 2297 if (vlog_) VLOG(2) << "Create frame: " << child_name; 2298 2299 int parallel_iters; 2300 s = GetNodeAttr(node->attrs(), "parallel_iterations", ¶llel_iters); 2301 DCHECK(s.ok()) << s; 2302 FrameState* temp = new FrameState(impl_, parallel_iters); 2303 temp->frame_name = child_name; 2304 temp->frame_id = Hash64(child_name); 2305 temp->parent_frame = frame; 2306 temp->parent_iter = iter; 2307 temp->InitializeFrameInfo(enter_name); 2308 2309 // 'iterations' is a fixed-length circular buffer. 2310 temp->iterations.resize(temp->max_parallel_iterations + 1); 2311 // Initialize iteration 0. 2312 temp->iterations[0] = 2313 new IterationState(temp->pending_counts, temp->total_input_tensors); 2314 2315 { 2316 mutex_lock executor_lock(mu_); 2317 auto it = outstanding_frames_.find(child_name); 2318 if (it != outstanding_frames_.end()) { 2319 *child = it->second; 2320 } else { 2321 mutex_lock frame_lock(frame->mu); 2322 frame->GetIteration(iter)->outstanding_frame_count++; 2323 outstanding_frames_[child_name] = temp; 2324 *child = temp; 2325 temp = nullptr; 2326 } 2327 } 2328 delete temp; // Not used so delete it. 2329 } 2330 2331 void ExecutorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) { 2332 // First, propagate dead_exits (if any) to the parent frame. 2333 FrameState* parent_frame = frame->parent_frame; 2334 const int64 parent_iter = frame->parent_iter; 2335 if (parent_frame != nullptr) { 2336 mutex_lock paranet_frame_lock(parent_frame->mu); 2337 // Propagate all the dead exits to the parent frame. 2338 for (const Node* node : frame->dead_exits) { 2339 auto parent_iter_state = parent_frame->GetIteration(parent_iter); 2340 for (const Edge* e : node->out_edges()) { 2341 const Node* dst_node = e->dst(); 2342 2343 const auto dst_pending_id = 2344 impl_->gview_.node(dst_node->id())->pending_id; 2345 2346 // TODO(yuanbyu): We don't need this if we require the subgraph 2347 // given to an executor not to contain a sink node. 2348 if (dst_node->IsSink()) continue; 2349 2350 bool dst_dead = true; 2351 bool dst_ready = false; 2352 // We know this is a dead input to dst. 2353 if (IsMerge(dst_node)) { 2354 if (e->IsControlEdge()) { 2355 parent_iter_state->decrement_pending(dst_pending_id, 2); 2356 int count = parent_iter_state->pending(dst_pending_id); 2357 int dead_cnt = parent_iter_state->dead_count(dst_pending_id); 2358 dst_dead = (dead_cnt == dst_node->num_inputs()); 2359 dst_ready = (count == 0) || ((count == 1) && dst_dead); 2360 } else { 2361 parent_iter_state->increment_dead_count(dst_pending_id); 2362 const int dead_cnt = parent_iter_state->dead_count(dst_pending_id); 2363 dst_dead = (dead_cnt == dst_node->num_inputs()); 2364 dst_ready = 2365 (parent_iter_state->pending(dst_pending_id) == 1) && dst_dead; 2366 } 2367 } else { 2368 parent_iter_state->increment_dead_count(dst_pending_id); 2369 dst_ready = 2370 (parent_iter_state->decrement_pending(dst_pending_id, 1) == 0); 2371 } 2372 if (dst_ready) { 2373 if (IsControlTrigger(dst_node)) dst_dead = false; 2374 ready->push_back( 2375 TaggedNode(dst_node, parent_frame, parent_iter, dst_dead)); 2376 parent_iter_state->outstanding_ops++; 2377 } 2378 } 2379 } 2380 } 2381 2382 // Delete the frame. 2383 const string& frame_name = frame->frame_name; 2384 if (vlog_) VLOG(2) << "Delete frame " << frame_name; 2385 { 2386 mutex_lock executor_lock(mu_); 2387 outstanding_frames_.erase(frame_name); 2388 } 2389 delete frame; 2390 } 2391 2392 void ExecutorState::CleanupFramesIterations(FrameState* frame, int64 iter, 2393 TaggedNodeSeq* ready) { 2394 bool is_frame_done = false; 2395 { 2396 mutex_lock frame_lock(frame->mu); 2397 frame->GetIteration(iter)->outstanding_frame_count--; 2398 is_frame_done = frame->CleanupIterations(&impl_->gview_, iter, ready); 2399 } 2400 if (is_frame_done) { 2401 FrameState* parent_frame = frame->parent_frame; 2402 const int64 parent_iter = frame->parent_iter; 2403 DeleteFrame(frame, ready); 2404 if (parent_frame != nullptr) { 2405 // The completion of frame may cause completions in its parent frame. 2406 // So clean things up recursively. 2407 CleanupFramesIterations(parent_frame, parent_iter, ready); 2408 } 2409 } 2410 } 2411 2412 void ExecutorState::FrameState::ActivateNodes(const NodeItem* item, 2413 const bool is_dead, int64 iter, 2414 EntryVector* outputs, 2415 TaggedNodeSeq* ready) { 2416 const GraphView& gview = executor->gview_; 2417 IterationState* iter_state = GetIteration(iter); 2418 const size_t num_output_edges = item->num_output_edges; 2419 const EdgeInfo* edges = item->output_edge_list(); 2420 Entry* input_tensors = iter_state->input_tensors; 2421 for (size_t out_index = 0; out_index < num_output_edges; out_index++) { 2422 const EdgeInfo& e = edges[out_index]; 2423 const int dst_id = e.dst_id; 2424 const NodeItem* dst_item = gview.node(dst_id); 2425 const PendingCounts::Handle dst_pending_id = dst_item->pending_id; 2426 const int src_slot = e.output_slot; 2427 2428 // TODO(yuanbyu): We don't need this if we require the subgraph 2429 // given to an executor not to contain a sink node. 2430 if (dst_item->is_sink) continue; 2431 2432 bool dst_dead = false; 2433 bool dst_ready = false; 2434 // True iff this input for dst is needed. We only set this input for 2435 // dst if this flag is true. This is needed to make the thread safety 2436 // analysis happy. 2437 const bool is_control_edge = (src_slot == Graph::kControlSlot); 2438 bool dst_need_input = !is_control_edge; 2439 if (dst_item->is_merge) { 2440 // A merge node is ready if all control inputs have arrived and either 2441 // a) a live data input becomes available or b) all data inputs are 2442 // dead. For Merge, pending's LSB is set iff a live data input has 2443 // arrived. 2444 if (is_control_edge) { 2445 iter_state->decrement_pending(dst_pending_id, 2); 2446 int count = iter_state->pending(dst_pending_id); 2447 int dead_cnt = iter_state->dead_count(dst_pending_id); 2448 dst_dead = (dead_cnt == dst_item->num_inputs); 2449 dst_ready = (count == 0) || ((count == 1) && dst_dead); 2450 } else { 2451 if ((*outputs)[src_slot].has_value) { 2452 // This is a live data input. 2453 int count = iter_state->pending(dst_pending_id); 2454 iter_state->mark_live(dst_pending_id); 2455 // Only the first live edge sets the input and (potentially) 2456 // triggers execution. The low bit of count is set if and 2457 // only if no live input has been used yet (mark_live clears 2458 // it). The node should be started if and only if this is 2459 // the first live input and there are no pending control 2460 // edges, i.e. count == 1. 2461 dst_ready = (count == 1); 2462 dst_need_input = ((count & 0x1) == 1); 2463 } else { 2464 // This is a dead data input. Note that dst_node is dead if node is 2465 // a dead enter. We need this to handle properly a while loop on 2466 // the untaken branch of a conditional. 2467 // TODO(yuanbyu): This is a bit hacky, but a good solution for 2468 // now. 2469 iter_state->increment_dead_count(dst_pending_id); 2470 const int dead_cnt = iter_state->dead_count(dst_pending_id); 2471 dst_dead = (dead_cnt == dst_item->num_inputs) || item->is_enter; 2472 dst_ready = (iter_state->pending(dst_pending_id) == 1) && dst_dead; 2473 dst_need_input = false; 2474 } 2475 } 2476 } else { 2477 const bool increment_dead = 2478 (is_dead || (!is_control_edge && !(*outputs)[src_slot].has_value)); 2479 int pending, dead; 2480 iter_state->adjust_for_activation(dst_pending_id, increment_dead, 2481 &pending, &dead); 2482 dst_dead = (dead > 0); 2483 dst_ready = (pending == 0); 2484 } 2485 2486 if (dst_need_input) { 2487 const int dst_slot = e.input_slot; 2488 const int dst_loc = dst_item->input_start + dst_slot; 2489 if (e.is_last) { 2490 input_tensors[dst_loc] = std::move((*outputs)[src_slot]); 2491 } else { 2492 input_tensors[dst_loc] = (*outputs)[src_slot]; 2493 } 2494 } 2495 2496 // Add dst to the ready queue if it's ready 2497 if (dst_ready) { 2498 if (dst_item->is_control_trigger) dst_dead = false; 2499 ready->push_back(TaggedNode(dst_item->node, this, iter, dst_dead)); 2500 iter_state->outstanding_ops++; 2501 } 2502 } 2503 } 2504 2505 void ExecutorState::FrameState::ActivateNexts(const GraphView* gview, 2506 int64 iter, 2507 TaggedNodeSeq* ready) { 2508 // Propagate the deferred NextIteration nodes to the new iteration. 2509 for (auto& node_entry : next_iter_roots) { 2510 const Node* node = node_entry.first; 2511 const Entry& entry = node_entry.second; 2512 const bool is_dead = !entry.has_value; 2513 const NodeItem* item = gview->node(node->id()); 2514 EntryVector outputs{entry}; 2515 ActivateNodes(item, is_dead, iter, &outputs, ready); 2516 } 2517 next_iter_roots.clear(); 2518 } 2519 2520 void ExecutorState::FrameState::ActivateLoopInvs(const GraphView* gview, 2521 int64 iter, 2522 TaggedNodeSeq* ready) { 2523 // Propagate loop invariants to the new iteration. 2524 for (auto& node_entry : inv_values) { 2525 const Node* node = node_entry.first; 2526 const Entry& entry = node_entry.second; 2527 const bool is_dead = !entry.has_value; 2528 const NodeItem* item = gview->node(node->id()); 2529 EntryVector outputs{entry}; 2530 ActivateNodes(item, is_dead, iter, &outputs, ready); 2531 } 2532 } 2533 2534 void ExecutorState::FrameState::AddLoopInv(const NodeItem* item, 2535 const Entry& entry, 2536 TaggedNodeSeq* ready) { 2537 // Store this value. 2538 inv_values.push_back({item->node, entry}); 2539 2540 // Make this value available to all iterations. 2541 const bool is_dead = !entry.has_value; 2542 for (int i = 0; i <= iteration_count; ++i) { 2543 EntryVector outputs{entry}; 2544 ActivateNodes(item, is_dead, i, &outputs, ready); 2545 } 2546 } 2547 2548 bool ExecutorState::FrameState::IsIterationDone(int64 iter) { 2549 IterationState* iter_state = GetIteration(iter); 2550 if (iter_state->outstanding_ops == 0 && 2551 iter_state->outstanding_frame_count == 0) { 2552 if (iter == 0) { 2553 // The enclosing frame has no pending input. 2554 return num_pending_inputs == 0; 2555 } else { 2556 // The preceding iteration is deleted (and therefore done). 2557 return (GetIteration(iter - 1) == nullptr); 2558 } 2559 } 2560 return false; 2561 } 2562 2563 void ExecutorState::FrameState::IncrementIteration(const GraphView* gview, 2564 TaggedNodeSeq* ready) { 2565 iteration_count++; 2566 const int64 next_iter = iteration_count; 2567 2568 // Initialize the next iteration. 2569 IterationState* iter_state = 2570 new IterationState(pending_counts, total_input_tensors); 2571 SetIteration(next_iter, iter_state); 2572 num_outstanding_iterations++; 2573 dead_exits.clear(); 2574 2575 // Activate the successors of the deferred roots in the new iteration. 2576 ActivateNexts(gview, next_iter, ready); 2577 2578 // Activate the loop invariants in the new iteration. 2579 ActivateLoopInvs(gview, next_iter, ready); 2580 } 2581 2582 bool ExecutorState::FrameState::CleanupIterations(const GraphView* gview, 2583 int64 iter, 2584 TaggedNodeSeq* ready) { 2585 int64 curr_iter = iter; 2586 while (curr_iter <= iteration_count && IsIterationDone(curr_iter)) { 2587 // Delete the iteration curr_iter. 2588 delete GetIteration(curr_iter); 2589 SetIteration(curr_iter, nullptr); 2590 --num_outstanding_iterations; 2591 ++curr_iter; 2592 2593 // When one iteration is completed, we check for deferred iteration, 2594 // and start it if there is one. 2595 if (!next_iter_roots.empty()) { 2596 IncrementIteration(gview, ready); 2597 } 2598 } 2599 return IsFrameDone(); 2600 } 2601 2602 void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) { 2603 (new ExecutorState(args, this))->RunAsync(std::move(done)); 2604 } 2605 2606 } // end namespace 2607 2608 Status NewLocalExecutor(const LocalExecutorParams& params, 2609 std::unique_ptr<const Graph> graph, 2610 Executor** executor) { 2611 ExecutorImpl* impl = new ExecutorImpl(params, std::move(graph)); 2612 const Status s = impl->Initialize(); 2613 if (s.ok()) { 2614 *executor = impl; 2615 } else { 2616 delete impl; 2617 } 2618 return s; 2619 } 2620 2621 Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib, 2622 const NodeDef& ndef, int graph_def_version, 2623 OpKernel** kernel) { 2624 const auto device_type = DeviceType(device->attributes().device_type()); 2625 auto allocator = device->GetAllocator(AllocatorAttributes()); 2626 return CreateOpKernel(device_type, device, allocator, flib, ndef, 2627 graph_def_version, kernel); 2628 } 2629 2630 void DeleteNonCachedKernel(OpKernel* kernel) { delete kernel; } 2631 2632 } // end namespace tensorflow 2633