1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/grappler/costs/virtual_scheduler.h" 17 18 #include <math.h> 19 20 #include "tensorflow/core/framework/allocation_description.pb.h" 21 #include "tensorflow/core/framework/attr_value.pb.h" 22 #include "tensorflow/core/framework/node_def.pb.h" 23 #include "tensorflow/core/framework/tensor.pb.h" 24 #include "tensorflow/core/framework/tensor_description.pb.h" 25 #include "tensorflow/core/framework/tensor_shape.pb.h" 26 #include "tensorflow/core/grappler/clusters/utils.h" 27 #include "tensorflow/core/grappler/costs/utils.h" 28 #include "tensorflow/core/grappler/op_types.h" 29 #include "tensorflow/core/grappler/utils.h" 30 #include "tensorflow/core/lib/core/errors.h" 31 #include "tensorflow/core/lib/strings/numbers.h" 32 #include "tensorflow/core/lib/strings/str_util.h" 33 #include "tensorflow/core/platform/logging.h" 34 #include "tensorflow/core/util/device_name_utils.h" 35 36 namespace tensorflow { 37 namespace grappler { 38 namespace { 39 40 Costs CombineCosts(const Costs& left, const Costs& right) { 41 CHECK_NE(left.max_memory, kMemoryUnknown); 42 CHECK_NE(left.max_per_op_buffers, kMemoryUnknown); 43 CHECK_NE(left.max_per_op_streaming, kMemoryUnknown); 44 45 Costs result = left; 46 result.execution_time += right.execution_time; 47 if (right.inaccurate) { 48 result.inaccurate = true; 49 } 50 if (right.max_memory != kMemoryUnknown) { 51 result.max_memory += right.max_memory; 52 } 53 if (right.max_per_op_buffers != kMemoryUnknown) { 54 result.max_per_op_buffers = 55 std::max(left.max_per_op_buffers, right.max_per_op_buffers); 56 } 57 if (right.max_per_op_streaming != kMemoryUnknown) { 58 result.max_per_op_streaming = 59 std::max(left.max_per_op_streaming, right.max_per_op_streaming); 60 } 61 VLOG(4) << "costs execution_time=" << result.execution_time.count() 62 << " max_memory=" << result.max_memory 63 << " max_per_op_buffers=" << result.max_per_op_buffers 64 << " max_per_op_streaming=" << result.max_per_op_streaming; 65 return result; 66 } 67 68 // Key to the cached _Recv ops map, and its hash and predicate structures. 69 struct RecvNodeDescriptor { 70 const NodeDef* node; 71 const int port_num; 72 const string device; 73 74 RecvNodeDescriptor(const NodeDef* node_, const int port_num_, 75 const string& device_) 76 : node(node_), port_num(port_num_), device(device_) {} 77 }; 78 79 struct RecvNodeDescriptorHash { 80 std::size_t operator()(const RecvNodeDescriptor& recv_node) const { 81 return std::hash<const NodeDef*>()(recv_node.node) ^ 82 std::hash<int>()(recv_node.port_num) ^ 83 std::hash<string>()(recv_node.device); 84 } 85 }; 86 87 struct RecvNodeDescriptorEqual { 88 bool operator()(const RecvNodeDescriptor& a, 89 const RecvNodeDescriptor& b) const { 90 return a.node == b.node && a.port_num == b.port_num && a.device == b.device; 91 } 92 }; 93 } // namespace 94 95 // ReadyNodeManager 96 const NodeDef* LIFOManager::GetCurrNode() { 97 CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node"; 98 if (curr_pos_ == nodes_.end()) { 99 curr_pos_ = --(nodes_.rbegin().base()); // Last one in the list. 100 } 101 // Once curr_pos_ is set to a valid entry in the list, we keep using the 102 // cached curr_pos_ until RemoveCurrNode() is called. AddNode() will not 103 // change the GetCurrNode() return value. 104 return *curr_pos_; 105 } 106 107 void LIFOManager::RemoveCurrNode() { 108 // Make sure we have curr_pos_ ready to be removed. 109 GetCurrNode(); 110 // Note curr_pos_ may not be pointing the last element if some nodes are 111 // added. 112 nodes_.erase(curr_pos_); 113 114 curr_pos_ = nodes_.end(); // Reset curr_pos_. 115 } 116 117 FirstReadyManager::FirstReadyManager() : ReadyNodeManager() { 118 std::make_heap(nodes_.begin(), nodes_.end()); 119 } 120 121 void FirstReadyManager::Init( 122 const std::unordered_map<const NodeDef*, NodeState>* node_state) { 123 // Reset the node state since different instances of the scheduler can reuse 124 // the same node_manager. 125 node_state_ = node_state; 126 nodes_.clear(); 127 waiting_queue_.clear(); 128 greater_ = [this](const NodeDef* a, const NodeDef* b) -> bool { 129 if (node_state_->at(a).time_ready == node_state_->at(b).time_ready) { 130 // Use Node name as tie-breaker for deterministic node scheduling. 131 return a->name().compare(b->name()) > 0; 132 } else { 133 // Note: we need a node with minimum time_ready, not 134 // maximum; hence, using a > b for comparison function. 135 return node_state_->at(a).time_ready > node_state_->at(b).time_ready; 136 } 137 }; 138 } 139 140 const NodeDef* FirstReadyManager::GetCurrNode() { 141 if (nodes_.empty()) { 142 // Nothing in the node_; probably, the very first call. Move 143 // waiting_queue_ to node_. 144 DrainWaitingQueue(); 145 CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node"; 146 } 147 return nodes_.front(); 148 } 149 150 void FirstReadyManager::RemoveCurrNode() { 151 if (nodes_.empty()) { 152 // Make sure that there is a node to be removed at the front of nodes_. 153 GetCurrNode(); 154 } 155 std::pop_heap(nodes_.begin(), nodes_.end(), greater_); 156 nodes_.pop_back(); 157 DrainWaitingQueue(); 158 } 159 160 bool FirstReadyManager::Empty() const { 161 return nodes_.empty() && waiting_queue_.empty(); 162 } 163 164 void FirstReadyManager::DrainWaitingQueue() { 165 for (const auto* node : waiting_queue_) { 166 // push_heap in AddNode() and pop_heap in RemoveCurrNode() guarantees that 167 // the first element is the node with minimum time_ready. 168 nodes_.push_back(node); 169 std::push_heap(nodes_.begin(), nodes_.end(), greater_); 170 } 171 waiting_queue_.clear(); 172 } 173 174 CompositeNodeManager::CompositeNodeManager() 175 : ReadyNodeManager(), send_manager_(), recv_manager_() {} 176 177 void CompositeNodeManager::Init( 178 const std::unordered_map<const NodeDef*, NodeState>* node_state) { 179 node_state_ = node_state; 180 send_manager_.Init(node_state); 181 recv_manager_.Init(node_state); 182 curr_node_ = nullptr; 183 } 184 185 void CompositeNodeManager::AddNode(const NodeDef* node) { 186 if (IsSend(*node)) { 187 send_manager_.AddNode(node); 188 } else if (IsRecv(*node)) { 189 recv_manager_.AddNode(node); 190 } else { 191 const auto& device = node_state_->at(node).device_name; 192 ops_lifo_map_[device].AddNode(node); 193 } 194 } 195 196 const NodeDef* CompositeNodeManager::GetCurrNode() { 197 if (curr_node_) return curr_node_; 198 199 // Per-device LIFO for normal ops (not _Send / _Recv), 200 // FirstReady for _Send and _Recv (separately), 201 // Globally (among the LIFO-selected ops from each device and _Send and 202 // _Recv) FirstReady, 203 // Priorty order: _Send, _Recv, and then the rest, if time_ready is equal. 204 std::vector<std::pair<const NodeDef*, Costs::Duration>> candidates; 205 for (auto& ops_lifo : ops_lifo_map_) { 206 if (!ops_lifo.second.Empty()) { 207 const auto* op = ops_lifo.second.GetCurrNode(); 208 candidates.emplace_back(op, node_state_->at(op).time_ready); 209 } 210 } 211 if (!send_manager_.Empty()) { 212 const auto* send = send_manager_.GetCurrNode(); 213 candidates.emplace_back(send, node_state_->at(send).time_ready); 214 } 215 if (!recv_manager_.Empty()) { 216 const auto* recv = recv_manager_.GetCurrNode(); 217 candidates.emplace_back(recv, node_state_->at(recv).time_ready); 218 } 219 CHECK(!candidates.empty()); 220 auto first_ready = std::min_element( 221 candidates.begin(), candidates.end(), 222 [](const std::pair<const NodeDef*, Costs::Duration>& a, 223 const std::pair<const NodeDef*, Costs::Duration>& b) { 224 if (a.second == b.second) { 225 // Note that there can be only 1 Send and only 1 Recv in candidates, 226 // at most; hence, score is 2 for Send, 1 for Recv, and 0 for a 227 // normap op, and a_score and b_score are equal only if both are 228 // normal ops. 229 int a_score = 2 * IsSend(*a.first) + IsRecv(*a.first); 230 int b_score = 2 * IsSend(*b.first) + IsRecv(*b.first); 231 if (a_score == b_score) { 232 // Both are normal ops; use node name as tie breaker. 233 return a.first->name().compare(b.first->name()) < 0; 234 } else { 235 // Priortize by op type: _Send, _Recv, and normap ops. 236 return a_score > b_score; 237 } 238 } else { 239 return a.second < b.second; 240 } 241 }); 242 // Next time we call GetCurrNode(), it just returns the cached one, 243 // curr_node_ until we call RemovCurrNode(). 244 curr_node_ = first_ready->first; 245 246 return curr_node_; 247 } 248 249 void CompositeNodeManager::RemoveCurrNode() { 250 const auto* node = GetCurrNode(); 251 if (IsSend(*node)) { 252 send_manager_.RemoveCurrNode(); 253 } else if (IsRecv(*node)) { 254 recv_manager_.RemoveCurrNode(); 255 } else { 256 const auto device = node_state_->at(node).device_name; 257 ops_lifo_map_[device].RemoveCurrNode(); 258 } 259 // Reset curr_node_ so that GetCurrNode() finds another node. 260 curr_node_ = nullptr; 261 } 262 263 bool CompositeNodeManager::Empty() const { 264 // Empty if all the ready managers are empty. 265 bool empty = true; 266 for (const auto& ops_lifo : ops_lifo_map_) { 267 empty &= ops_lifo.second.Empty(); 268 } 269 return empty && send_manager_.Empty() && recv_manager_.Empty(); 270 } 271 272 VirtualScheduler::VirtualScheduler(const GrapplerItem* grappler_item, 273 const bool use_static_shapes, 274 Cluster* cluster, 275 ReadyNodeManager* ready_nodes) 276 : ready_nodes_(ready_nodes), 277 graph_costs_(Costs::ZeroCosts()), 278 graph_properties_(*grappler_item), 279 cluster_(cluster), 280 grappler_item_(grappler_item), 281 use_static_shapes_(use_static_shapes), 282 placer_(cluster) { 283 initialized_ = false; 284 } 285 286 ReadyNodeManager* VirtualScheduler::ReadyNodeManagerFactory( 287 const string& ready_node_manager) { 288 if (ready_node_manager == "FIFO") { 289 return new FIFOManager(); 290 } else if (ready_node_manager == "LIFO") { 291 return new LIFOManager(); 292 } else if (ready_node_manager == "FirstReady") { 293 return new FirstReadyManager(); 294 } else if (ready_node_manager == "Composite") { 295 return new CompositeNodeManager(); 296 } 297 LOG(FATAL) << "Not a valid ready node manager: " << ready_node_manager; 298 } 299 300 Status VirtualScheduler::Init() { 301 // Init() preprocesses the input grappler_item and graph_properties to extract 302 // necessary information for emulating tensorflow op scheduling and 303 // construct internal data structures (NodeState and DeviceState) for virtual 304 // scheduling. 305 ready_nodes_->Init(GetNodeStates()); 306 // Construct graph properties. 307 Status status; 308 if (use_static_shapes_) { 309 status = graph_properties_.InferStatically(true); 310 } else { 311 status = graph_properties_.InferDynamically(cluster_); 312 } 313 if (!status.ok()) { 314 return status; 315 } 316 317 const auto& graph = grappler_item_->graph; 318 const auto& fetch_nodes = grappler_item_->fetch; 319 std::set<string> feed_nodes; 320 for (const auto& f : grappler_item_->feed) { 321 auto iter_and_inserted_flag = feed_nodes.insert(f.first); 322 QCHECK(iter_and_inserted_flag.second) 323 << "Duplicate feed node found: " << f.first; 324 } 325 326 // Get the nodes that would run to output fetch_nodes. 327 bool ill_formed = false; 328 std::vector<const NodeDef*> nodes = 329 ComputeTransitiveFanin(graph, fetch_nodes, &ill_formed); 330 if (ill_formed) { 331 return errors::InvalidArgument( 332 "Ill formed graph or invalid set of fetch nodes specified"); 333 } 334 335 // TODO(dyoon): this is a bit inefficient as name_to_node is already built in 336 // ComputeTransitiveFanin(). 337 // Once ComputeTransitiveFanin is complete, only the nodes that can be reached 338 // from the fetch nodes are scheduled. So the scheduled nodes should be 339 // exactly the same as those executed for real. One possible discrepancy could 340 // be the control flow nodes, where tf only executes one path. 341 std::unordered_map<string, const NodeDef*> name_to_node; 342 for (const auto& node : nodes) { 343 name_to_node[node->name()] = node; 344 } 345 346 // TODO(dyoon): Instead of identifying _Send node here manually, add _Send 347 // to _Recv as control dependency when creating GrapplerItem. 348 std::unordered_map<string, const NodeDef*> name_to_send; 349 for (const auto& node : graph.node()) { 350 if (IsSend(node)) { 351 const auto& attr = node.attr(); 352 name_to_send[attr.at("tensor_name").s()] = &node; 353 } 354 } 355 356 // To reuse _Recv ops. 357 std::unordered_map<RecvNodeDescriptor, const NodeDef*, RecvNodeDescriptorHash, 358 RecvNodeDescriptorEqual> 359 cached_recv_nodes; 360 361 // Build node_map; for each node, create its NodeState and connect its inputs 362 // and outputs. 363 for (const auto* curr_node : nodes) { 364 auto& curr_node_state = GetNodeStateOrCreateIt(curr_node); 365 const string curr_node_device = DeviceName(curr_node); 366 std::vector<string> inputs; 367 if (IsRecv(*curr_node)) { 368 const auto& attr = curr_node->attr(); 369 const NodeDef* send = name_to_send[attr.at("tensor_name").s()]; 370 inputs = {send->name()}; 371 } else { 372 for (const string& input : curr_node->input()) { 373 inputs.push_back(input); 374 } 375 } 376 for (const string& input_node_name : inputs) { 377 // Note that input_node_name may be in <prefix><node_name>:<port_num> 378 // format, where <prefix> (e.g., "^" for control dependency) and 379 // ":<port_num>" may be omitted. NodeName() extracts only the node_name. 380 const NodeDef* input_node = name_to_node[NodeName(input_node_name)]; 381 382 CHECK(input_node); 383 const string in_device = DeviceName(input_node); 384 const auto input_node_port_num = NodePosition(input_node_name); 385 386 if (curr_node_device == in_device) { 387 // Same device: connect input_node and curr_node directly. 388 curr_node_state.inputs.push_back( 389 std::make_pair(input_node, input_node_port_num)); 390 auto& input_node_state = GetNodeStateOrCreateIt(input_node); 391 input_node_state.outputs[input_node_port_num].push_back(curr_node); 392 } else { 393 RecvNodeDescriptor recv_node(input_node, input_node_port_num, 394 curr_node_device); 395 auto it = cached_recv_nodes.find(recv_node); 396 if (it != cached_recv_nodes.end()) { 397 // Different device, but found an already-cached copy (a _Recv op); 398 // connect the _Recv to curr_node. 399 const NodeDef* recv_op = it->second; 400 // recv_op's output port is hard-coded to zero. 401 curr_node_state.inputs.push_back(std::make_pair(recv_op, 0)); 402 auto& input_node_state = node_map_.at(recv_op); 403 input_node_state.outputs[0].push_back(curr_node); 404 } else { 405 // Different device, no cached copy; transfer input_node to the 406 // curr_node's device. 407 auto send_and_recv = 408 CreateSendRecv(input_node, curr_node, input_node_name); 409 // Note that CreateSendRecv() already connected input/output between 410 // _Send and _Recv ops. 411 const auto* send = send_and_recv.first; 412 const auto* recv = send_and_recv.second; 413 // recv_op's output port is hard-coded to zero. 414 curr_node_state.inputs.push_back(std::make_pair(recv, 0)); 415 auto& input_node_state = GetNodeStateOrCreateIt(input_node); 416 input_node_state.outputs[input_node_port_num].push_back(send); 417 418 // Cache the _Recv op for future use. 419 cached_recv_nodes[recv_node] = recv; 420 } 421 } 422 } 423 424 // Special case: given feed nodes are ready at time 0. 425 const bool given_as_feed = 426 feed_nodes.find(curr_node->name()) != feed_nodes.end(); 427 428 // Default case: node without inputs are ready at time 0. 429 const bool has_no_inputs = curr_node->input().empty(); 430 431 if (!IsRecv(*curr_node) && (given_as_feed || has_no_inputs)) { 432 curr_node_state.time_ready = Costs::Duration(); 433 ready_nodes_->AddNode(curr_node); 434 VLOG(3) << "Added ready node: " << curr_node->name(); 435 } 436 437 feed_nodes.erase(curr_node->name()); 438 439 if (IsPersistentNode(curr_node)) { 440 auto& device_state = device_[curr_node_device]; 441 for (int port_num = 0; 442 port_num < curr_node_state.output_properties.size(); ++port_num) { 443 device_state.persistent_nodes.insert( 444 std::make_pair(curr_node, port_num)); 445 } 446 } 447 } 448 449 if (ready_nodes_->Empty()) { 450 return errors::InvalidArgument("No ready nodes in the graph."); 451 } 452 453 if (!feed_nodes.empty()) { 454 return errors::InvalidArgument( 455 strings::StrCat("Some feed nodes were not found in the graph: ", 456 str_util::Join(feed_nodes, ","))); 457 } 458 initialized_ = true; 459 return Status::OK(); 460 } 461 462 void VirtualScheduler::MaybeUpdateInputOutput(const NodeDef* node) { 463 CHECK(!initialized_) << "MaybeUpdateInputOutput is called after Init()."; 464 // This method is called when NodeState is created and adds input and output 465 // properties for a few exceptional cases that GraphProperties cannot provide 466 // input/output properties. 467 if ((IsSend(*node) || IsRecv(*node)) && node->attr().count(kAttrInputSrc)) { 468 // _Send and _Recv ops created from VirtualScheduler have kAttrInputSrc 469 // attr; normal _Send and _Recv ops (from the input graph) do not have that 470 // attr. 471 auto& node_state = node_map_[node]; 472 auto& inputs = node_state.input_properties; 473 auto& outputs = node_state.output_properties; 474 475 // _Send and _Recv ops are created from VirtualScheduler, so 476 // there should be no inputs TensorProperties. 477 CHECK(inputs.empty()); 478 CHECK(outputs.empty()); 479 const auto& attr = node->attr(); 480 // This is the original input source to the _Send and _Recv, and this 481 // string includes "^" if it was control dependency, and output port 482 /// (e.g., ":2") if the input source had multiple outputs. 483 const auto& input_source_name = attr.at(kAttrInputSrc).s(); 484 if (IsControlInput(input_source_name)) { 485 // Control dependency; regardless of the input source tensor size, 486 // send 4B. 487 OpInfo::TensorProperties control_message; 488 control_message.set_dtype(DT_FLOAT); 489 control_message.mutable_shape()->add_dim()->set_size(1); 490 auto* value = control_message.mutable_value(); 491 value->add_float_val(1); 492 inputs.push_back(control_message); 493 outputs.push_back(control_message); 494 } else { 495 auto output_properties = 496 graph_properties_.GetOutputProperties(NodeName(input_source_name)); 497 // Like with HasInputProperties, if a node does not have output 498 // properties, it's likely it was pruned during the shape inference run. 499 if (!output_properties.empty()) { 500 const auto input_node_port_num = NodePosition(input_source_name); 501 // Use the input source's output property as _Send and _Recv's input 502 // property. 503 CHECK_GT(output_properties.size(), input_node_port_num); 504 inputs.push_back(output_properties[input_node_port_num]); 505 outputs.push_back(output_properties[input_node_port_num]); 506 } 507 } 508 } 509 } 510 511 float VirtualScheduler::Round2(const float x) const { 512 // Not using std::round from <cmath> here because not all platforms seem to 513 // support that (specifically Android). 514 return ::round(100.0 * x) / 100.0; 515 } 516 517 bool VirtualScheduler::IsPersistentNode(const NodeDef* node) const { 518 // Variables are persistent nodes. 519 return IsVariable(*node); 520 } 521 522 string VirtualScheduler::DeviceName(const NodeDef* node) const { 523 return placer_.get_canonical_device_name(*node); 524 } 525 526 string VirtualScheduler::SanitizedDeviceName(const NodeDef* node) const { 527 // Replace the ":" characters that may be present in the device name with "_". 528 // This makes it possible to then use the resulting string in a node name. 529 return str_util::StringReplace(placer_.get_canonical_device_name(*node), ":", 530 "_", true); 531 } 532 533 string VirtualScheduler::ChannelDeviceName(const NodeDef* from, 534 const NodeDef* to) const { 535 CHECK(!initialized_) << "ChannelDeviceName is called after Init()."; 536 return kChannelDevice + "_from_" + SanitizedDeviceName(from) + "_to_" + 537 SanitizedDeviceName(to); 538 } 539 540 std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv( 541 const NodeDef* from, const NodeDef* to, const string& input_name) { 542 CHECK(!initialized_) << "CreateSendRecv is called after Init()."; 543 544 // Connect "from" node to "to" node with _Send and _Recv such that 545 // from -> _Send -> _Recv -> to. 546 // _Send is placed on "Channel" device, and _Recv is on the same device 547 // as "to" node. 548 // input_node_name is the string from the "to" node to identify which output 549 // we get from the "from" node. 550 551 // Note that we use NodeState for scheduling, so _Send and _Recv 552 // NodeDefs created here need not be correct: in terms of name, 553 // input names, attrs, etc. 554 555 auto input_node_port_num = NodePosition(input_name); 556 string src_name; 557 if (input_node_port_num >= 0) { 558 src_name = strings::StrCat(from->name(), "_", input_node_port_num); 559 } else { 560 src_name = strings::StrCat(from->name(), "_minus1"); 561 } 562 563 // _Send op. 564 auto* send = new NodeDef(); 565 send->set_name("Send_" + src_name + "_from_" + SanitizedDeviceName(from) + 566 "_to_" + SanitizedDeviceName(to)); 567 send->set_op("_Send"); 568 send->add_input(from->name()); 569 send->set_device(ChannelDeviceName(from, to)); 570 auto& send_attr = *(send->mutable_attr()); 571 send_attr[kAttrInputSrc].set_s(input_name); 572 send_attr[kAttrSrcDevice].set_s(DeviceName(from)); 573 send_attr[kAttrDstDevice].set_s(DeviceName(to)); 574 575 // _Recv op. 576 auto* recv = new NodeDef(); 577 recv->set_name("Recv_" + src_name + "_on_" + SanitizedDeviceName(to)); 578 recv->set_op("_Recv"); 579 recv->add_input(send->name()); 580 recv->set_device(DeviceName(to)); 581 auto& recv_attr = *(recv->mutable_attr()); 582 recv_attr[kAttrInputSrc].set_s(input_name); 583 584 // NodeState for _Send op. 585 auto& send_node_state = GetNodeStateOrCreateIt(send); 586 send_node_state.device_name = send->device(); // Set Channel device. 587 send_node_state.inputs.push_back(std::make_pair(from, input_node_port_num)); 588 send_node_state.outputs[0].push_back(recv); 589 590 // NodeState for _Recv op. 591 auto& recv_node_state = GetNodeStateOrCreateIt(recv); 592 recv_node_state.inputs.push_back(std::make_pair(send, 0)); 593 recv_node_state.outputs[0].push_back(to); 594 595 // Keep the created nodes. 596 additional_nodes_.emplace_back(std::unique_ptr<NodeDef>(send)); 597 additional_nodes_.emplace_back(std::unique_ptr<NodeDef>(recv)); 598 599 // Return _Send and _Recv. 600 return std::make_pair(send, recv); 601 } 602 603 OpContext VirtualScheduler::GetCurrNode() const { 604 const NodeDef* node = ready_nodes_->GetCurrNode(); 605 606 // Get the device from the placer. 607 DeviceProperties device; 608 device = placer_.get_device(*node); 609 610 // Special case for _Send op. 611 if (IsSend(*node)) { 612 device.set_type(kChannelDevice); 613 } 614 615 // Construct OpContext. 616 OpContext op_context; 617 const auto& node_state = node_map_.at(node); 618 op_context.name = node->name(); 619 op_context.device_name = node_state.device_name; 620 auto& op_info = op_context.op_info; 621 op_info.set_op(node->op()); 622 *op_info.mutable_attr() = node->attr(); 623 for (auto& input : node_state.input_properties) { 624 *op_info.add_inputs() = input; 625 } 626 for (auto& output : node_state.output_properties) { 627 *op_info.add_outputs() = output; 628 } 629 op_info.mutable_device()->Swap(&device); 630 631 if (grappler_item_->graph.has_library()) { 632 op_context.function_library = &grappler_item_->graph.library(); 633 } 634 return op_context; 635 } 636 637 NodeState& VirtualScheduler::GetNodeStateOrCreateIt(const NodeDef* node) { 638 CHECK(!initialized_) << "GetNodeStateOrCreateIt is called after Init()."; 639 640 auto it = node_map_.find(node); 641 if (it == node_map_.end()) { 642 // Not found; create a NodeState for this node. 643 it = node_map_.emplace(node, NodeState()).first; 644 auto& node_state = it->second; 645 node_state.input_properties = 646 graph_properties_.GetInputProperties(node->name()); 647 node_state.output_properties = 648 graph_properties_.GetOutputProperties(node->name()); 649 650 // Some ops may need further processing to the input / output properties: 651 // _Send and _Recv. 652 MaybeUpdateInputOutput(node); 653 654 if (!IsSend(*node)) { 655 node_state.device_name = DeviceName(node); 656 // For _Send op, device_name will be set to Channel in CreateSendRecv(). 657 } 658 659 // Initialize output port related data: 660 // Assume the size of OutputProperties represents the number of output ports 661 // of this node. 662 for (size_t i = 0; i < node_state.output_properties.size(); ++i) { 663 node_state.time_no_references[i] = Costs::Duration::max(); 664 node_state.num_outputs_executed[i] = 0; 665 // Populate an empty vector for each port. The caller will add nodes 666 // that use this port as input. 667 node_state.outputs[i] = {}; 668 } 669 // Port_num -1 is for control dependency. 670 node_state.time_no_references[-1] = Costs::Duration::max(); 671 node_state.num_outputs_executed[-1] = 0; 672 node_state.outputs[-1] = {}; 673 } 674 return it->second; 675 } 676 677 int64 VirtualScheduler::CalculateOutputSize( 678 const std::vector<OpInfo::TensorProperties>& output_properties, 679 const int port_num) const { 680 if (port_num < 0) { 681 return 4; // 4B for control dependency. 682 } 683 684 if (port_num >= output_properties.size()) { 685 VLOG(3) << "VirtualScheduler::CalculateOutputSize() -- " 686 << "port_num: " << port_num 687 << " >= output_properties.size(): " << output_properties.size(); 688 return 0; 689 } 690 691 const auto& output = output_properties[port_num]; 692 int64 output_size = DataTypeSize(BaseType(output.dtype())); 693 694 for (const auto& dim : output.shape().dim()) { 695 auto dim_size = dim.size(); 696 if (dim_size < 0) { 697 // Zero output size if there's any unknown dim. 698 output_size = 0; 699 VLOG(3) << "VirtualScheduler::CalculateOutputSize() -- " 700 << "unknown dim: " << output_size; 701 break; 702 } 703 output_size *= dim_size; 704 } 705 706 return output_size; 707 } 708 709 Costs& VirtualScheduler::FindOrCreateZero(const string& op_name, 710 std::map<string, Costs>* op_cost) { 711 auto it = op_cost->find(op_name); 712 if (it == op_cost->end()) { 713 // Note that default constructor of Costs sets some memory related fields 714 // to unknown values so we should explicitly initialize it with ZeroCosts. 715 it = op_cost->emplace(op_name, Costs::ZeroCosts()).first; 716 } 717 return it->second; 718 } 719 720 bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) { 721 // Update graph_costs_ and per-op costs. 722 graph_costs_ = CombineCosts(graph_costs_, node_costs); 723 const NodeDef* node = ready_nodes_->GetCurrNode(); 724 const string& op_name = node->op(); 725 726 // Also keep track of op counts and times per op (with their shapes). 727 OpContext op_context = GetCurrNode(); 728 string node_description = GetOpDescription(op_context.op_info); 729 op_counts_[node_description] += 1; 730 op_costs_[node_description] = 731 std::make_pair(node_costs.execution_time.asMicroSeconds().count(), 732 !node_costs.inaccurate); 733 734 auto& op_cost = FindOrCreateZero(op_name, &op_to_cost_); 735 op_cost = CombineCosts(op_cost, node_costs); 736 737 // Update node and device states. 738 auto& node_state = node_map_[node]; 739 auto& device = device_[node_state.device_name]; 740 device.nodes_executed.push_back(node); 741 // Node is scheduled when the device is available AND all the inputs are 742 // ready; hence, time_scheduled is time_ready if time_ready > device curr 743 // time. 744 node_state.time_scheduled = 745 std::max(device.GetCurrTime(), node_state.time_ready); 746 // Override device curr time with the time_scheduled. 747 device.device_costs.execution_time = node_state.time_scheduled; 748 device.device_costs = CombineCosts(device.device_costs, node_costs); 749 auto curr_time = device.GetCurrTime(); 750 node_state.time_finished = curr_time; 751 752 // Update device memory usage. 753 if (!IsPersistentNode(node)) { 754 for (const auto& port_num_output_pair : node_state.outputs) { 755 int port_num = port_num_output_pair.first; 756 // There's a chance that a specific output is not used at all. 757 if (node_state.outputs[port_num].empty()) { 758 node_state.time_no_references[port_num] = curr_time; 759 } else { 760 device.memory_usage += 761 CalculateOutputSize(node_state.output_properties, port_num); 762 device.nodes_in_memory.insert(std::make_pair(node, port_num)); 763 } 764 } 765 } 766 767 // Update device's per-op cost. 768 auto& device_op_cost = FindOrCreateZero(op_name, &device.op_to_cost); 769 device_op_cost = CombineCosts(device_op_cost, node_costs); 770 771 VLOG(3) << "Op scheduled -- name: " << node->name() << ", op: " << node->op() 772 << ", device: " << node->device() 773 << ", ready: " << node_state.time_ready.count() 774 << ", scheduled: " << node_state.time_scheduled.count() 775 << ", finished: " << node_state.time_finished.count(); 776 777 // Increment num_inputs_ready of the output nodes 778 for (const auto& port_num_output_pair : node_state.outputs) { 779 for (auto* output_node : port_num_output_pair.second) { 780 auto& output_state = node_map_[output_node]; 781 output_state.num_inputs_ready++; 782 // Execute a node as soon as all its inputs are ready. Merge nodes are 783 // special since they run as soon as one of their inputs becomes 784 // available. 785 if (output_state.num_inputs_ready == output_state.inputs.size() || 786 IsMerge(*output_node)) { 787 // This output node is now ready. 788 output_state.time_ready = curr_time; 789 ready_nodes_->AddNode(output_node); 790 } 791 } 792 } 793 794 // Increment num_outputs_executed of the input nodes. 795 for (const auto& input_port : node_state.inputs) { 796 auto* input = input_port.first; 797 auto port = input_port.second; 798 auto& input_state = node_map_[input]; 799 input_state.num_outputs_executed[port]++; 800 if (input_state.num_outputs_executed[port] == 801 input_state.outputs[port].size() && 802 !IsPersistentNode(input)) { 803 // All the outputs are executed; no reference to this output port of 804 // input node. 805 input_state.time_no_references[port] = curr_time; 806 auto& input_device = device_[input_state.device_name]; 807 input_device.memory_usage -= 808 CalculateOutputSize(input_state.output_properties, port); 809 810 input_device.nodes_in_memory.erase(std::make_pair(input, port)); 811 } 812 } 813 814 if (!IsPersistentNode(node)) { 815 // Now that output memory is added and used up nodes are deallocated, 816 // check max memory usage. 817 if (device.memory_usage > device.max_memory_usage) { 818 device.max_memory_usage = device.memory_usage; 819 device.mem_usage_snapshot_at_peak = device.nodes_in_memory; 820 } 821 } 822 823 // Remove the current node; assume FIFO. 824 ready_nodes_->RemoveCurrNode(); 825 826 return !ready_nodes_->Empty(); 827 } 828 829 Costs VirtualScheduler::Summary() const { 830 // Print out basic execution summary. 831 VLOG(1) << "Expected execution time: " << graph_costs_.execution_time.count(); 832 VLOG(1) << "Expected max memory: " << graph_costs_.max_memory; 833 VLOG(1) << "Expected max per-op buffers: " << graph_costs_.max_per_op_buffers; 834 VLOG(1) << "Expected max per-op streaming buffers: " 835 << graph_costs_.max_per_op_streaming; 836 837 VLOG(1) << "Per-op execution time:"; 838 for (const auto& op_cost_pair : op_to_cost_) { 839 const auto& op = op_cost_pair.first; 840 const auto& cost = op_cost_pair.second.execution_time.count(); 841 const bool is_op_cost_accurate = !op_cost_pair.second.inaccurate; 842 if (cost) { // Skip printing out zero-cost ops. 843 VLOG(1) << " + " << op << " : " << (is_op_cost_accurate ? "" : "~") 844 << cost; 845 } 846 } 847 848 // Print per device summary 849 VLOG(1) << "Devices:"; 850 Costs critical_path_costs = Costs::ZeroCosts(); 851 852 for (const auto& device : device_) { 853 const auto& name = device.first; 854 const auto& state = device.second; 855 856 std::map<string, int64> op_to_memory; 857 // First profile only persistent memory usage. 858 int64 persistent_memory_usage = 0; 859 std::set<string> persisent_ops; 860 for (const auto& node_port : state.persistent_nodes) { 861 const auto* node = node_port.first; 862 const auto port = node_port.second; 863 const auto output_size = 864 CalculateOutputSize(node_map_.at(node).output_properties, port); 865 persistent_memory_usage += output_size; 866 op_to_memory[node->op()] += output_size; 867 persisent_ops.insert(node->op()); 868 } 869 int64 max_memory_usage = persistent_memory_usage + state.max_memory_usage; 870 critical_path_costs.estimated_max_memory_per_device[name] = 871 max_memory_usage; 872 873 const Costs::NanoSeconds wall_time_ns = state.GetCurrTime(); 874 VLOG(1) << "Device = " << name 875 << ", num_nodes = " << state.nodes_executed.size() 876 << ", wall_time_ns = " << wall_time_ns.count() << ", memory usage: " 877 << "persistent = " 878 << strings::HumanReadableNumBytes(persistent_memory_usage) 879 << ", peak = " 880 << strings::HumanReadableNumBytes(state.max_memory_usage) 881 << ", total = " << strings::HumanReadableNumBytes(max_memory_usage) 882 << ", at the end: " 883 << strings::HumanReadableNumBytes(state.memory_usage); 884 885 VLOG(1) << "Per-op execution time (and memory usage at peak memory usage):"; 886 887 // Profile non-persistent op memory usage. 888 for (const auto& node_port : state.mem_usage_snapshot_at_peak) { 889 const auto* node = node_port.first; 890 const auto port = node_port.second; 891 op_to_memory[node->op()] += 892 CalculateOutputSize(node_map_.at(node).output_properties, port); 893 } 894 Costs::NanoSeconds total_compute_time_ns; 895 bool is_total_cost_accurate = true; 896 for (const auto& op_cost_pair : state.op_to_cost) { 897 const auto& op = op_cost_pair.first; 898 const auto& cost = op_cost_pair.second.execution_time.count(); 899 total_compute_time_ns += op_cost_pair.second.execution_time; 900 const bool is_op_cost_accurate = !op_cost_pair.second.inaccurate; 901 if (!is_op_cost_accurate) { 902 is_total_cost_accurate = false; 903 } 904 905 int64 op_mem_usage = 0; 906 auto it = op_to_memory.find(op); 907 if (it != op_to_memory.end()) { 908 op_mem_usage = it->second; 909 } 910 911 const float mem_usage_percent = 912 max_memory_usage > 0 ? Round2(100.0 * op_mem_usage / max_memory_usage) 913 : 0.0; 914 if (cost || mem_usage_percent > 1.0) { 915 // Print out only non-zero cost ops or ops with > 1% memory usage. 916 VLOG(1) << " + " << op << " : " << (is_op_cost_accurate ? "" : "~") 917 << cost << " (" << strings::HumanReadableNumBytes(op_mem_usage) 918 << " [" << mem_usage_percent << "%] " 919 << (persisent_ops.count(op) > 0 ? ": persistent op)" : ")"); 920 } 921 } 922 923 int utilization = 0; 924 if (wall_time_ns.count() > 0) { 925 utilization = total_compute_time_ns.count() * 100 / wall_time_ns.count(); 926 } 927 VLOG(1) << "Device = " << name << ", total_compute_time_ns = " 928 << (is_total_cost_accurate ? "" : "~") 929 << total_compute_time_ns.count() 930 << ", utilization = " << utilization << "%"; 931 932 if (critical_path_costs.execution_time <= state.GetCurrTime()) { 933 critical_path_costs = state.device_costs; 934 } 935 } 936 937 if (VLOG_IS_ON(2)) { 938 // Also log the op description and their corresponding counts. 939 VLOG(2) << "Node description, counts, cost:"; 940 for (const auto& item : op_counts_) { 941 int cost; 942 bool is_cost_accurate; 943 std::tie(cost, is_cost_accurate) = op_costs_.at(item.first); 944 VLOG(2) << "Node: " << item.first << ", Count: " << item.second 945 << ", Individual Cost: " << (is_cost_accurate ? "" : "~") << cost; 946 } 947 } 948 949 VLOG(1) << "Critical path execution time: " 950 << critical_path_costs.execution_time.count(); 951 return critical_path_costs; 952 } 953 954 Costs VirtualScheduler::Summary(RunMetadata* metadata) { 955 if (metadata != nullptr) { 956 StepStats* stepstats = metadata->mutable_step_stats(); 957 for (const auto& device : device_) { 958 GraphDef* device_partition_graph = metadata->add_partition_graphs(); 959 DeviceStepStats* device_stepstats = stepstats->add_dev_stats(); 960 device_stepstats->set_device(device.first); 961 for (const auto& node_def : device.second.nodes_executed) { 962 const NodeState& nodestate = node_map_.at(node_def); 963 NodeExecStats* node_stats = device_stepstats->add_node_stats(); 964 uint64 total_output_size = 0; 965 for (int slot = 0; slot < nodestate.output_properties.size(); slot++) { 966 const auto& properties = nodestate.output_properties[slot]; 967 NodeOutput* no = node_stats->add_output(); 968 no->set_slot(slot); 969 TensorDescription* tensor_descr = no->mutable_tensor_description(); 970 tensor_descr->set_dtype(properties.dtype()); 971 *tensor_descr->mutable_shape() = properties.shape(); 972 // Optional allocation description. 973 const auto tensor_size = 974 CalculateOutputSize(nodestate.output_properties, slot); 975 total_output_size += tensor_size; 976 tensor_descr->mutable_allocation_description()->set_requested_bytes( 977 tensor_size); 978 tensor_descr->mutable_allocation_description()->set_allocated_bytes( 979 tensor_size); 980 } 981 node_stats->set_timeline_label(node_def->op()); 982 node_stats->set_node_name(node_def->name()); 983 node_stats->set_op_start_rel_micros(0); 984 node_stats->set_all_start_micros( 985 nodestate.time_scheduled.asMicroSeconds().count()); 986 node_stats->set_op_end_rel_micros( 987 nodestate.time_finished.asMicroSeconds().count() - 988 nodestate.time_scheduled.asMicroSeconds().count()); 989 node_stats->set_all_end_rel_micros( 990 nodestate.time_finished.asMicroSeconds().count() - 991 nodestate.time_scheduled.asMicroSeconds().count()); 992 auto* mem_stats = node_stats->mutable_memory_stats(); 993 // VirtualScheduler does not specify scratch pad memory usage. 994 mem_stats->set_temp_memory_size(0); 995 int64 persistent_memory_size = 0; 996 if (IsPersistentNode(node_def)) { 997 persistent_memory_size = total_output_size; 998 } 999 mem_stats->set_persistent_memory_size(persistent_memory_size); 1000 *device_partition_graph->add_node() = *node_def; 1001 } 1002 } 1003 } 1004 return Summary(); 1005 } 1006 1007 const std::unordered_map<string, int64> VirtualScheduler::GetPeakMemoryUsage() 1008 const { 1009 std::unordered_map<string, int64> result; 1010 for (const auto& device : device_) { 1011 const string& name = device.first; 1012 const DeviceState& state = device.second; 1013 result[name] = state.max_memory_usage; 1014 } 1015 return result; 1016 } 1017 1018 } // end namespace grappler 1019 } // end namespace tensorflow 1020