1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/common_runtime/placer.h" 17 18 #include <memory> 19 #include <set> 20 #include <utility> 21 #include <vector> 22 23 #include "tensorflow/core/common_runtime/device.h" 24 #include "tensorflow/core/framework/device_attributes.pb.h" 25 #include "tensorflow/core/framework/graph.pb.h" 26 #include "tensorflow/core/framework/node_def_util.h" 27 #include "tensorflow/core/framework/op_kernel.h" 28 #include "tensorflow/core/framework/types.h" 29 #include "tensorflow/core/framework/types.pb.h" 30 #include "tensorflow/core/lib/core/errors.h" 31 #include "tensorflow/core/lib/core/stringpiece.h" 32 33 namespace tensorflow { 34 35 namespace { 36 37 // We hoist the conversion from C-style string literal to StringPiece here, 38 // so that we can avoid the many repeated calls to strlen(). 39 const StringPiece kColocationAttrNameStringPiece(kColocationAttrName); 40 const StringPiece kColocationGroupPrefixStringPiece(kColocationGroupPrefix); 41 42 // Returns a list of devices sorted by preferred type and then name 43 // from 'devices' whose type is in 'supported_device_types'. This 44 // function searches the device types in 'supported_device_types' and 45 // returns the subset of devices that match. 46 std::vector<Device*> FilterSupportedDevices( 47 const std::vector<Device*>& devices, 48 const DeviceTypeVector& supported_device_types) { 49 std::vector<Device*> filtered_devices; 50 for (const DeviceType& d : supported_device_types) { 51 for (Device* device : devices) { 52 if (DeviceType(device->attributes().device_type()) == d) { 53 filtered_devices.emplace_back(device); 54 } 55 } 56 } 57 58 auto device_sort = [](const Device* a, const Device* b) { 59 auto a_priority = DeviceSet::DeviceTypeOrder(DeviceType(a->device_type())); 60 auto b_priority = DeviceSet::DeviceTypeOrder(DeviceType(b->device_type())); 61 // First sort by prioritized device type (higher is preferred) and 62 // then by device name (lexicographically). 63 if (a_priority != b_priority) { 64 return a_priority > b_priority; 65 } 66 return StringPiece(a->name()) < StringPiece(b->name()); 67 }; 68 std::sort(filtered_devices.begin(), filtered_devices.end(), device_sort); 69 return filtered_devices; 70 } 71 72 // This class maintains the connected components of a colocation 73 // constraint graph, and uses this information to assign a satisfying 74 // device placement to the nodes of the graph. 75 // 76 // The typical usage pattern is: 77 // 78 // Graph graph = ...; 79 // DeviceSet device_set = ...; 80 // ColocationGraph colocation_graph(graph, device_set); 81 // 82 // // Add all the nodes of graph to colocation_graph. 83 // for (Node* node : graph.nodes()) { 84 // TF_RETURN_IF_ERROR(colocation_graph.AddNode(*node)); 85 // } 86 // 87 // // Add one or more colocation constraint. 88 // Node node_1 = *graph.FindNodeId(...); 89 // Node node_2 = *graph.FindNodeId(...); 90 // TF_RETURN_IF_ERROR(colocation_graph.ColocateNodes(node_1, node_2)); 91 // 92 // // Assign devices based on the accumulated constraints. 93 // for (Node* node : graph.nodes()) { 94 // TF_RETURN_IF_ERROR(colocation_graph.AssignDevice(node)); 95 // } 96 // 97 // The implementation uses the union-find algorithm to maintain the 98 // connected components efficiently and incrementally as edges 99 // (implied by ColocationGraph::ColocateNodes() invocations) are added. 100 class ColocationGraph { 101 public: 102 ColocationGraph(Graph* graph, const DeviceSet* device_set, 103 bool allow_soft_placement) 104 : graph_(graph), 105 device_set_(device_set), 106 device_types_(device_set->PrioritizedDeviceTypeList()), 107 allow_soft_placement_(allow_soft_placement) { 108 members_.resize(graph->num_node_ids()); 109 } 110 111 // Adds each node of the Graph to this ColocationGraph as a singleton. 112 // 113 // NOTE: The implementation assumes that the ids of nodes passed to 114 // this method are dense and zero-based; the memory used will be linear in 115 // the largest node ID. 116 // NOTE: If this method returns an error, *this is left in an undefined 117 // state. 118 Status ColocateAllNodes() { 119 // This maps from a colocation group identifier to the 'root' of that 120 // colocation group. Note that the keys in this map are StringPiece; the 121 // actual strings are stored under the NodeDef. The lifetime of this map 122 // is limited to this ColocateAllNodes() method, and no part of the 123 // NodeDef trees are changed during the lifetime of this method, so using 124 // StringPiece as a key is safe. 125 // 126 // Also, as a further optimization, we remove the "loc:@" prefix from 127 // "class" attribute values, when they are used as keys in this table. 128 // This allows us to use StringPiece values that refer to substrings of 129 // 'string' values stored in NodeDef attribute lists, as well as StringPiece 130 // values that refer to 'string' values from NodeDef::name(), without 131 // performing any string allocations. 132 std::unordered_map<StringPiece, const Node*, StringPieceHasher> 133 colocation_group_root; 134 135 for (Node* node : graph_->nodes()) { 136 if (!node->IsOp()) { 137 continue; 138 } 139 140 // When adding the node, identify whether it is part of a 141 // colocation group. 142 143 // This code is effectively the equivalent of GetNodeAttr() for a string 144 // array, but it avoids all internal allocations (the allocation of the 145 // backing store of the std::vector<string> as well as the copies of the 146 // strings within it). Instead, we combine the query of the colocation 147 // attribute with the calls to ColocateNodeToGroup. 148 bool found_spec = false; 149 const AttrValue* attr_value = 150 node->attrs().Find(kColocationAttrNameStringPiece); 151 if (attr_value != nullptr && attr_value->has_list()) { 152 for (const string& class_spec : attr_value->list().s()) { 153 StringPiece spec(class_spec); 154 if (spec.Consume(kColocationGroupPrefixStringPiece)) { 155 found_spec = true; 156 TF_RETURN_IF_ERROR( 157 ColocateNodeToGroup(&colocation_group_root, node, spec)); 158 } 159 } 160 } 161 162 if (!found_spec) { 163 // If the node does not specify a colocation group, then use the 164 // name of this node as the colocation group. 165 TF_RETURN_IF_ERROR( 166 ColocateNodeToGroup(&colocation_group_root, node, node->name())); 167 } 168 } 169 170 return Status::OK(); 171 } 172 173 Status ColocateNodeToGroup( 174 std::unordered_map<StringPiece, const Node*, StringPieceHasher>* 175 colocation_group_root, 176 Node* node, StringPiece colocation_group) { 177 const Node*& root_node = (*colocation_group_root)[colocation_group]; 178 if (root_node == nullptr) { 179 // This is the first node of the colocation group, so 180 // designate this node as the 'root' of that colocation group. 181 root_node = node; 182 } else { 183 // Try to colocate the node with the root. If there is an 184 // error, return it. 185 Status s = ColocateNodes(*node, *root_node); 186 if (!s.ok()) { 187 return AttachDef(s, *node); 188 } 189 } 190 return Status::OK(); 191 } 192 193 // Merge the (possibly disjoint) sets containing nodes "x" and 194 // "y". Returns OK if the all nodes in the union of these sets can 195 // be placed on the same device type. 196 // 197 // NOTE: If this method returns an error, *this is left in an undefined 198 // state. 199 Status ColocateNodes(const Node& x, const Node& y) { 200 int x_root = FindRoot(x.id()); 201 int y_root = FindRoot(y.id()); 202 return ColocateNodes(x, x_root, y, y_root); 203 } 204 205 // This overload of ColocateNodes() allows a caller to provide the root node 206 // ids for the two nodes. For large graphs, this noticeably reduces the 207 // graph load time. 208 Status ColocateNodes(const Node& x, int x_root, const Node& y, int y_root) { 209 if (x_root == y_root) { 210 return Status::OK(); 211 } 212 213 DCHECK_EQ(x_root, FindRoot(x.id())); 214 DCHECK_EQ(y_root, FindRoot(y.id())); 215 216 Member& x_root_member = members_[x_root]; 217 Member& y_root_member = members_[y_root]; 218 219 // Merge the sets by swinging the parent pointer of the smaller 220 // tree to point to the root of the larger tree. Together with 221 // path compression in ColocationGraph::FindRoot, this ensures 222 // that we do not experience pathological performance on graphs 223 // such as chains. 224 int new_root, old_root; 225 if (x_root_member.rank < y_root_member.rank) { 226 // The tree rooted at x_root is shallower, so connect it to 227 // y_root. The rank of y_root is unchanged because its new 228 // child has strictly less rank. 229 x_root_member.parent = y_root; 230 new_root = y_root; 231 old_root = x_root; 232 } else if (x_root_member.rank > y_root_member.rank) { 233 // The tree rooted at y_root is shallower, so connect it to 234 // x_root. The rank of x_root is unchanged because its new 235 // child has strictly less rank. 236 y_root_member.parent = x_root; 237 new_root = x_root; 238 old_root = y_root; 239 } else { 240 // Both trees have the same rank, so break the tie by choosing 241 // x_root as the new root. 242 y_root_member.parent = x_root; 243 // Increment the rank of the tree rooted at x_root, because it 244 // is now strictly deeper than before. 245 ++x_root_member.rank; 246 new_root = x_root; 247 old_root = y_root; 248 } 249 250 Member& new_root_member = members_[new_root]; 251 Member& old_root_member = members_[old_root]; 252 253 // Merge the partial device specifications, and ensure that they are 254 // compatible. NULL options_ is treated as allowing soft placement. 255 // TODO(mrry): Consider enriching the error message by pointing 256 // out which nodes have the explicit partial device 257 // specifications that caused this conflict. 258 Status s = DeviceNameUtils::MergeDevNames(&new_root_member.device_name, 259 old_root_member.device_name, 260 allow_soft_placement_); 261 if (!s.ok()) { 262 return errors::InvalidArgument("Cannot colocate nodes '", x.name(), 263 "' and '", y.name(), ": ", 264 s.error_message()); 265 } 266 267 // Ensure that the common root has at least one supported device 268 // type, by computing the intersection of 269 // new_root_member.supported_device_types and 270 // old_root_member.supported_device_types. 271 MergeSupportedDevices(&new_root_member.supported_device_types, 272 old_root_member.supported_device_types); 273 if (new_root_member.supported_device_types.empty()) { 274 return errors::InvalidArgument( 275 "Cannot colocate nodes '", x.name(), "' and '", y.name(), 276 "' because no device type supports both of those nodes and the " 277 "other nodes colocated with them.", 278 DebugInfo(x_root), DebugInfo(y_root)); 279 } 280 281 return Status::OK(); 282 } 283 284 // For the given node, subject to the constraints previously given 285 // to this ColocationGraph, set its assigned_device_name. Returns OK 286 // if a satisfying device can be found, otherwise an error. 287 // 288 // Note: This method returns a pointer to a field within members_. 289 // The caller must not use the returned pointer after there is any possibility 290 // that the members_[i].possible_devices field has been modified. 291 Status GetDevicesForNode(Node* node, 292 std::vector<Device*>** possible_devices) { 293 *possible_devices = nullptr; 294 const int node_root = FindRoot(node->id()); 295 if (!members_[node_root].possible_devices.empty()) { 296 *possible_devices = &members_[node_root].possible_devices; 297 return Status::OK(); 298 } 299 300 // We have not yet computed the possible devices for the 301 // colocated node set containing 'node', so we do so now using the 302 // constraints on the root node. 303 304 // "devices" will contain the set of feasible placements for the 305 // colocated node set containing 'node'. 306 std::vector<Device*> devices; 307 if (DeviceNameUtils::HasSomeDetails(members_[node_root].device_name)) { 308 // The root node has a (possibly partial) device 309 // specification, so enumerate the physical devices that 310 // conform to it. 311 device_set_->FindMatchingDevices(members_[node_root].device_name, 312 &devices); 313 314 if (!devices.empty()) { 315 // Filter devices into those that are compatible with the root 316 // node (and its children). 317 devices = FilterSupportedDevices( 318 devices, members_[node_root].supported_device_types); 319 } 320 321 // Perform soft placement if allow_soft_placement_ is set. 322 if (devices.empty() && allow_soft_placement_) { 323 // The soft_device_name is the same as the node's device name 324 // without specifying the device type or ID. 325 DeviceNameUtils::ParsedName soft_device_name = 326 members_[node_root].device_name; 327 soft_device_name.type.clear(); 328 soft_device_name.has_type = false; 329 soft_device_name.has_id = false; 330 device_set_->FindMatchingDevices(soft_device_name, &devices); 331 if (!devices.empty()) { 332 devices = FilterSupportedDevices( 333 devices, members_[node_root].supported_device_types); 334 } 335 } 336 337 if (devices.empty()) { 338 // Return an error when a physical device that matches an explicit 339 // device specification is not found. This ensures that we don't 340 // assign a node to GPU when the user wanted to force it on CPU. 341 string debug_info = DebugInfo(node_root); 342 343 DeviceNameUtils::ParsedName specified_device_name; 344 if (DeviceNameUtils::ParseFullName(node->requested_device(), 345 &specified_device_name) && 346 specified_device_name == members_[node_root].device_name) { 347 // The specified device and merged set device match, and 348 // will appear in the GraphDef (for debugging), so just 349 // print the specified device. 350 std::vector<Device*> devices_matching_nodedef; 351 device_set_->FindMatchingDevices(specified_device_name, 352 &devices_matching_nodedef); 353 if (devices_matching_nodedef.empty()) { 354 // Sometimes it is almost impossible to understand the problem 355 // without a list of available devices. 356 std::vector<string> device_names; 357 for (const Device* device : device_set_->devices()) { 358 device_names.push_back(device->name()); 359 } 360 std::sort(device_names.begin(), device_names.end()); 361 362 return errors::InvalidArgument( 363 "Operation was explicitly assigned to ", 364 node->requested_device(), " but available devices are [ ", 365 str_util::Join(device_names, ", "), " ]. Make sure ", 366 "the device specification refers to a valid device."); 367 } else if (specified_device_name.has_type) { 368 return errors::InvalidArgument( 369 "Could not satisfy explicit device specification '", 370 node->requested_device(), "' because no supported kernel for ", 371 specified_device_name.type, " devices is available.", 372 debug_info, "\nRegistered kernels:\n", 373 KernelsRegisteredForOp(node->type_string())); 374 } else { 375 return errors::InvalidArgument( 376 "Could not satisfy explicit device specification '", 377 node->requested_device(), debug_info); 378 } 379 } else { 380 // The specified device may be a valid device but the 381 // merged set device is different, so print both. 382 return errors::InvalidArgument( 383 "Could not satisfy explicit device specification '", 384 node->requested_device(), 385 "' because the node was colocated with a group of nodes that " 386 "required incompatible device '", 387 DeviceNameUtils::ParsedNameToString( 388 members_[node_root].device_name), 389 "'", debug_info); 390 } 391 } 392 } else { 393 // The device is completely unspecified, so enumerate the devices that 394 // support all of the nodes in the set. 395 if (device_set_->devices().empty()) { 396 return errors::Internal("No devices are registered"); 397 } 398 devices = FilterSupportedDevices( 399 device_set_->devices(), members_[node_root].supported_device_types); 400 401 if (devices.empty()) { 402 return errors::InvalidArgument( 403 "Node had no OpKernel registered to support this operation: ", 404 "Operation was ", node->type_string(), " and inputs were ", 405 DataTypeVectorString(node->input_types()), DebugInfo(node_root)); 406 } 407 } 408 409 // Cache the result of the possible devices for this node group. 410 members_[node_root].possible_devices = std::move(devices); 411 *possible_devices = &members_[node_root].possible_devices; 412 return Status::OK(); 413 } 414 415 Status InitializeMembers() { 416 for (Node* node : graph_->nodes()) { 417 if (!node->IsOp()) { 418 continue; 419 } 420 Status status = InitializeMember(*node, &members_[node->id()]); 421 if (!status.ok()) { 422 return AttachDef(status, *node); 423 } 424 } 425 return Status::OK(); 426 } 427 428 // Represents a node in the disjoint node set forest, and the 429 // accumulated constraints on the device used by that node. 430 struct Member { 431 Member() = default; 432 // The id of the node that is the parent of this one, or its own 433 // id if it is a root. parent <= 0 indicates that this member is invalid. 434 int parent = -1; 435 436 // A proxy for the depth of the tree that is used to prefer 437 // connecting smaller trees to larger trees when merging disjoint 438 // sets. 439 int rank = 0; 440 441 // The intersection of all device types supported by this node, 442 // and those of all of its children, in priority order 443 // of the preferred device. 444 DeviceTypeVector supported_device_types; 445 446 // The merged form of the device requested for this node, with 447 // those of all of its children. 448 DeviceNameUtils::ParsedName device_name; 449 450 // If this node is a root, stores a list of Devices to which this node 451 // and all of its children have been assigned, or nullptr if this 452 // has not yet been computed. 453 std::vector<Device*> possible_devices; 454 }; 455 456 // Returns debugging info for the node referred to by 'node_root'. 457 string DebugInfo(const int node_root) { 458 string text( 459 "\nColocation Debug Info:\n" 460 "Colocation group had the following types and devices: "); 461 462 // If this node is part of a colocation group, then we want to 463 // collect the mapping of ops to supported devices, so that 464 // the user can see why an unsatisfiable placement occurred. 465 466 std::unordered_map<string, string> type_to_devices; 467 int num_nodes_found = 0; 468 469 for (const Node* node : graph_->nodes()) { 470 if (!node->IsOp()) { 471 continue; 472 } 473 int id = node->id(); 474 if (FindRoot(id) != node_root) { 475 continue; 476 } 477 ++num_nodes_found; 478 const string& op_type = node->type_string(); 479 string devices_registered; 480 for (const auto& device_type : members_[id].supported_device_types) { 481 strings::StrAppend(&devices_registered, DeviceTypeString(device_type), 482 " "); 483 } 484 485 type_to_devices[op_type] = std::move(devices_registered); 486 } 487 488 for (const auto& td : type_to_devices) { 489 strings::StrAppend(&text, "\n", td.first, ": ", td.second); 490 } 491 492 if (num_nodes_found <= 1) { 493 text.clear(); 494 } 495 return text; 496 } 497 498 Status InitializeMember(const Node& node, Member* member) { 499 const int id = node.id(); 500 DCHECK_GE(id, 0); 501 member->parent = id; 502 TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode( 503 device_types_, node.def(), &member->supported_device_types)); 504 505 if (node.has_assigned_device_name()) { 506 // This node has already been assigned to a device, so we 507 // respect this placement, after sanity-checking it. The 508 // device_name and supported_device_types for this node reflect 509 // the assigned device, so any nodes colocated with this node 510 // will be assigned to the same device (assuming this is 511 // possible). 512 // NOTE: Since any assignment must have been performed by 513 // the TensorFlow runtime, we consider errors in this branch to 514 // be INTERNAL. 515 const string& assigned_device_name = node.assigned_device_name(); 516 if (!DeviceNameUtils::ParseFullName(assigned_device_name, 517 &member->device_name)) { 518 return errors::Internal("Malformed assigned device '", 519 assigned_device_name, "'"); 520 } 521 const Device* assigned_device = 522 device_set_->FindDeviceByName(assigned_device_name); 523 if (assigned_device == nullptr) { 524 return errors::Internal("Assigned device '", assigned_device_name, 525 "' does not match any device"); 526 } 527 528 for (const DeviceType& d : member->supported_device_types) { 529 if (DeviceType(assigned_device->attributes().device_type()) == d) { 530 return Status::OK(); 531 } 532 } 533 534 return errors::Internal("Assigned device '", assigned_device_name, 535 "' does not have registered OpKernel support " 536 "for ", 537 node.type_string()); 538 } else { 539 // This node has not yet been assigned to a device, so we 540 // calculate any constraints due to the set of registered 541 // kernels and any (partial) user-provided device specification 542 // in the NodeDef. 543 544 // If no kernels are registered for this op type, fail with an error. 545 if (member->supported_device_types.empty()) { 546 std::set<string> registered_device_types; 547 for (Device* d : device_set_->devices()) { 548 registered_device_types.insert(d->device_type()); 549 } 550 return errors::InvalidArgument( 551 "No OpKernel was registered to support Op '", node.type_string(), 552 "' with these attrs. Registered devices: [", 553 str_util::Join(registered_device_types, ","), 554 "], Registered kernels:\n", 555 KernelsRegisteredForOp(node.type_string())); 556 } 557 558 // If the NodeDef contains a device, then we interpret it as a 559 // (partial) device specification. 560 if (!node.requested_device().empty()) { 561 // The user has specified a device in the NodeDef, try to find a 562 // valid device matching their specification in the set of 563 // devices. 564 // NOTE: The full name may specify a device that is not in 565 // n.supported_device_types(), but we check that in AssignDevice(). 566 if (!DeviceNameUtils::ParseFullName(node.requested_device(), 567 &member->device_name)) { 568 return errors::InvalidArgument("Malformed device specification '", 569 node.requested_device(), "'"); 570 } 571 } 572 } 573 return Status::OK(); 574 } 575 576 // Updates target to contain the intersection of the device types in 577 // "target" and "other". 578 static void MergeSupportedDevices(DeviceTypeVector* target, 579 const DeviceTypeVector& other) { 580 DeviceTypeVector temp = *target; 581 target->clear(); 582 583 // Iterate in priority order. 584 for (const DeviceType& device_type : temp) { 585 bool found = false; 586 for (const DeviceType& other_device_type : other) { 587 if (device_type == other_device_type) { 588 found = true; 589 break; 590 } 591 } 592 if (found) { 593 target->push_back(device_type); 594 } 595 } 596 } 597 598 // Returns the root node of the disjoint tree to which the node with the 599 // given id is connected. 600 int FindRoot(int node_id) { 601 Member& member = members_[node_id]; 602 603 int parent = member.parent; 604 DCHECK_GE(parent, 0); 605 606 if (parent != node_id) { 607 // NOTE: Compress paths from node_id to its root, so that future 608 // calls to FindRoot and ColocateNodes are more efficient. 609 int root = FindRoot(parent); 610 if (parent != root) { 611 parent = root; 612 member.parent = root; 613 } 614 } 615 616 DCHECK_GE(parent, 0); 617 return parent; 618 } 619 620 Graph* const graph_; // Not owned. 621 std::vector<Member> members_; 622 const DeviceSet* device_set_; // Not owned. 623 const std::vector<DeviceType> device_types_; 624 const bool allow_soft_placement_; 625 }; 626 627 // Returns true if the node has no inputs and produces outputs 628 // that are consumed by a single node. 629 // 630 // TODO(vrv): Currently this handles only nodes with one output, but 631 // this could be extended to handle the case where a node has many 632 // outputs that are connected to nodes in the same colocation group. 633 bool IsGeneratorNode(const Node* node) { 634 return node->num_inputs() == 0 && node->num_outputs() == 1 && 635 !IsRefType(node->output_type(0)); 636 } 637 638 } // namespace 639 640 Placer::Placer(Graph* graph, const DeviceSet* devices, 641 const SessionOptions* options) 642 : graph_(graph), 643 devices_(devices), 644 options_(options), 645 log_device_placement_(options != nullptr && 646 options->config.log_device_placement()) {} 647 648 Placer::Placer(Graph* graph, const DeviceSet* devices) 649 : Placer(graph, devices, nullptr) {} 650 651 Placer::~Placer() {} 652 653 Status Placer::Run() { 654 if (devices_->devices().empty()) { 655 return errors::FailedPrecondition("No devices are registered"); 656 } 657 658 ColocationGraph colocation_graph( 659 graph_, devices_, 660 options_ == nullptr || options_->config.allow_soft_placement()); 661 662 TF_RETURN_IF_ERROR(colocation_graph.InitializeMembers()); 663 664 // 1. First add all of the nodes. Note that steps (1) and (2) 665 // requires two passes over the nodes because the graph (and hence 666 // the constraints) may not be acyclic. 667 TF_RETURN_IF_ERROR(colocation_graph.ColocateAllNodes()); 668 669 // 2. Enumerate the constraint edges, and use them to update the disjoint 670 // node set. 671 672 // If `node` has an input edge with reference type, add an 673 // edge from the source of that edge to `node`. 674 for (const Edge* edge : graph_->edges()) { 675 if (edge->IsControlEdge()) { 676 continue; 677 } 678 Node* src = edge->src(); 679 Node* dst = edge->dst(); 680 DataType input_type = dst->input_type(edge->dst_input()); 681 if (input_type == DT_RESOURCE || IsRefType(input_type)) { 682 int src_root_id = colocation_graph.FindRoot(src->id()); 683 int dst_root_id = colocation_graph.FindRoot(dst->id()); 684 auto& src_root = colocation_graph.members_[src_root_id]; 685 auto& dst_root = colocation_graph.members_[dst_root_id]; 686 // If both the source node and this node have partially 687 // specified a device, then 'node's device should be 688 // cleared: the reference edge forces 'node' to be on the 689 // same device as the source node. 690 const auto& source_parsed_name = src_root.device_name; 691 const auto& dest_parsed_name = dst_root.device_name; 692 if (DeviceNameUtils::HasSomeDetails(source_parsed_name) && 693 DeviceNameUtils::HasSomeDetails(dest_parsed_name)) { 694 // Ignore a specified device for 'dst' if the two names were 695 // incompatible. 696 if (!DeviceNameUtils::AreCompatibleDevNames(source_parsed_name, 697 dest_parsed_name)) { 698 if (log_device_placement_) { 699 LOG(INFO) << "Ignoring device specification " 700 << DeviceNameUtils::ParsedNameToString(dest_parsed_name) 701 << " for node '" << dst->name() 702 << "' because the input edge from '" << src->name() 703 << "' is a reference connection and already has a device " 704 "field set to " 705 << DeviceNameUtils::ParsedNameToString( 706 source_parsed_name); 707 } 708 709 // Make 'dst' colocated with the source 710 dst_root.device_name = source_parsed_name; 711 } else { 712 bool source_subset_of_dest = DeviceNameUtils::IsSpecification( 713 source_parsed_name, dest_parsed_name); 714 bool dest_subset_of_source = DeviceNameUtils::IsSpecification( 715 dest_parsed_name, source_parsed_name); 716 717 if (source_subset_of_dest && !dest_subset_of_source) { 718 src_root.device_name = dest_parsed_name; 719 } else { 720 dst_root.device_name = source_parsed_name; 721 } 722 } 723 } 724 725 Status status = 726 colocation_graph.ColocateNodes(*src, src_root_id, *dst, dst_root_id); 727 if (!status.ok()) { 728 return AttachDef( 729 errors::InvalidArgument("Nodes were connected by a " 730 "reference connection (requiring them to " 731 "be on the same device), but the two nodes " 732 "were assigned two different devices: ", 733 status.error_message()), 734 *dst); 735 } 736 } 737 } 738 739 // 3. For each node, assign a device based on the constraints in the 740 // disjoint node set. 741 std::vector<Node*> second_pass; 742 for (Node* node : graph_->op_nodes()) { 743 // The graph may have come pre-populated by the framework with assigned 744 // devices (e.g., for stateful placements), so the placer should not try to 745 // place nodes that are already placed. 746 if (node->has_assigned_device_name()) { 747 LogDeviceAssignment(node); 748 continue; 749 } 750 751 // Heuristic A: prefer to place "generators" with their only 752 // consumers. 753 // 754 // If this is a node with no inputs and one output, we save 755 // this for a second pass, so that the consumer's placement 756 // is chosen. 757 if (IsGeneratorNode(node)) { 758 second_pass.push_back(node); 759 continue; 760 } 761 762 std::vector<Device*>* devices; 763 Status status = colocation_graph.GetDevicesForNode(node, &devices); 764 if (!status.ok()) { 765 return AttachDef( 766 errors::InvalidArgument("Cannot assign a device for operation '", 767 node->name(), "': ", status.error_message()), 768 *node); 769 } 770 771 // Returns the first device in sorted devices list so we will always 772 // choose the same device. 773 // 774 // TODO(vrv): Factor this assignment out into a pluggable 775 // algorithm, so that Placer is responsible for enforcing 776 // preconditions and we can experiment with other algorithms when 777 // given a choice of devices. Once we have a better idea of the 778 // types of heuristics we want to use and the information needed 779 // to perform good placement we can add an interface for this. 780 int assigned_device = -1; 781 782 // Heuristic B: If the node only operates on metadata, not data, 783 // then it is desirable to place that metadata node with its 784 // input. 785 if (IsMetadata(node)) { 786 // Make sure that the input device type is in the list of supported 787 // device types for this node. 788 const Node* input = (*node->in_edges().begin())->src(); 789 // TODO(vrv): if the input is empty, consider postponing this 790 // node's assignment to the second pass, so that we handle the 791 // case where a metadata node's input comes from a backedge 792 // of a loop. 793 if (CanAssignToDevice(input->assigned_device_name(), *devices)) { 794 assigned_device = input->assigned_device_name_index(); 795 } 796 } 797 798 // Provide the default, if necessary. 799 if (assigned_device == -1) { 800 assigned_device = graph_->InternDeviceName((*devices)[0]->name()); 801 } 802 803 AssignAndLog(assigned_device, node); 804 } 805 806 // 4. Perform a second pass assignment for those nodes explicitly 807 // skipped during the first pass. 808 for (Node* node : second_pass) { 809 std::vector<Device*>* devices; 810 Status status = colocation_graph.GetDevicesForNode(node, &devices); 811 if (!status.ok()) { 812 return AttachDef( 813 errors::InvalidArgument("Cannot assign a device for operation '", 814 node->name(), "': ", status.error_message()), 815 *node); 816 } 817 818 int assigned_device = -1; 819 820 // Heuristic A application. 821 if (IsGeneratorNode(node)) { 822 const Node* output = (*node->out_edges().begin())->dst(); 823 int output_device_name = output->assigned_device_name_index(); 824 825 const bool consumers_on_same_device = std::all_of( 826 node->out_edges().begin(), node->out_edges().end(), 827 [output_device_name](const Edge* e) { 828 return e->dst()->assigned_device_name_index() == output_device_name; 829 }); 830 831 if (consumers_on_same_device && 832 CanAssignToDevice(output->assigned_device_name(), *devices)) { 833 assigned_device = output_device_name; 834 } 835 } 836 837 // Provide the default, if necessary. 838 if (assigned_device == -1) { 839 assigned_device = graph_->InternDeviceName((*devices)[0]->name()); 840 } 841 842 AssignAndLog(assigned_device, node); 843 } 844 845 return Status::OK(); 846 } 847 848 bool Placer::CanAssignToDevice(const string& candidate_device_name, 849 const std::vector<Device*>& devices) const { 850 if (!candidate_device_name.empty()) { 851 // 'devices' lists the set of devices that the placer or the user has 852 // constrained the operation to. "candidate_device_name" must 853 // refer to a concrete Device that is in the list of 'devices'. 854 const Device* other_device = 855 devices_->FindDeviceByName(candidate_device_name); 856 if (std::find(devices.begin(), devices.end(), other_device) != 857 devices.end()) { 858 return true; 859 } 860 } 861 862 return false; 863 } 864 865 void Placer::AssignAndLog(int assigned_device, Node* node) const { 866 node->set_assigned_device_name_index(assigned_device); 867 LogDeviceAssignment(node); 868 } 869 870 void Placer::LogDeviceAssignment(const Node* node) const { 871 // Log placement if log_device_placement is set. 872 if (log_device_placement_) { 873 printf("%s: (%s): %s\n", node->name().c_str(), node->type_string().c_str(), 874 node->assigned_device_name().c_str()); 875 LOG(INFO) << node->name() << ": " 876 << "(" << node->type_string() << ")" 877 << node->assigned_device_name(); 878 } 879 } 880 881 } // namespace tensorflow 882