Home | History | Annotate | Download | only in common_runtime
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/common_runtime/placer.h"
     17 
     18 #include <memory>
     19 #include <set>
     20 #include <utility>
     21 #include <vector>
     22 
     23 #include "tensorflow/core/common_runtime/device.h"
     24 #include "tensorflow/core/framework/device_attributes.pb.h"
     25 #include "tensorflow/core/framework/graph.pb.h"
     26 #include "tensorflow/core/framework/node_def_util.h"
     27 #include "tensorflow/core/framework/op_kernel.h"
     28 #include "tensorflow/core/framework/types.h"
     29 #include "tensorflow/core/framework/types.pb.h"
     30 #include "tensorflow/core/lib/core/errors.h"
     31 #include "tensorflow/core/lib/core/stringpiece.h"
     32 
     33 namespace tensorflow {
     34 
     35 namespace {
     36 
     37 // We hoist the conversion from C-style string literal to StringPiece here,
     38 // so that we can avoid the many repeated calls to strlen().
     39 const StringPiece kColocationAttrNameStringPiece(kColocationAttrName);
     40 const StringPiece kColocationGroupPrefixStringPiece(kColocationGroupPrefix);
     41 
     42 // Returns a list of devices sorted by preferred type and then name
     43 // from 'devices' whose type is in 'supported_device_types'.  This
     44 // function searches the device types in 'supported_device_types' and
     45 // returns the subset of devices that match.
     46 std::vector<Device*> FilterSupportedDevices(
     47     const std::vector<Device*>& devices,
     48     const DeviceTypeVector& supported_device_types) {
     49   std::vector<Device*> filtered_devices;
     50   for (const DeviceType& d : supported_device_types) {
     51     for (Device* device : devices) {
     52       if (DeviceType(device->attributes().device_type()) == d) {
     53         filtered_devices.emplace_back(device);
     54       }
     55     }
     56   }
     57 
     58   auto device_sort = [](const Device* a, const Device* b) {
     59     auto a_priority = DeviceSet::DeviceTypeOrder(DeviceType(a->device_type()));
     60     auto b_priority = DeviceSet::DeviceTypeOrder(DeviceType(b->device_type()));
     61     // First sort by prioritized device type (higher is preferred) and
     62     // then by device name (lexicographically).
     63     if (a_priority != b_priority) {
     64       return a_priority > b_priority;
     65     }
     66     return StringPiece(a->name()) < StringPiece(b->name());
     67   };
     68   std::sort(filtered_devices.begin(), filtered_devices.end(), device_sort);
     69   return filtered_devices;
     70 }
     71 
     72 // This class maintains the connected components of a colocation
     73 // constraint graph, and uses this information to assign a satisfying
     74 // device placement to the nodes of the graph.
     75 //
     76 // The typical usage pattern is:
     77 //
     78 //   Graph graph = ...;
     79 //   DeviceSet device_set = ...;
     80 //   ColocationGraph colocation_graph(graph, device_set);
     81 //
     82 //   // Add all the nodes of graph to colocation_graph.
     83 //   for (Node* node : graph.nodes()) {
     84 //     TF_RETURN_IF_ERROR(colocation_graph.AddNode(*node));
     85 //   }
     86 //
     87 //   // Add one or more colocation constraint.
     88 //   Node node_1 = *graph.FindNodeId(...);
     89 //   Node node_2 = *graph.FindNodeId(...);
     90 //   TF_RETURN_IF_ERROR(colocation_graph.ColocateNodes(node_1, node_2));
     91 //
     92 //   // Assign devices based on the accumulated constraints.
     93 //   for (Node* node : graph.nodes()) {
     94 //     TF_RETURN_IF_ERROR(colocation_graph.AssignDevice(node));
     95 //   }
     96 //
     97 // The implementation uses the union-find algorithm to maintain the
     98 // connected components efficiently and incrementally as edges
     99 // (implied by ColocationGraph::ColocateNodes() invocations) are added.
    100 class ColocationGraph {
    101  public:
    102   ColocationGraph(Graph* graph, const DeviceSet* device_set,
    103                   bool allow_soft_placement)
    104       : graph_(graph),
    105         device_set_(device_set),
    106         device_types_(device_set->PrioritizedDeviceTypeList()),
    107         allow_soft_placement_(allow_soft_placement) {
    108     members_.resize(graph->num_node_ids());
    109   }
    110 
    111   // Adds each node of the Graph to this ColocationGraph as a singleton.
    112   //
    113   // NOTE: The implementation assumes that the ids of nodes passed to
    114   // this method are dense and zero-based; the memory used will be linear in
    115   // the largest node ID.
    116   // NOTE: If this method returns an error, *this is left in an undefined
    117   // state.
    118   Status ColocateAllNodes() {
    119     // This maps from a colocation group identifier to the 'root' of that
    120     // colocation group.  Note that the keys in this map are StringPiece; the
    121     // actual strings are stored under the NodeDef.  The lifetime of this map
    122     // is limited to this ColocateAllNodes() method, and no part of the
    123     // NodeDef trees are changed during the lifetime of this method, so using
    124     // StringPiece as a key is safe.
    125     //
    126     // Also, as a further optimization, we remove the "loc:@" prefix from
    127     // "class" attribute values, when they are used as keys in this table.
    128     // This allows us to use StringPiece values that refer to substrings of
    129     // 'string' values stored in NodeDef attribute lists, as well as StringPiece
    130     // values that refer to 'string' values from NodeDef::name(), without
    131     // performing any string allocations.
    132     std::unordered_map<StringPiece, const Node*, StringPieceHasher>
    133         colocation_group_root;
    134 
    135     for (Node* node : graph_->nodes()) {
    136       if (!node->IsOp()) {
    137         continue;
    138       }
    139 
    140       // When adding the node, identify whether it is part of a
    141       // colocation group.
    142 
    143       // This code is effectively the equivalent of GetNodeAttr() for a string
    144       // array, but it avoids all internal allocations (the allocation of the
    145       // backing store of the std::vector<string> as well as the copies of the
    146       // strings within it).  Instead, we combine the query of the colocation
    147       // attribute with the calls to ColocateNodeToGroup.
    148       bool found_spec = false;
    149       const AttrValue* attr_value =
    150           node->attrs().Find(kColocationAttrNameStringPiece);
    151       if (attr_value != nullptr && attr_value->has_list()) {
    152         for (const string& class_spec : attr_value->list().s()) {
    153           StringPiece spec(class_spec);
    154           if (spec.Consume(kColocationGroupPrefixStringPiece)) {
    155             found_spec = true;
    156             TF_RETURN_IF_ERROR(
    157                 ColocateNodeToGroup(&colocation_group_root, node, spec));
    158           }
    159         }
    160       }
    161 
    162       if (!found_spec) {
    163         // If the node does not specify a colocation group, then use the
    164         // name of this node as the colocation group.
    165         TF_RETURN_IF_ERROR(
    166             ColocateNodeToGroup(&colocation_group_root, node, node->name()));
    167       }
    168     }
    169 
    170     return Status::OK();
    171   }
    172 
    173   Status ColocateNodeToGroup(
    174       std::unordered_map<StringPiece, const Node*, StringPieceHasher>*
    175           colocation_group_root,
    176       Node* node, StringPiece colocation_group) {
    177     const Node*& root_node = (*colocation_group_root)[colocation_group];
    178     if (root_node == nullptr) {
    179       // This is the first node of the colocation group, so
    180       // designate this node as the 'root' of that colocation group.
    181       root_node = node;
    182     } else {
    183       // Try to colocate the node with the root.  If there is an
    184       // error, return it.
    185       Status s = ColocateNodes(*node, *root_node);
    186       if (!s.ok()) {
    187         return AttachDef(s, *node);
    188       }
    189     }
    190     return Status::OK();
    191   }
    192 
    193   // Merge the (possibly disjoint) sets containing nodes "x" and
    194   // "y". Returns OK if the all nodes in the union of these sets can
    195   // be placed on the same device type.
    196   //
    197   // NOTE: If this method returns an error, *this is left in an undefined
    198   // state.
    199   Status ColocateNodes(const Node& x, const Node& y) {
    200     int x_root = FindRoot(x.id());
    201     int y_root = FindRoot(y.id());
    202     return ColocateNodes(x, x_root, y, y_root);
    203   }
    204 
    205   // This overload of ColocateNodes() allows a caller to provide the root node
    206   // ids for the two nodes. For large graphs, this noticeably reduces the
    207   // graph load time.
    208   Status ColocateNodes(const Node& x, int x_root, const Node& y, int y_root) {
    209     if (x_root == y_root) {
    210       return Status::OK();
    211     }
    212 
    213     DCHECK_EQ(x_root, FindRoot(x.id()));
    214     DCHECK_EQ(y_root, FindRoot(y.id()));
    215 
    216     Member& x_root_member = members_[x_root];
    217     Member& y_root_member = members_[y_root];
    218 
    219     // Merge the sets by swinging the parent pointer of the smaller
    220     // tree to point to the root of the larger tree. Together with
    221     // path compression in ColocationGraph::FindRoot, this ensures
    222     // that we do not experience pathological performance on graphs
    223     // such as chains.
    224     int new_root, old_root;
    225     if (x_root_member.rank < y_root_member.rank) {
    226       // The tree rooted at x_root is shallower, so connect it to
    227       // y_root. The rank of y_root is unchanged because its new
    228       // child has strictly less rank.
    229       x_root_member.parent = y_root;
    230       new_root = y_root;
    231       old_root = x_root;
    232     } else if (x_root_member.rank > y_root_member.rank) {
    233       // The tree rooted at y_root is shallower, so connect it to
    234       // x_root. The rank of x_root is unchanged because its new
    235       // child has strictly less rank.
    236       y_root_member.parent = x_root;
    237       new_root = x_root;
    238       old_root = y_root;
    239     } else {
    240       // Both trees have the same rank, so break the tie by choosing
    241       // x_root as the new root.
    242       y_root_member.parent = x_root;
    243       // Increment the rank of the tree rooted at x_root, because it
    244       // is now strictly deeper than before.
    245       ++x_root_member.rank;
    246       new_root = x_root;
    247       old_root = y_root;
    248     }
    249 
    250     Member& new_root_member = members_[new_root];
    251     Member& old_root_member = members_[old_root];
    252 
    253     // Merge the partial device specifications, and ensure that they are
    254     // compatible. NULL options_ is treated as allowing soft placement.
    255     // TODO(mrry): Consider enriching the error message by pointing
    256     // out which nodes have the explicit partial device
    257     // specifications that caused this conflict.
    258     Status s = DeviceNameUtils::MergeDevNames(&new_root_member.device_name,
    259                                               old_root_member.device_name,
    260                                               allow_soft_placement_);
    261     if (!s.ok()) {
    262       return errors::InvalidArgument("Cannot colocate nodes '", x.name(),
    263                                      "' and '", y.name(), ": ",
    264                                      s.error_message());
    265     }
    266 
    267     // Ensure that the common root has at least one supported device
    268     // type, by computing the intersection of
    269     // new_root_member.supported_device_types and
    270     // old_root_member.supported_device_types.
    271     MergeSupportedDevices(&new_root_member.supported_device_types,
    272                           old_root_member.supported_device_types);
    273     if (new_root_member.supported_device_types.empty()) {
    274       return errors::InvalidArgument(
    275           "Cannot colocate nodes '", x.name(), "' and '", y.name(),
    276           "' because no device type supports both of those nodes and the "
    277           "other nodes colocated with them.",
    278           DebugInfo(x_root), DebugInfo(y_root));
    279     }
    280 
    281     return Status::OK();
    282   }
    283 
    284   // For the given node, subject to the constraints previously given
    285   // to this ColocationGraph, set its assigned_device_name. Returns OK
    286   // if a satisfying device can be found, otherwise an error.
    287   //
    288   // Note: This method returns a pointer to a field within members_.
    289   // The caller must not use the returned pointer after there is any possibility
    290   // that the members_[i].possible_devices field has been modified.
    291   Status GetDevicesForNode(Node* node,
    292                            std::vector<Device*>** possible_devices) {
    293     *possible_devices = nullptr;
    294     const int node_root = FindRoot(node->id());
    295     if (!members_[node_root].possible_devices.empty()) {
    296       *possible_devices = &members_[node_root].possible_devices;
    297       return Status::OK();
    298     }
    299 
    300     // We have not yet computed the possible devices for the
    301     // colocated node set containing 'node', so we do so now using the
    302     // constraints on the root node.
    303 
    304     // "devices" will contain the set of feasible placements for the
    305     // colocated node set containing 'node'.
    306     std::vector<Device*> devices;
    307     if (DeviceNameUtils::HasSomeDetails(members_[node_root].device_name)) {
    308       // The root node has a (possibly partial) device
    309       // specification, so enumerate the physical devices that
    310       // conform to it.
    311       device_set_->FindMatchingDevices(members_[node_root].device_name,
    312                                        &devices);
    313 
    314       if (!devices.empty()) {
    315         // Filter devices into those that are compatible with the root
    316         // node (and its children).
    317         devices = FilterSupportedDevices(
    318             devices, members_[node_root].supported_device_types);
    319       }
    320 
    321       // Perform soft placement if allow_soft_placement_ is set.
    322       if (devices.empty() && allow_soft_placement_) {
    323         // The soft_device_name is the same as the node's device name
    324         // without specifying the device type or ID.
    325         DeviceNameUtils::ParsedName soft_device_name =
    326             members_[node_root].device_name;
    327         soft_device_name.type.clear();
    328         soft_device_name.has_type = false;
    329         soft_device_name.has_id = false;
    330         device_set_->FindMatchingDevices(soft_device_name, &devices);
    331         if (!devices.empty()) {
    332           devices = FilterSupportedDevices(
    333               devices, members_[node_root].supported_device_types);
    334         }
    335       }
    336 
    337       if (devices.empty()) {
    338         // Return an error when a physical device that matches an explicit
    339         // device specification is not found. This ensures that we don't
    340         // assign a node to GPU when the user wanted to force it on CPU.
    341         string debug_info = DebugInfo(node_root);
    342 
    343         DeviceNameUtils::ParsedName specified_device_name;
    344         if (DeviceNameUtils::ParseFullName(node->requested_device(),
    345                                            &specified_device_name) &&
    346             specified_device_name == members_[node_root].device_name) {
    347           // The specified device and merged set device match, and
    348           // will appear in the GraphDef (for debugging), so just
    349           // print the specified device.
    350           std::vector<Device*> devices_matching_nodedef;
    351           device_set_->FindMatchingDevices(specified_device_name,
    352                                            &devices_matching_nodedef);
    353           if (devices_matching_nodedef.empty()) {
    354             // Sometimes it is almost impossible to understand the problem
    355             // without a list of available devices.
    356             std::vector<string> device_names;
    357             for (const Device* device : device_set_->devices()) {
    358               device_names.push_back(device->name());
    359             }
    360             std::sort(device_names.begin(), device_names.end());
    361 
    362             return errors::InvalidArgument(
    363                 "Operation was explicitly assigned to ",
    364                 node->requested_device(), " but available devices are [ ",
    365                 str_util::Join(device_names, ", "), " ]. Make sure ",
    366                 "the device specification refers to a valid device.");
    367           } else if (specified_device_name.has_type) {
    368             return errors::InvalidArgument(
    369                 "Could not satisfy explicit device specification '",
    370                 node->requested_device(), "' because no supported kernel for ",
    371                 specified_device_name.type, " devices is available.",
    372                 debug_info, "\nRegistered kernels:\n",
    373                 KernelsRegisteredForOp(node->type_string()));
    374           } else {
    375             return errors::InvalidArgument(
    376                 "Could not satisfy explicit device specification '",
    377                 node->requested_device(), debug_info);
    378           }
    379         } else {
    380           // The specified device may be a valid device but the
    381           // merged set device is different, so print both.
    382           return errors::InvalidArgument(
    383               "Could not satisfy explicit device specification '",
    384               node->requested_device(),
    385               "' because the node was colocated with a group of nodes that "
    386               "required incompatible device '",
    387               DeviceNameUtils::ParsedNameToString(
    388                   members_[node_root].device_name),
    389               "'", debug_info);
    390         }
    391       }
    392     } else {
    393       // The device is completely unspecified, so enumerate the devices that
    394       // support all of the nodes in the set.
    395       if (device_set_->devices().empty()) {
    396         return errors::Internal("No devices are registered");
    397       }
    398       devices = FilterSupportedDevices(
    399           device_set_->devices(), members_[node_root].supported_device_types);
    400 
    401       if (devices.empty()) {
    402         return errors::InvalidArgument(
    403             "Node had no OpKernel registered to support this operation: ",
    404             "Operation was ", node->type_string(), " and inputs were ",
    405             DataTypeVectorString(node->input_types()), DebugInfo(node_root));
    406       }
    407     }
    408 
    409     // Cache the result of the possible devices for this node group.
    410     members_[node_root].possible_devices = std::move(devices);
    411     *possible_devices = &members_[node_root].possible_devices;
    412     return Status::OK();
    413   }
    414 
    415   Status InitializeMembers() {
    416     for (Node* node : graph_->nodes()) {
    417       if (!node->IsOp()) {
    418         continue;
    419       }
    420       Status status = InitializeMember(*node, &members_[node->id()]);
    421       if (!status.ok()) {
    422         return AttachDef(status, *node);
    423       }
    424     }
    425     return Status::OK();
    426   }
    427 
    428   // Represents a node in the disjoint node set forest, and the
    429   // accumulated constraints on the device used by that node.
    430   struct Member {
    431     Member() = default;
    432     // The id of the node that is the parent of this one, or its own
    433     // id if it is a root. parent <= 0 indicates that this member is invalid.
    434     int parent = -1;
    435 
    436     // A proxy for the depth of the tree that is used to prefer
    437     // connecting smaller trees to larger trees when merging disjoint
    438     // sets.
    439     int rank = 0;
    440 
    441     // The intersection of all device types supported by this node,
    442     // and those of all of its children, in priority order
    443     // of the preferred device.
    444     DeviceTypeVector supported_device_types;
    445 
    446     // The merged form of the device requested for this node, with
    447     // those of all of its children.
    448     DeviceNameUtils::ParsedName device_name;
    449 
    450     // If this node is a root, stores a list of Devices to which this node
    451     // and all of its children have been assigned, or nullptr if this
    452     // has not yet been computed.
    453     std::vector<Device*> possible_devices;
    454   };
    455 
    456   // Returns debugging info for the node referred to by 'node_root'.
    457   string DebugInfo(const int node_root) {
    458     string text(
    459         "\nColocation Debug Info:\n"
    460         "Colocation group had the following types and devices: ");
    461 
    462     // If this node is part of a colocation group, then we want to
    463     // collect the mapping of ops to supported devices, so that
    464     // the user can see why an unsatisfiable placement occurred.
    465 
    466     std::unordered_map<string, string> type_to_devices;
    467     int num_nodes_found = 0;
    468 
    469     for (const Node* node : graph_->nodes()) {
    470       if (!node->IsOp()) {
    471         continue;
    472       }
    473       int id = node->id();
    474       if (FindRoot(id) != node_root) {
    475         continue;
    476       }
    477       ++num_nodes_found;
    478       const string& op_type = node->type_string();
    479       string devices_registered;
    480       for (const auto& device_type : members_[id].supported_device_types) {
    481         strings::StrAppend(&devices_registered, DeviceTypeString(device_type),
    482                            " ");
    483       }
    484 
    485       type_to_devices[op_type] = std::move(devices_registered);
    486     }
    487 
    488     for (const auto& td : type_to_devices) {
    489       strings::StrAppend(&text, "\n", td.first, ": ", td.second);
    490     }
    491 
    492     if (num_nodes_found <= 1) {
    493       text.clear();
    494     }
    495     return text;
    496   }
    497 
    498   Status InitializeMember(const Node& node, Member* member) {
    499     const int id = node.id();
    500     DCHECK_GE(id, 0);
    501     member->parent = id;
    502     TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode(
    503         device_types_, node.def(), &member->supported_device_types));
    504 
    505     if (node.has_assigned_device_name()) {
    506       // This node has already been assigned to a device, so we
    507       // respect this placement, after sanity-checking it.  The
    508       // device_name and supported_device_types for this node reflect
    509       // the assigned device, so any nodes colocated with this node
    510       // will be assigned to the same device (assuming this is
    511       // possible).
    512       // NOTE: Since any assignment must have been performed by
    513       // the TensorFlow runtime, we consider errors in this branch to
    514       // be INTERNAL.
    515       const string& assigned_device_name = node.assigned_device_name();
    516       if (!DeviceNameUtils::ParseFullName(assigned_device_name,
    517                                           &member->device_name)) {
    518         return errors::Internal("Malformed assigned device '",
    519                                 assigned_device_name, "'");
    520       }
    521       const Device* assigned_device =
    522           device_set_->FindDeviceByName(assigned_device_name);
    523       if (assigned_device == nullptr) {
    524         return errors::Internal("Assigned device '", assigned_device_name,
    525                                 "' does not match any device");
    526       }
    527 
    528       for (const DeviceType& d : member->supported_device_types) {
    529         if (DeviceType(assigned_device->attributes().device_type()) == d) {
    530           return Status::OK();
    531         }
    532       }
    533 
    534       return errors::Internal("Assigned device '", assigned_device_name,
    535                               "' does not have registered OpKernel support "
    536                               "for ",
    537                               node.type_string());
    538     } else {
    539       // This node has not yet been assigned to a device, so we
    540       // calculate any constraints due to the set of registered
    541       // kernels and any (partial) user-provided device specification
    542       // in the NodeDef.
    543 
    544       // If no kernels are registered for this op type, fail with an error.
    545       if (member->supported_device_types.empty()) {
    546         std::set<string> registered_device_types;
    547         for (Device* d : device_set_->devices()) {
    548           registered_device_types.insert(d->device_type());
    549         }
    550         return errors::InvalidArgument(
    551             "No OpKernel was registered to support Op '", node.type_string(),
    552             "' with these attrs.  Registered devices: [",
    553             str_util::Join(registered_device_types, ","),
    554             "], Registered kernels:\n",
    555             KernelsRegisteredForOp(node.type_string()));
    556       }
    557 
    558       // If the NodeDef contains a device, then we interpret it as a
    559       // (partial) device specification.
    560       if (!node.requested_device().empty()) {
    561         // The user has specified a device in the NodeDef, try to find a
    562         // valid device matching their specification in the set of
    563         // devices.
    564         // NOTE: The full name may specify a device that is not in
    565         // n.supported_device_types(), but we check that in AssignDevice().
    566         if (!DeviceNameUtils::ParseFullName(node.requested_device(),
    567                                             &member->device_name)) {
    568           return errors::InvalidArgument("Malformed device specification '",
    569                                          node.requested_device(), "'");
    570         }
    571       }
    572     }
    573     return Status::OK();
    574   }
    575 
    576   // Updates target to contain the intersection of the device types in
    577   // "target" and "other".
    578   static void MergeSupportedDevices(DeviceTypeVector* target,
    579                                     const DeviceTypeVector& other) {
    580     DeviceTypeVector temp = *target;
    581     target->clear();
    582 
    583     // Iterate in priority order.
    584     for (const DeviceType& device_type : temp) {
    585       bool found = false;
    586       for (const DeviceType& other_device_type : other) {
    587         if (device_type == other_device_type) {
    588           found = true;
    589           break;
    590         }
    591       }
    592       if (found) {
    593         target->push_back(device_type);
    594       }
    595     }
    596   }
    597 
    598   // Returns the root node of the disjoint tree to which the node with the
    599   // given id is connected.
    600   int FindRoot(int node_id) {
    601     Member& member = members_[node_id];
    602 
    603     int parent = member.parent;
    604     DCHECK_GE(parent, 0);
    605 
    606     if (parent != node_id) {
    607       // NOTE: Compress paths from node_id to its root, so that future
    608       // calls to FindRoot and ColocateNodes are more efficient.
    609       int root = FindRoot(parent);
    610       if (parent != root) {
    611         parent = root;
    612         member.parent = root;
    613       }
    614     }
    615 
    616     DCHECK_GE(parent, 0);
    617     return parent;
    618   }
    619 
    620   Graph* const graph_;  // Not owned.
    621   std::vector<Member> members_;
    622   const DeviceSet* device_set_;  // Not owned.
    623   const std::vector<DeviceType> device_types_;
    624   const bool allow_soft_placement_;
    625 };
    626 
    627 // Returns true if the node has no inputs and produces outputs
    628 // that are consumed by a single node.
    629 //
    630 // TODO(vrv): Currently this handles only nodes with one output, but
    631 // this could be extended to handle the case where a node has many
    632 // outputs that are connected to nodes in the same colocation group.
    633 bool IsGeneratorNode(const Node* node) {
    634   return node->num_inputs() == 0 && node->num_outputs() == 1 &&
    635          !IsRefType(node->output_type(0));
    636 }
    637 
    638 }  // namespace
    639 
    640 Placer::Placer(Graph* graph, const DeviceSet* devices,
    641                const SessionOptions* options)
    642     : graph_(graph),
    643       devices_(devices),
    644       options_(options),
    645       log_device_placement_(options != nullptr &&
    646                             options->config.log_device_placement()) {}
    647 
    648 Placer::Placer(Graph* graph, const DeviceSet* devices)
    649     : Placer(graph, devices, nullptr) {}
    650 
    651 Placer::~Placer() {}
    652 
    653 Status Placer::Run() {
    654   if (devices_->devices().empty()) {
    655     return errors::FailedPrecondition("No devices are registered");
    656   }
    657 
    658   ColocationGraph colocation_graph(
    659       graph_, devices_,
    660       options_ == nullptr || options_->config.allow_soft_placement());
    661 
    662   TF_RETURN_IF_ERROR(colocation_graph.InitializeMembers());
    663 
    664   // 1. First add all of the nodes. Note that steps (1) and (2)
    665   // requires two passes over the nodes because the graph (and hence
    666   // the constraints) may not be acyclic.
    667   TF_RETURN_IF_ERROR(colocation_graph.ColocateAllNodes());
    668 
    669   // 2. Enumerate the constraint edges, and use them to update the disjoint
    670   // node set.
    671 
    672   // If `node` has an input edge with reference type, add an
    673   // edge from the source of that edge to `node`.
    674   for (const Edge* edge : graph_->edges()) {
    675     if (edge->IsControlEdge()) {
    676       continue;
    677     }
    678     Node* src = edge->src();
    679     Node* dst = edge->dst();
    680     DataType input_type = dst->input_type(edge->dst_input());
    681     if (input_type == DT_RESOURCE || IsRefType(input_type)) {
    682       int src_root_id = colocation_graph.FindRoot(src->id());
    683       int dst_root_id = colocation_graph.FindRoot(dst->id());
    684       auto& src_root = colocation_graph.members_[src_root_id];
    685       auto& dst_root = colocation_graph.members_[dst_root_id];
    686       // If both the source node and this node have partially
    687       // specified a device, then 'node's device should be
    688       // cleared: the reference edge forces 'node' to be on the
    689       // same device as the source node.
    690       const auto& source_parsed_name = src_root.device_name;
    691       const auto& dest_parsed_name = dst_root.device_name;
    692       if (DeviceNameUtils::HasSomeDetails(source_parsed_name) &&
    693           DeviceNameUtils::HasSomeDetails(dest_parsed_name)) {
    694         // Ignore a specified device for 'dst' if the two names were
    695         // incompatible.
    696         if (!DeviceNameUtils::AreCompatibleDevNames(source_parsed_name,
    697                                                     dest_parsed_name)) {
    698           if (log_device_placement_) {
    699             LOG(INFO) << "Ignoring device specification "
    700                       << DeviceNameUtils::ParsedNameToString(dest_parsed_name)
    701                       << " for node '" << dst->name()
    702                       << "' because the input edge from '" << src->name()
    703                       << "' is a reference connection and already has a device "
    704                          "field set to "
    705                       << DeviceNameUtils::ParsedNameToString(
    706                              source_parsed_name);
    707           }
    708 
    709           // Make 'dst' colocated with the source
    710           dst_root.device_name = source_parsed_name;
    711         } else {
    712           bool source_subset_of_dest = DeviceNameUtils::IsSpecification(
    713               source_parsed_name, dest_parsed_name);
    714           bool dest_subset_of_source = DeviceNameUtils::IsSpecification(
    715               dest_parsed_name, source_parsed_name);
    716 
    717           if (source_subset_of_dest && !dest_subset_of_source) {
    718             src_root.device_name = dest_parsed_name;
    719           } else {
    720             dst_root.device_name = source_parsed_name;
    721           }
    722         }
    723       }
    724 
    725       Status status =
    726           colocation_graph.ColocateNodes(*src, src_root_id, *dst, dst_root_id);
    727       if (!status.ok()) {
    728         return AttachDef(
    729             errors::InvalidArgument("Nodes were connected by a "
    730                                     "reference connection (requiring them to "
    731                                     "be on the same device), but the two nodes "
    732                                     "were assigned two different devices: ",
    733                                     status.error_message()),
    734             *dst);
    735       }
    736     }
    737   }
    738 
    739   // 3. For each node, assign a device based on the constraints in the
    740   // disjoint node set.
    741   std::vector<Node*> second_pass;
    742   for (Node* node : graph_->op_nodes()) {
    743     // The graph may have come pre-populated by the framework with assigned
    744     // devices (e.g., for stateful placements), so the placer should not try to
    745     // place nodes that are already placed.
    746     if (node->has_assigned_device_name()) {
    747       LogDeviceAssignment(node);
    748       continue;
    749     }
    750 
    751     // Heuristic A: prefer to place "generators" with their only
    752     // consumers.
    753     //
    754     // If this is a node with no inputs and one output, we save
    755     // this for a second pass, so that the consumer's placement
    756     // is chosen.
    757     if (IsGeneratorNode(node)) {
    758       second_pass.push_back(node);
    759       continue;
    760     }
    761 
    762     std::vector<Device*>* devices;
    763     Status status = colocation_graph.GetDevicesForNode(node, &devices);
    764     if (!status.ok()) {
    765       return AttachDef(
    766           errors::InvalidArgument("Cannot assign a device for operation '",
    767                                   node->name(), "': ", status.error_message()),
    768           *node);
    769     }
    770 
    771     // Returns the first device in sorted devices list so we will always
    772     // choose the same device.
    773     //
    774     // TODO(vrv): Factor this assignment out into a pluggable
    775     // algorithm, so that Placer is responsible for enforcing
    776     // preconditions and we can experiment with other algorithms when
    777     // given a choice of devices. Once we have a better idea of the
    778     // types of heuristics we want to use and the information needed
    779     // to perform good placement we can add an interface for this.
    780     int assigned_device = -1;
    781 
    782     // Heuristic B: If the node only operates on metadata, not data,
    783     // then it is desirable to place that metadata node with its
    784     // input.
    785     if (IsMetadata(node)) {
    786       // Make sure that the input device type is in the list of supported
    787       // device types for this node.
    788       const Node* input = (*node->in_edges().begin())->src();
    789       // TODO(vrv): if the input is empty, consider postponing this
    790       // node's assignment to the second pass, so that we handle the
    791       // case where a metadata node's input comes from a backedge
    792       // of a loop.
    793       if (CanAssignToDevice(input->assigned_device_name(), *devices)) {
    794         assigned_device = input->assigned_device_name_index();
    795       }
    796     }
    797 
    798     // Provide the default, if necessary.
    799     if (assigned_device == -1) {
    800       assigned_device = graph_->InternDeviceName((*devices)[0]->name());
    801     }
    802 
    803     AssignAndLog(assigned_device, node);
    804   }
    805 
    806   // 4. Perform a second pass assignment for those nodes explicitly
    807   // skipped during the first pass.
    808   for (Node* node : second_pass) {
    809     std::vector<Device*>* devices;
    810     Status status = colocation_graph.GetDevicesForNode(node, &devices);
    811     if (!status.ok()) {
    812       return AttachDef(
    813           errors::InvalidArgument("Cannot assign a device for operation '",
    814                                   node->name(), "': ", status.error_message()),
    815           *node);
    816     }
    817 
    818     int assigned_device = -1;
    819 
    820     // Heuristic A application.
    821     if (IsGeneratorNode(node)) {
    822       const Node* output = (*node->out_edges().begin())->dst();
    823       int output_device_name = output->assigned_device_name_index();
    824 
    825       const bool consumers_on_same_device = std::all_of(
    826           node->out_edges().begin(), node->out_edges().end(),
    827           [output_device_name](const Edge* e) {
    828             return e->dst()->assigned_device_name_index() == output_device_name;
    829           });
    830 
    831       if (consumers_on_same_device &&
    832           CanAssignToDevice(output->assigned_device_name(), *devices)) {
    833         assigned_device = output_device_name;
    834       }
    835     }
    836 
    837     // Provide the default, if necessary.
    838     if (assigned_device == -1) {
    839       assigned_device = graph_->InternDeviceName((*devices)[0]->name());
    840     }
    841 
    842     AssignAndLog(assigned_device, node);
    843   }
    844 
    845   return Status::OK();
    846 }
    847 
    848 bool Placer::CanAssignToDevice(const string& candidate_device_name,
    849                                const std::vector<Device*>& devices) const {
    850   if (!candidate_device_name.empty()) {
    851     // 'devices' lists the set of devices that the placer or the user has
    852     // constrained the operation to.  "candidate_device_name" must
    853     // refer to a concrete Device that is in the list of 'devices'.
    854     const Device* other_device =
    855         devices_->FindDeviceByName(candidate_device_name);
    856     if (std::find(devices.begin(), devices.end(), other_device) !=
    857         devices.end()) {
    858       return true;
    859     }
    860   }
    861 
    862   return false;
    863 }
    864 
    865 void Placer::AssignAndLog(int assigned_device, Node* node) const {
    866   node->set_assigned_device_name_index(assigned_device);
    867   LogDeviceAssignment(node);
    868 }
    869 
    870 void Placer::LogDeviceAssignment(const Node* node) const {
    871   // Log placement if log_device_placement is set.
    872   if (log_device_placement_) {
    873     printf("%s: (%s): %s\n", node->name().c_str(), node->type_string().c_str(),
    874            node->assigned_device_name().c_str());
    875     LOG(INFO) << node->name() << ": "
    876               << "(" << node->type_string() << ")"
    877               << node->assigned_device_name();
    878   }
    879 }
    880 
    881 }  // namespace tensorflow
    882