Home | History | Annotate | Download | only in graph
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/graph/graph_partition.h"
     17 
     18 #include <deque>
     19 #include <queue>
     20 #include <unordered_map>
     21 #include <unordered_set>
     22 #include <utility>
     23 #include <vector>
     24 
     25 #include "absl/container/flat_hash_map.h"
     26 #include "tensorflow/core/framework/function.h"
     27 #include "tensorflow/core/framework/memory_types.h"
     28 #include "tensorflow/core/framework/node_def_builder.h"
     29 #include "tensorflow/core/framework/tensor.pb.h"
     30 #include "tensorflow/core/framework/types.h"
     31 #include "tensorflow/core/framework/versions.pb.h"
     32 #include "tensorflow/core/graph/algorithm.h"
     33 #include "tensorflow/core/graph/control_flow.h"
     34 #include "tensorflow/core/graph/costmodel.h"
     35 #include "tensorflow/core/graph/graph_def_builder.h"
     36 #include "tensorflow/core/graph/node_builder.h"
     37 #include "tensorflow/core/graph/tensor_id.h"
     38 #include "tensorflow/core/lib/core/errors.h"
     39 #include "tensorflow/core/lib/hash/hash.h"
     40 #include "tensorflow/core/lib/strings/str_util.h"
     41 #include "tensorflow/core/platform/logging.h"
     42 #include "tensorflow/core/util/device_name_utils.h"
     43 
     44 namespace tensorflow {
     45 
     46 namespace {
     47 
     48 inline bool IsMerge(const NodeDef& node_def) {
     49   return node_def.op() == "Merge" || node_def.op() == "RefMerge";
     50 }
     51 
     52 inline bool IsNextIteration(const NodeDef& node_def) {
     53   return node_def.op() == "NextIteration" ||
     54          node_def.op() == "RefNextIteration";
     55 }
     56 
     57 struct DupRecvKey {
     58   int src_node_id;           // Edge's src node id
     59   int src_output_slot;       // Edge's src node output slot
     60   GraphDef* dst_graph;       // Edge's dst node is in this subgraph
     61   bool recv_output_on_host;  // The output of recv is on host
     62 
     63   template <typename H>
     64   friend H AbslHashValue(H h, const DupRecvKey& c) {
     65     return H::combine(std::move(h), c.src_node_id, c.src_output_slot,
     66                       reinterpret_cast<std::uintptr_t>(c.dst_graph),
     67                       c.recv_output_on_host);
     68   }
     69 
     70   friend bool operator==(const DupRecvKey& x, const DupRecvKey& y) {
     71     return (x.src_node_id == y.src_node_id) &&
     72            (x.src_output_slot == y.src_output_slot) &&
     73            (x.dst_graph == y.dst_graph) &&
     74            (x.recv_output_on_host == y.recv_output_on_host);
     75   }
     76 };
     77 
     78 // struct used to store the recvs, so that start times can be properly updated
     79 struct RecvInfo {
     80   NodeDef* recv;
     81   NodeDef* real_recv;
     82   int64 start_time;
     83 };
     84 
     85 typedef absl::flat_hash_map<DupRecvKey, RecvInfo> DupRecvTable;
     86 
     87 // A map used to store memory types for the inputs/outputs of every node.
     88 // The key is a pair of ints consisting of a node id and input/output index.
     89 // TODO(power): migrate back to std::pair when absl::Hash is fixed for MSVC.
     90 struct NodePort {
     91   int node_id;
     92   int index;
     93 
     94   friend bool operator==(const NodePort& x, const NodePort& y) {
     95     return x.node_id == y.node_id && x.index == y.index;
     96   }
     97 
     98   template <typename H>
     99   friend H AbslHashValue(H h, const NodePort& c) {
    100     return H::combine(std::move(h), c.node_id, c.index);
    101   }
    102 };
    103 
    104 typedef absl::flat_hash_map<NodePort, MemoryType> MemoryTypeMap;
    105 
    106 // We collect the following information about the graph before performing
    107 // graph partitioning.
    108 struct GraphInfo {
    109   std::vector<DeviceType> device_types;
    110   MemoryTypeMap input_types;
    111   MemoryTypeMap output_types;
    112   std::vector<ControlFlowInfo> cf_info;
    113 };
    114 
    115 DataType EdgeType(const Edge* e) {
    116   if (e->IsControlEdge()) {
    117     return DT_FLOAT;
    118   } else {
    119     return e->dst()->input_type(e->dst_input());
    120   }
    121 }
    122 
    123 // Return true iff we need to add the same device send/recv for 'edge'.
    124 bool NeedSameDeviceSendRecv(const Edge* edge, const GraphInfo& info) {
    125   if (edge->IsControlEdge()) {
    126     return false;
    127   }
    128 
    129   const Node* src = edge->src();
    130   const Node* dst = edge->dst();
    131   if (src->assigned_device_name() == dst->assigned_device_name()) {
    132     int src_port = edge->src_output();
    133     int dst_port = edge->dst_input();
    134     if (info.device_types[src->id()] != DEVICE_CPU) {
    135       auto src_it = info.output_types.find({src->id(), src_port});
    136       DCHECK(src_it != info.output_types.end());
    137       auto dst_it = info.input_types.find({dst->id(), dst_port});
    138       DCHECK(dst_it != info.input_types.end());
    139       return src_it->second != dst_it->second;
    140     }
    141   }
    142   return false;
    143 }
    144 
    145 // Return true iff (dst, dst_input) is specified on host memory.
    146 bool IsDstInputOnHost(const Edge* edge, const GraphInfo& info) {
    147   const Node* dst = edge->dst();
    148   int dst_port = edge->dst_input();
    149   if (info.device_types[dst->id()] != DEVICE_CPU) {
    150     if (edge->IsControlEdge()) return false;
    151     auto dst_it = info.input_types.find({dst->id(), dst_port});
    152     DCHECK(dst_it != info.input_types.end());
    153     return dst_it->second == HOST_MEMORY;
    154   }
    155   return true;
    156 }
    157 
    158 // Add an input to dst that comes from the "src_slot" output of the
    159 // node named by "src_name".
    160 void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) {
    161   if (src_slot == Graph::kControlSlot) {
    162     dst->add_input(strings::StrCat("^", src_name));
    163   } else if (src_slot == 0) {
    164     dst->add_input(src_name.data(), src_name.size());
    165   } else {
    166     dst->add_input(strings::StrCat(src_name, ":", src_slot));
    167   }
    168 }
    169 
    170 // Add a control edge from each input to each recv.
    171 void AddReadControl(const std::vector<NodeDef*>& recvs,
    172                     const std::vector<string>& inputs) {
    173   for (NodeDef* recv : recvs) {
    174     for (const string& input : inputs) {
    175       recv->add_input(strings::StrCat("^", input));
    176     }
    177   }
    178 }
    179 
    180 void SetSendRecvAttrs(const PartitionOptions& opts, const Edge* edge,
    181                       NodeDefBuilder* builder) {
    182   builder->Attr("tensor_name",
    183                 strings::StrCat("edge_", edge->id(), "_", edge->src()->name()));
    184   builder->Attr("send_device", edge->src()->assigned_device_name());
    185   builder->Attr("send_device_incarnation",
    186                 static_cast<int64>(
    187                     opts.get_incarnation(edge->src()->assigned_device_name())));
    188   builder->Attr("recv_device", edge->dst()->assigned_device_name());
    189   builder->Attr("client_terminated", false);
    190 }
    191 
    192 NodeDef* AddSend(const PartitionOptions& opts, const GraphInfo& g_info,
    193                  GraphDef* gdef, const Edge* edge,
    194                  NodeDefBuilder::NodeOut send_from, int64 start_time,
    195                  Status* status) {
    196   const DataType dtype = send_from.data_type;
    197   const DataType cast_dtype = opts.should_cast ? opts.should_cast(edge) : dtype;
    198   const Node* src = edge->src();
    199   const int src_port = edge->src_output();
    200 
    201   // host_memory = true iff we need to use HostSend/HostCast.
    202   bool host_memory = false;
    203   if (!edge->IsControlEdge()) {
    204     auto src_it = g_info.output_types.find({src->id(), src_port});
    205     DCHECK(src_it != g_info.output_types.end());
    206     host_memory = (src_it->second == HOST_MEMORY);
    207   }
    208 
    209   // Add a cast node that casts dtype to cast_dtype.
    210   // NOTE(yuanbyu): Only cast for cross-device send/recv.
    211   if (dtype != cast_dtype && !NeedSameDeviceSendRecv(edge, g_info)) {
    212     const string cast_op = (host_memory) ? "_HostCast" : "Cast";
    213     NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op,
    214                                 NodeDebugInfo(*src));
    215     cast_builder.Device(src->assigned_device_name()).Input(send_from);
    216     if (opts.scheduling_for_recvs) {
    217       cast_builder.Attr("_start_time", start_time);
    218     }
    219     cast_builder.Attr("DstT", cast_dtype);
    220 
    221     if (cast_dtype == DT_BFLOAT16) {
    222       // the below attribute specifies that the cast to bfloat16 should use
    223       // truncation. This is needed to retain legacy behavior when we change
    224       // the default bfloat16 casts to use rounding instead of truncation
    225       cast_builder.Attr("Truncate", true);
    226     }
    227 
    228     NodeDef* cast = gdef->add_node();
    229     *status = cast_builder.Finalize(cast);
    230     if (!status->ok()) return nullptr;
    231 
    232     // Connect the Send op to the cast.
    233     send_from.Reset(cast->name(), 0, cast_dtype);
    234   }
    235 
    236   // Add the send node.
    237   const string send_op = (host_memory) ? "_HostSend" : "_Send";
    238   NodeDefBuilder send_builder(opts.new_name(src->name()), send_op,
    239                               NodeDebugInfo(*src));
    240   SetSendRecvAttrs(opts, edge, &send_builder);
    241   send_builder.Device(src->assigned_device_name()).Input(send_from);
    242   if (opts.scheduling_for_recvs) {
    243     send_builder.Attr("_start_time", start_time);
    244   }
    245   NodeDef* send = gdef->add_node();
    246   *status = send_builder.Finalize(send);
    247   return send;
    248 }
    249 
    250 NodeDef* AddRecv(const PartitionOptions& opts, const GraphInfo& g_info,
    251                  GraphDef* gdef, const Edge* edge, NodeDef** real_recv,
    252                  Status* status) {
    253   const DataType dtype = EdgeType(edge);
    254   const Node* src = edge->src();
    255   const Node* dst = edge->dst();
    256   const int dst_port = edge->dst_input();
    257   DataType cast_dtype = dtype;
    258 
    259   // NOTE(yuanbyu): Only cast for cross-device send/recv.
    260   if (opts.should_cast && !NeedSameDeviceSendRecv(edge, g_info)) {
    261     cast_dtype = opts.should_cast(edge);
    262   }
    263 
    264   // host_memory = true iff we need to use HostRecv/HostCast.
    265   bool host_memory = false;
    266   if (!edge->IsControlEdge()) {
    267     auto dst_it = g_info.input_types.find({dst->id(), dst_port});
    268     DCHECK(dst_it != g_info.input_types.end());
    269     host_memory = (dst_it->second == HOST_MEMORY);
    270   }
    271 
    272   // Add the recv node.
    273   const string recv_op = (host_memory) ? "_HostRecv" : "_Recv";
    274   NodeDefBuilder recv_builder(opts.new_name(src->name()), recv_op,
    275                               NodeDebugInfo(*src));
    276   SetSendRecvAttrs(opts, edge, &recv_builder);
    277   recv_builder.Device(dst->assigned_device_name())
    278       .Attr("tensor_type", cast_dtype);
    279   NodeDef* recv = gdef->add_node();
    280   *status = recv_builder.Finalize(recv);
    281   if (!status->ok()) return nullptr;
    282   *real_recv = recv;
    283 
    284   // Add the cast node (from cast_dtype to dtype) or an Identity node.
    285   if (dtype != cast_dtype) {
    286     const string cast_op = (host_memory) ? "_HostCast" : "Cast";
    287     NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op,
    288                                 NodeDebugInfo(*src));
    289     cast_builder.Attr("DstT", dtype);
    290     cast_builder.Device(dst->assigned_device_name())
    291         .Input(recv->name(), 0, cast_dtype);
    292     NodeDef* cast = gdef->add_node();
    293     *status = cast_builder.Finalize(cast);
    294     if (!status->ok()) return nullptr;
    295     return cast;
    296   } else if (edge->IsControlEdge()) {
    297     // An Identity is only needed for control edges.
    298     NodeDefBuilder id_builder(opts.new_name(src->name()), "Identity",
    299                               NodeDebugInfo(*src));
    300     id_builder.Device(dst->assigned_device_name())
    301         .Input(recv->name(), 0, cast_dtype);
    302     NodeDef* id = gdef->add_node();
    303     *status = id_builder.Finalize(id);
    304     if (!status->ok()) return nullptr;
    305     return id;
    306   } else {
    307     return recv;
    308   }
    309 }
    310 
    311 NodeDef* AddDummyConst(const PartitionOptions& opts, GraphDef* gdef,
    312                        const Edge* edge, Status* status) {
    313   const Node* src = edge->src();
    314   Tensor tensor(DT_FLOAT, TensorShape({0}));
    315   NodeDef* result = gdef->add_node();
    316   *status = NodeDefBuilder(opts.new_name(src->name()), "Const")
    317                 .Device(src->assigned_device_name())
    318                 .Attr("dtype", DT_FLOAT)
    319                 .Attr("value", tensor)
    320                 .Finalize(result);
    321   return result;
    322 }
    323 
    324 // A dummy node for scheduling.
    325 NodeDef* AddControlTrigger(const PartitionOptions& opts, GraphDef* gdef,
    326                            const string& assigned_device_name, int64 epoch,
    327                            int64 starttime, Status* status) {
    328   NodeDef* result = gdef->add_node();
    329   *status = NodeDefBuilder(opts.new_name(strings::StrCat("synch_", epoch)),
    330                            "ControlTrigger")
    331                 .Device(assigned_device_name)
    332                 .Attr("_start_time", starttime)
    333                 .Finalize(result);
    334   return result;
    335 }
    336 
    337 // Optimize colocation for control flow nodes. For cond, we want the
    338 // switch nodes to colocate with its data input. This is particularly
    339 // needed for conditional reading of a remote variable. It may also
    340 // reduce the number of devices involved in a loop.
    341 // TODO(yuanbyu): In this case, we don't respect the requested device in
    342 // the GraphDef for these nodes. Ideally, the placer would enforce the
    343 // colocation to render this unnecessary.
    344 void OptimizeControlFlowColocation(Graph* graph) {
    345   auto visit = [](Node* node) {
    346     if (IsSwitch(node)) {
    347       for (const Edge* in_edge : node->in_edges()) {
    348         if (in_edge->dst_input() == 0) {
    349           // Colocate with the data input.
    350           node->set_assigned_device_name(
    351               in_edge->src()->assigned_device_name());
    352           return;
    353         }
    354       }
    355     } else if (IsExit(node)) {
    356       for (const Edge* in_edge : node->in_edges()) {
    357         if (!in_edge->IsControlEdge()) {
    358           // Colocate with upstream node.
    359           node->set_assigned_device_name(
    360               in_edge->src()->assigned_device_name());
    361           return;
    362         }
    363       }
    364     } else {
    365       if ((IsEnter(node) && !IsRefType(node->input_type(0))) ||
    366           IsNextIteration(node)) {
    367         const Edge* data_edge = nullptr;
    368         for (const Edge* out_edge : node->out_edges()) {
    369           if (!out_edge->IsControlEdge()) {
    370             data_edge = out_edge;
    371             break;
    372           }
    373         }
    374         // Colocate with the first downstream data node.
    375         if (data_edge) {
    376           node->set_assigned_device_name(
    377               data_edge->dst()->assigned_device_name());
    378         }
    379       }
    380     }
    381   };
    382   DFS(*graph, visit, {});
    383 }
    384 
    385 string ControlLoopName(const string& name) {
    386   return strings::StrCat("_cloop", name);
    387 }
    388 
    389 bool IsControlLoop(const Node* node) {
    390   const string& name = node->name();
    391   return str_util::StartsWith(name, "_cloop");
    392 }
    393 
    394 // An enter node for control flow.
    395 Node* AddControlEnter(Graph* g, const string& node_name,
    396                       const string& device_name, const string& frame_name,
    397                       const int parallel_iterations, Status* status) {
    398   NodeBuilder node_builder(node_name, "Enter", g->op_registry());
    399   node_builder.Input({"dummy", 0, DT_FLOAT});
    400   node_builder.Attr("frame_name", frame_name);
    401   node_builder.Attr("parallel_iterations", parallel_iterations);
    402   Node* res_node;
    403   *status = node_builder.Finalize(g, &res_node);
    404   if (!status->ok()) return nullptr;
    405   res_node->set_assigned_device_name(device_name);
    406   return res_node;
    407 }
    408 
    409 // A merge node for control flow.
    410 Node* AddControlMerge(const string& in_name1, const string& in_name2, Graph* g,
    411                       const string& node_name, const string& device_name,
    412                       Status* status) {
    413   NodeBuilder node_builder(node_name, "Merge", g->op_registry());
    414   node_builder.Input({{in_name1, 0, DT_FLOAT}, {in_name2, 0, DT_FLOAT}});
    415   Node* res_node;
    416   *status = node_builder.Finalize(g, &res_node);
    417   if (!status->ok()) return nullptr;
    418   res_node->set_assigned_device_name(device_name);
    419   return res_node;
    420 }
    421 
    422 // A switch node for control flow.
    423 Node* AddControlSwitch(NodeBuilder::NodeOut input1, NodeBuilder::NodeOut input2,
    424                        const string& device_name,
    425                        const GraphDefBuilder::Options& bopts) {
    426   Node* res_node =
    427       ops::BinaryOp("Switch", std::move(input1), std::move(input2), bopts);
    428   if (bopts.HaveError()) return nullptr;
    429   res_node->set_assigned_device_name(device_name);
    430   return res_node;
    431 }
    432 
    433 // A next_iteration node for control flow.
    434 Node* AddControlNext(NodeBuilder::NodeOut input, const string& device_name,
    435                      const GraphDefBuilder::Options& bopts) {
    436   Node* res_node = ops::UnaryOp("NextIteration", std::move(input), bopts);
    437   if (bopts.HaveError()) return nullptr;
    438   res_node->set_assigned_device_name(device_name);
    439   return res_node;
    440 }
    441 
    442 Node* EmptyConst(const GraphDefBuilder::Options& options) {
    443   if (options.HaveError()) return nullptr;
    444   NodeBuilder node_builder(options.GetNameForOp("Const"), "Const",
    445                            options.op_registry());
    446   const DataType dt = DataTypeToEnum<float>::v();
    447   TensorProto proto;
    448   proto.set_dtype(dt);
    449   TensorShape empty_shape({0});
    450   empty_shape.AsProto(proto.mutable_tensor_shape());
    451   node_builder.Attr("dtype", dt).Attr("value", proto);
    452   return options.FinalizeBuilder(&node_builder);
    453 }
    454 
    455 // A dummy const node for control flow.
    456 Node* AddControlConst(const string& device_name,
    457                       const GraphDefBuilder::Options& bopts) {
    458   Node* res_node = EmptyConst(bopts);
    459   if (bopts.HaveError()) return nullptr;
    460   res_node->set_assigned_device_name(device_name);
    461   return res_node;
    462 }
    463 
    464 // A synthetic loop, made up of dummy nodes. It performs control-flow actions
    465 // on behalf of a leader on a different device.
    466 struct ControlLoop {
    467   Node* enter = nullptr;
    468   Node* merge = nullptr;
    469   Node* switch_node = nullptr;
    470 };
    471 
    472 // Add the control flow info of a new node added during partitioning.
    473 // The new node has the same control flow info as src.
    474 void AddControlFlowInfo(const Node* node, const Node* src,
    475                         std::vector<ControlFlowInfo>* cf_info) {
    476   int id = node->id();
    477   if (static_cast<size_t>(id) >= cf_info->size()) {
    478     cf_info->resize(id + 1);
    479   }
    480   const ControlFlowInfo& src_info = (*cf_info)[src->id()];
    481   ControlFlowInfo* info = &(*cf_info)[id];
    482   info->frame = src_info.frame;
    483   info->parent_frame = src_info.parent_frame;
    484   info->frame_name = src_info.frame_name;
    485 }
    486 
    487 // Constructs a control loop. Returns a struct containing the newly created
    488 // enter, merge, and switch nodes. The enter and merge nodes are used in the
    489 // recursive construction of control loops for nested frames (loops). The
    490 // switch node will be connected to the LoopCond node. The merge node will
    491 // be connected to all the recvs of the same frame by control edges when
    492 // the actual partitioning happens.
    493 Status AddControlLoop(const PartitionOptions& opts, Graph* g, const Node* src,
    494                       const Edge* edge, Node* loop_cond,
    495                       std::vector<ControlFlowInfo>* cf_info,
    496                       ControlLoop* loop) {
    497   Status status;
    498   GraphDefBuilder::Options bopts(g, &status);
    499   const ControlFlowInfo& src_info = (*cf_info)[src->id()];
    500   const string& device_name = edge->dst()->assigned_device_name();
    501   const string& frame_name = src_info.frame_name;
    502   int parallel_iterations;
    503   status = GetNodeAttr(src_info.frame->attrs(), "parallel_iterations",
    504                        &parallel_iterations);
    505   if (!status.ok()) return status;
    506 
    507   // The names of the nodes to be added.
    508   const string& enter_name =
    509       ControlLoopName(opts.new_name(edge->dst()->name()));
    510   const string& merge_name =
    511       ControlLoopName(opts.new_name(edge->dst()->name()));
    512   const string& switch_name =
    513       ControlLoopName(opts.new_name(edge->dst()->name()));
    514   const string& next_name = ControlLoopName(opts.new_name(edge->dst()->name()));
    515 
    516   // Add the nodes to the graph g.
    517   Node* enter = AddControlEnter(g, enter_name, device_name, frame_name,
    518                                 parallel_iterations, &status);
    519   if (!status.ok()) return status;
    520   Node* merge = AddControlMerge(enter_name, next_name, g, merge_name,
    521                                 device_name, &status);
    522   if (!status.ok()) return status;
    523   Node* switch_node = AddControlSwitch(merge, loop_cond, device_name,
    524                                        bopts.WithName(switch_name));
    525   if (!status.ok()) return status;
    526   Node* next =
    527       AddControlNext({switch_node, 1}, device_name, bopts.WithName(next_name));
    528   if (!status.ok()) return status;
    529 
    530   // Add control flow info for these new nodes:
    531   AddControlFlowInfo(enter, src, cf_info);
    532   AddControlFlowInfo(merge, src, cf_info);
    533   AddControlFlowInfo(switch_node, src, cf_info);
    534   AddControlFlowInfo(next, src, cf_info);
    535 
    536   // Add input edges for the newly created merge node:
    537   g->AddEdge(enter, 0, merge, 0);
    538   g->AddEdge(next, 0, merge, 1);
    539 
    540   loop->enter = enter;
    541   loop->merge = merge;
    542   loop->switch_node = switch_node;
    543   return Status::OK();
    544 }
    545 
    546 // Build memory and device type info for every node in the graph.
    547 // TODO(yuanbyu): It might be simpler if we convert MemoryType to
    548 // DeviceType for the inputs/outputs of each node.
    549 Status BuildMemoryDeviceInfo(const Graph& g, GraphInfo* info) {
    550   MemoryTypeVector input_memory_types;
    551   MemoryTypeVector output_memory_types;
    552 
    553   info->device_types.resize(g.num_node_ids(), DEVICE_CPU);
    554   for (const Node* node : g.op_nodes()) {
    555     DeviceNameUtils::ParsedName parsed;
    556     if (!DeviceNameUtils::ParseFullName(node->assigned_device_name(),
    557                                         &parsed)) {
    558       return errors::Internal("Malformed assigned device '",
    559                               node->assigned_device_name(), "'");
    560     }
    561 
    562     TF_RETURN_IF_ERROR(MemoryTypesForNode(
    563         g.op_registry(), DeviceType(parsed.type), node->def(),
    564         &input_memory_types, &output_memory_types));
    565 
    566     int node_id = node->id();
    567     info->device_types[node_id] = DeviceType(parsed.type);
    568     for (int i = 0; i < input_memory_types.size(); ++i) {
    569       info->input_types[{node_id, i}] = input_memory_types[i];
    570     }
    571     for (int i = 0; i < output_memory_types.size(); ++i) {
    572       info->output_types[{node_id, i}] = output_memory_types[i];
    573     }
    574   }
    575   return Status::OK();
    576 }
    577 
    578 const Node* InputFrame(const Node* node,
    579                        const std::vector<ControlFlowInfo>& cf_info) {
    580   // An input is in the same frame as the node except for Enter nodes.
    581   // The input of Enter is in the parent frame of the Enter node.
    582   if (!node->IsEnter()) {
    583     return node;
    584   }
    585   return cf_info[node->id()].parent_frame;
    586 }
    587 
    588 const Node* OutputFrame(const Node* node,
    589                         const std::vector<ControlFlowInfo>& cf_info) {
    590   // An output is in the same frame as the node except for Exit nodes.
    591   // The output of Exit is in the parent frame of the Exit node.
    592   if (!node->IsExit()) {
    593     return node;
    594   }
    595   return cf_info[node->id()].parent_frame;
    596 }
    597 
    598 // Each participating device needs to decide a) if there is a next iteration,
    599 // and b) if the loop terminates. We take the approach to encode this control
    600 // flow logic in the dataflow graph. There are at least two possible encodings.
    601 // In a completely decentralized encoding, the participants communicate peer
    602 // to peer. The other encoding uses a frame leader (the participant who owns
    603 // the pivot termination predicate) to broadcast the termination condition to
    604 // all the participants. For now we take the latter because it is simpler.
    605 //
    606 // TODO(yuanbyu): The correctness of this construction is rather subtle. I got
    607 // it wrong many times so it would be nice to write a proof to be sure.
    608 Status AddControlFlow(const PartitionOptions& opts, Graph* g,
    609                       GraphInfo* g_info) {
    610   Status status;
    611   GraphDefBuilder::Options bopts(g, &status);
    612   std::vector<ControlFlowInfo>& cf_info = g_info->cf_info;
    613 
    614   // Build the control flow info for every node.
    615   status = BuildControlFlowInfo(g, &cf_info);
    616   if (!status.ok()) return status;
    617 
    618   OptimizeControlFlowColocation(g);
    619 
    620   // The map from frames to their LoopCond nodes.
    621   std::unordered_map<string, Node*> frame_cond_map;
    622   int num_node_ids = g->num_node_ids();
    623   for (int i = 0; i < num_node_ids; ++i) {
    624     Node* node = g->FindNodeId(i);
    625     if (node == nullptr) continue;
    626 
    627     if (IsLoopCond(node)) {
    628       const string& frame_name = cf_info[node->id()].frame_name;
    629       DCHECK(!frame_name.empty());
    630       frame_cond_map[frame_name] = node;
    631     }
    632   }
    633 
    634   // Add all control loops for cross-device frames.
    635   // A control loop is added only when there is a cross-device edge in a
    636   // non-root frame. Nothing is added if there is no loops. We also don't
    637   // add anything for a frame that is completely local to a device. For
    638   // nested loops, we stack the control loops together by connecting
    639   // the merge of the outer loop to the enter of the inner loop.
    640   //
    641   // A map from <frame_name, device_name> to ControlLoop.
    642   std::unordered_map<string, ControlLoop> control_loops;
    643   int num_edge_ids = g->num_edge_ids();
    644   for (int i = 0; i < num_edge_ids; ++i) {
    645     const Edge* edge = g->FindEdgeId(i);
    646     if (edge == nullptr) continue;
    647 
    648     const Node* src = edge->src();
    649     const Node* dst = edge->dst();
    650     // Skip Sink/Source nodes.
    651     if (!src->IsOp() || !dst->IsOp()) continue;
    652 
    653     const string& src_device = src->assigned_device_name();
    654     const string& dst_device = dst->assigned_device_name();
    655     // Skip local edges.
    656     if (src_device == dst_device) continue;
    657 
    658     const Node* src_frame = OutputFrame(src, cf_info);
    659     const Node* dst_frame = InputFrame(dst, cf_info);
    660     const string& src_frame_name = cf_info[src_frame->id()].frame_name;
    661     const string& dst_frame_name = cf_info[dst_frame->id()].frame_name;
    662     // Skip if src and dst are not in the same frame.
    663     if (src_frame_name.empty() || src_frame_name != dst_frame_name) {
    664       continue;
    665     }
    666 
    667     // Add the control loop. Start by adding the control loop for the
    668     // current frame if needed, and recursively adding the control loop
    669     // for its outer frame when nested.
    670     ControlLoop child_loop;
    671     while (true) {
    672       const string& curr_frame_name = cf_info[src_frame->id()].frame_name;
    673       if (curr_frame_name.empty()) {
    674         // We have reached the root frame.
    675         if (child_loop.merge != nullptr) {
    676           const string& node_name = opts.new_name(edge->dst()->name());
    677           const string& device_name = edge->dst()->assigned_device_name();
    678           Node* const_node =
    679               AddControlConst(device_name, bopts.WithName(node_name));
    680           if (!status.ok()) return status;
    681           AddControlFlowInfo(const_node, src_frame, &cf_info);
    682           g->AddEdge(const_node, 0, child_loop.enter, 0);
    683         }
    684         break;
    685       }
    686 
    687       const string& cl_key = strings::StrCat(curr_frame_name, "$$", dst_device);
    688       auto it = control_loops.find(cl_key);
    689       if (it != control_loops.end()) {
    690         if (child_loop.enter != nullptr) {
    691           g->AddEdge(it->second.merge, 0, child_loop.enter, 0);
    692         }
    693         break;
    694       }
    695 
    696       // Get the frame's LoopCond.
    697       auto cond_it = frame_cond_map.find(curr_frame_name);
    698       if (cond_it == frame_cond_map.end()) {
    699         return errors::InvalidArgument(
    700             "A cross-device loop must have a pivot predicate: ",
    701             curr_frame_name);
    702       }
    703       Node* loop_cond = cond_it->second;
    704 
    705       // Add the control loop.
    706       ControlLoop curr_loop;
    707       status = AddControlLoop(opts, g, src_frame, edge, loop_cond, &cf_info,
    708                               &curr_loop);
    709       if (!status.ok()) return status;
    710       control_loops[cl_key] = curr_loop;
    711 
    712       if (child_loop.enter != nullptr) {
    713         // Connect the merge of the outer loop to the enter of the inner.
    714         g->AddEdge(curr_loop.merge, 0, child_loop.enter, 0);
    715       }
    716       src_frame = cf_info[src_frame->id()].parent_frame;
    717       child_loop = curr_loop;
    718     }
    719   }
    720 
    721   // For a cross-device edge, on the dst device, add a control edge
    722   // from the merge node of the control loop to dst. If a send/recv is
    723   // introduced for this edge in future partitioning, we delete this
    724   // control edge and add a new control edge from the merge to the recv.
    725   num_edge_ids = g->num_edge_ids();
    726   for (int i = 0; i < num_edge_ids; ++i) {
    727     const Edge* edge = g->FindEdgeId(i);
    728     if (edge == nullptr) continue;
    729 
    730     const Node* src = edge->src();
    731     Node* dst = edge->dst();
    732     // Skip Sink/Source nodes.
    733     if (!src->IsOp() || !dst->IsOp()) continue;
    734 
    735     const string& src_device = src->assigned_device_name();
    736     const string& dst_device = dst->assigned_device_name();
    737     if (src_device != dst_device) {
    738       const Node* src_frame = OutputFrame(src, cf_info);
    739       const Node* dst_frame = InputFrame(dst, cf_info);
    740       const string& src_frame_name = cf_info[src_frame->id()].frame_name;
    741       const string& dst_frame_name = cf_info[dst_frame->id()].frame_name;
    742       if (!src_frame_name.empty() && src_frame_name == dst_frame_name) {
    743         const string& cl_key =
    744             strings::StrCat(dst_frame_name, "$$", dst_device);
    745         ControlLoop loop = control_loops[cl_key];
    746         DCHECK(loop.enter != nullptr);
    747         // Note that we'll create multiple duplicate edges if dst has multiple
    748         // cross-device inputs. This is expected by the logic in Partition(), so
    749         // it can add control edges to the recv nodes once they're created.
    750         g->AddControlEdge(loop.merge, dst, /*allow_duplicates=*/true);
    751       }
    752     }
    753   }
    754   return Status::OK();
    755 }
    756 
    757 struct PriorityTopoSortNode {
    758   PriorityTopoSortNode(const NodeDef* n, int64 st) : node(n), start_time(st) {}
    759 
    760   const NodeDef* node;
    761   int64 start_time;
    762 };
    763 
    764 struct PriorityTopoSortNodeGreater {
    765   bool operator()(const PriorityTopoSortNode& left,
    766                   const PriorityTopoSortNode& right) {
    767     return left.start_time > right.start_time;
    768   }
    769 };
    770 
    771 }  // namespace
    772 
    773 // Returns in <nodes> the nodes that should participate in epoch-based recv
    774 // scheduling, along with their times; <nodes> is ordered by increasing
    775 // start_time. Returns in <node_to_start_time_out> the timing for all nodes,
    776 // even those not in <nodes>.
    777 //
    778 // Comparing to sorting on the node's start time only, this also processes the
    779 // nodes in dependency order, and updates start times to ensure a node's
    780 // start_time > the start time for all dependencies.
    781 //
    782 // Note that graph_partition_test.cc accesses this function for testing, even
    783 // though it's not declared in the header.
    784 Status TopologicalSortNodesWithTimePriority(
    785     const GraphDef* gdef, std::vector<std::pair<const NodeDef*, int64>>* nodes,
    786     std::unordered_map<const NodeDef*, int64>* node_to_start_time_out) {
    787   // Queue of nodes to process; lowest start time is returned first.
    788   std::priority_queue<PriorityTopoSortNode, std::vector<PriorityTopoSortNode>,
    789                       PriorityTopoSortNodeGreater>
    790       q;
    791   std::unordered_map<const NodeDef*, int64> node_to_start_time;
    792   auto enqueue = [&q, &node_to_start_time](const NodeDef* node) {
    793     const int64 start_time = node_to_start_time[node];
    794     q.emplace(node, start_time);
    795   };
    796 
    797   // Build initial structures, initial contents of queue.
    798   std::unordered_map<string, std::vector<const NodeDef*>> node_to_output_nodes;
    799   std::unordered_map<const NodeDef*, int> inputs_needed;
    800   for (int n = 0; n < gdef->node_size(); ++n) {
    801     const NodeDef* ndef = &gdef->node(n);
    802     for (int i = 0; i < ndef->input_size(); ++i) {
    803       node_to_output_nodes[string(ParseTensorName(ndef->input(i)).first)]
    804           .push_back(ndef);
    805     }
    806     int64 start_time;
    807     TF_RETURN_IF_ERROR(GetNodeAttr(*ndef, "_start_time", &start_time));
    808     node_to_start_time[ndef] = start_time;
    809     inputs_needed[ndef] = ndef->input_size();
    810     if (ndef->input_size() == 0) {
    811       enqueue(ndef);
    812     }
    813   }
    814 
    815   // Determine which merge nodes are parts of loops; these
    816   // need to happen in the traversal after all non-NextIteration inputs
    817   // are run.
    818   for (int n = 0; n < gdef->node_size(); ++n) {
    819     const NodeDef* ndef = &gdef->node(n);
    820     if (IsNextIteration(*ndef)) {
    821       for (const NodeDef* n : node_to_output_nodes[ndef->name()]) {
    822         if (IsMerge(*n)) {
    823           // n is a merge that is part of a loop structure.
    824           // It doesn't need to wait for this NextIteration loop
    825           // when doing the traversal.
    826           --inputs_needed[n];
    827         }
    828       }
    829     }
    830   }
    831 
    832   // Traverse.
    833   std::vector<std::pair<const NodeDef*, int64>> start_times;
    834   start_times.reserve(gdef->node_size());
    835   while (!q.empty()) {
    836     PriorityTopoSortNode cur = q.top();
    837     q.pop();
    838 
    839     start_times.emplace_back(cur.node, cur.start_time);
    840 
    841     for (const NodeDef* n : node_to_output_nodes[cur.node->name()]) {
    842       auto& output_start_time = node_to_start_time[n];
    843       if (output_start_time <= cur.start_time) {
    844         output_start_time = cur.start_time + 1;
    845       }
    846       if (--inputs_needed[n] == 0) {
    847         enqueue(n);
    848       }
    849     }
    850   }
    851 
    852   // Done.
    853   nodes->swap(start_times);
    854   node_to_start_time_out->swap(node_to_start_time);
    855   return Status::OK();
    856 }
    857 
    858 Status AddControlEdges(const PartitionOptions& opts,
    859                        std::unordered_map<string, GraphDef>* partitions) {
    860   Status status;
    861   // TODO(yuanbyu): Very naive for now. To be improved.
    862   const int num_epochs = 100;
    863   const int prefetch = 6;
    864 
    865   for (auto& part : *partitions) {
    866     GraphDef* gdef = &part.second;
    867     std::vector<std::pair<const NodeDef*, int64>> start_times;
    868     std::unordered_map<const NodeDef*, int64> node_to_start_time;
    869     status = TopologicalSortNodesWithTimePriority(gdef, &start_times,
    870                                                   &node_to_start_time);
    871     if (!status.ok()) {
    872       return status;
    873     }
    874 
    875     // Add a dummy node for every epoch, and add a control edge from the
    876     // "last" node in the preceding epoch to the dummy node.
    877     string device_name = gdef->node(0).device();
    878     int64 makespan = start_times.back().second;
    879     int64 resolution = (makespan / num_epochs) + 1;
    880 
    881     int i = 0;
    882     int j = 0;
    883     std::vector<NodeDef*> dummys;
    884     while (i < num_epochs && static_cast<size_t>(j) < start_times.size()) {
    885       if (i * resolution > start_times[j].second) {
    886         j++;
    887       } else {
    888         NodeDef* dummy = AddControlTrigger(opts, gdef, device_name, i,
    889                                            i * resolution, &status);
    890         if (!status.ok()) {
    891           return status;
    892         }
    893         dummys.push_back(dummy);
    894         if (j > 0) {
    895           string src_name = start_times[j - 1].first->name();
    896           AddInput(dummy, src_name, Graph::kControlSlot);
    897         }
    898         i++;
    899       }
    900     }
    901 
    902     // Finally, add the control edges to recvs.
    903     for (int n = 0; n < gdef->node_size(); ++n) {
    904       NodeDef* ndef = gdef->mutable_node(n);
    905       if (ndef->op() == "_Recv") {
    906         const int64 start_time = node_to_start_time[ndef];
    907         const int recv_epoch = start_time / resolution;
    908         if (recv_epoch >= prefetch) {
    909           NodeDef* dummy = dummys[recv_epoch - prefetch];
    910           AddInput(ndef, dummy->name(), Graph::kControlSlot);
    911         }
    912       }
    913     }
    914   }
    915   return Status::OK();
    916 }
    917 
    918 // If 'ndef' is a Send or Recv, fills its attr send_device_incarnation
    919 // if possible.
    920 void SetIncarnation(const PartitionOptions& opts, NodeDef* ndef) {
    921   StringPiece op(ndef->op());
    922   if (op != "_Send" && op != "_Recv") {
    923     // Not related to send/recv.
    924     return;
    925   }
    926   string send_device;
    927   if (!GetNodeAttr(*ndef, "send_device", &send_device).ok()) {
    928     // No known send_device. The runtime will detect it later.
    929     return;
    930   }
    931   int64 incarnation = PartitionOptions::kIllegalIncarnation;
    932   if (!GetNodeAttr(*ndef, "send_device_incarnation", &incarnation).ok() ||
    933       (incarnation == PartitionOptions::kIllegalIncarnation)) {
    934     incarnation = opts.get_incarnation(send_device);
    935     SetAttrValue(incarnation,
    936                  &((*ndef->mutable_attr())["send_device_incarnation"]));
    937   }
    938 }
    939 
    940 // Sets attribute send_device_incarnation of all Send/Recv nodes in
    941 // 'gdef', if possible.
    942 void SetIncarnation(const PartitionOptions& opts, GraphDef* gdef) {
    943   for (NodeDef& ndef : *gdef->mutable_node()) {
    944     SetIncarnation(opts, &ndef);
    945   }
    946   for (FunctionDef& fdef : *gdef->mutable_library()->mutable_function()) {
    947     for (NodeDef& ndef : *fdef.mutable_node_def()) {
    948       SetIncarnation(opts, &ndef);
    949     }
    950   }
    951 }
    952 
    953 Status Partition(const PartitionOptions& opts, Graph* g,
    954                  std::unordered_map<string, GraphDef>* partitions) {
    955   Status status;
    956   partitions->clear();
    957 
    958   GraphInfo g_info;
    959   if (!opts.control_flow_added) {
    960     // Add the "code" for distributed execution of control flow. Code is
    961     // added only for the frames that are placed on multiple devices. The
    962     // new graph is an equivalent transformation of the original graph and
    963     // has the property that it can be subsequently partitioned arbitrarily
    964     // (down to the level of individual device) for distributed execution.
    965     status = AddControlFlow(opts, g, &g_info);
    966     if (!status.ok()) return status;
    967   }
    968 
    969   // At this point, all the graph mutations have been done. Build memory
    970   // and device type info for every node and edge in the graph.
    971   status = BuildMemoryDeviceInfo(*g, &g_info);
    972   if (!status.ok()) return status;
    973 
    974   string dstp;
    975   std::vector<const Edge*> inputs;
    976   DupRecvTable dup_recv(3);
    977   // For a node dst, 'ref_recvs' remembers the recvs introduced by a ref
    978   // edge to dst. 'ref_control_inputs' remembers the inputs by a non-ref
    979   // edge to dst. We will add a control edge for every pair in
    980   // (ref_recvs x ref_control_inputs).
    981   std::vector<NodeDef*> ref_recvs;
    982   std::vector<string> ref_control_inputs;
    983 
    984   int32 num_data = 0;
    985   int32 num_control = 0;
    986   for (const Node* dst : g->op_nodes()) {
    987     dstp = opts.node_to_loc(dst);
    988     GraphDef* dst_graph = &(*partitions)[dstp];
    989     NodeDef* dst_def = dst_graph->add_node();
    990     *dst_def = dst->def();
    991     MergeDebugInfo(NodeDebugInfo(dst->def()), dst_def);
    992     dst_def->set_device(dst->assigned_device_name());
    993     dst_def->clear_input();  // Inputs are filled below
    994     if (opts.need_to_record_start_times) {
    995       int64 start_time;
    996       status = GetNodeAttr(*dst_def, "_start_time", &start_time);
    997       if (errors::IsNotFound(status)) {
    998         start_time = opts.start_times[dst->id()].value();
    999         AddNodeAttr("_start_time", start_time, dst_def);
   1000       } else if (!status.ok()) {
   1001         return status;
   1002       }
   1003     }
   1004 
   1005     // Arrange the incoming edges to dst so that input[i] holds the
   1006     // input flowing into slot numbered i. Trailing entries in input[]
   1007     // hold control edges.
   1008     inputs.clear();
   1009     inputs.resize(dst->num_inputs(), nullptr);
   1010     ref_recvs.clear();
   1011     ref_control_inputs.clear();
   1012     const Edge* control_flow_edge = nullptr;
   1013     int32 num_control_flow_edges = 0;
   1014     int32 num_input_edges = 0;
   1015     for (const Edge* edge : dst->in_edges()) {
   1016       if (edge->IsControlEdge()) {
   1017         if (IsMerge(edge->src()) && IsControlLoop(edge->src())) {
   1018           // This is one of the control edges added for control flow. There
   1019           // can be multiple such edges as the dest node may have multiple
   1020           // remote inputs. We keep track of the number of such edges.
   1021           control_flow_edge = edge;
   1022           ++num_control_flow_edges;
   1023         } else {
   1024           inputs.push_back(edge);
   1025         }
   1026       } else {
   1027         DCHECK(inputs[edge->dst_input()] == nullptr);
   1028         inputs[edge->dst_input()] = edge;
   1029         ++num_input_edges;
   1030       }
   1031     }
   1032 
   1033     if (num_input_edges != dst->num_inputs()) {
   1034       return errors::InvalidArgument("Incomplete graph, missing ",
   1035                                      (dst->num_inputs() - num_input_edges),
   1036                                      " inputs for ", dst->name());
   1037     }
   1038 
   1039     // Process in order so that all data edges are added as inputs to
   1040     // dst in Edge::dst_input() order.
   1041     for (const Edge* edge : inputs) {
   1042       const Node* src = edge->src();
   1043       if (!src->IsOp()) continue;  // Skip Sink/Source nodes.
   1044 
   1045       GraphDef* src_graph = &(*partitions)[opts.node_to_loc(src)];
   1046       if (src_graph == dst_graph && !NeedSameDeviceSendRecv(edge, g_info)) {
   1047         // Same partition and compatible memory types:
   1048         AddInput(dst_def, src->name(), edge->src_output());
   1049         if (edge->IsControlEdge() ||
   1050             !IsRefType(src->output_type(edge->src_output()))) {
   1051           ref_control_inputs.push_back(src->name());
   1052         }
   1053         continue;
   1054       }
   1055 
   1056       int64 send_start_time = 0;
   1057       int64 recv_start_time = 0;
   1058       if (opts.scheduling_for_recvs) {
   1059         status = GetNodeAttr(src->attrs(), "_start_time", &send_start_time);
   1060         if (errors::IsNotFound(status) && opts.need_to_record_start_times) {
   1061           send_start_time = opts.start_times[src->id()].value();
   1062         } else if (!status.ok()) {
   1063           return status;
   1064         }
   1065 
   1066         status = GetNodeAttr(dst->attrs(), "_start_time", &recv_start_time);
   1067         if (errors::IsNotFound(status) && opts.need_to_record_start_times) {
   1068           recv_start_time = opts.start_times[dst->id()].value();
   1069         } else if (!status.ok()) {
   1070           return status;
   1071         }
   1072       }
   1073 
   1074       // Check whether there is already a send/recv pair transferring
   1075       // the same tensor/control from the src to dst partition.
   1076       const bool on_host = IsDstInputOnHost(edge, g_info);
   1077       DupRecvKey key{src->id(), edge->src_output(), dst_graph, on_host};
   1078       auto iter = dup_recv.find(key);
   1079       if (iter != dup_recv.end()) {
   1080         // We found one. Reuse the data/control transferred already.
   1081         const string& recv_node_name = iter->second.recv->name();
   1082         if (edge->IsControlEdge()) {
   1083           AddInput(dst_def, recv_node_name, Graph::kControlSlot);
   1084         } else {
   1085           AddInput(dst_def, recv_node_name, 0);
   1086         }
   1087         ref_control_inputs.push_back(recv_node_name);
   1088 
   1089         // We want the start_time for the recv to be the smallest of the start
   1090         // times of it's consumers. So we update this whenever we use a recv,
   1091         // and write it out to the attribute at the end of the subroutine
   1092         if (iter->second.start_time > recv_start_time) {
   1093           iter->second.start_time = recv_start_time;
   1094         }
   1095         continue;
   1096       }
   1097 
   1098       NodeDefBuilder::NodeOut send_from;
   1099       if (edge->IsControlEdge()) {
   1100         // Insert a dummy const node that will generate a tiny
   1101         // data element to be sent from send to recv.
   1102         VLOG(1) << "Send/Recv control: " << src->assigned_device_name() << "["
   1103                 << src->name() << "] -> " << dst->assigned_device_name() << "["
   1104                 << dst->name() << "]";
   1105         NodeDef* dummy = AddDummyConst(opts, src_graph, edge, &status);
   1106         if (!status.ok()) return status;
   1107         // Set the start time for this dummy node.
   1108         if (opts.scheduling_for_recvs) {
   1109           AddNodeAttr("_start_time", send_start_time, dummy);
   1110         }
   1111         AddInput(dummy, src->name(), Graph::kControlSlot);
   1112         send_from.Reset(dummy->name(), 0, DT_FLOAT);
   1113       } else {
   1114         send_from.Reset(src->name(), edge->src_output(), EdgeType(edge));
   1115       }
   1116 
   1117       // Need to split edge by placing matching send/recv nodes on
   1118       // the src/dst sides of the edge.
   1119       NodeDef* send = AddSend(opts, g_info, src_graph, edge, send_from,
   1120                               send_start_time, &status);
   1121       if (!status.ok()) return status;
   1122 
   1123       NodeDef* real_recv = nullptr;
   1124       NodeDef* recv =
   1125           AddRecv(opts, g_info, dst_graph, edge, &real_recv, &status);
   1126       if (!status.ok()) return status;
   1127 
   1128       // Fix up the control flow edge.
   1129       // NOTE(yuanbyu): 'real_recv' must be the real recv node.
   1130       if (src_graph == dst_graph) {
   1131         // For same device send/recv, add a control edge from send to recv.
   1132         // This prevents the asynchronous recv kernel from being scheduled
   1133         // before the data is available.
   1134         AddInput(real_recv, send->name(), Graph::kControlSlot);
   1135       } else if (control_flow_edge != nullptr) {
   1136         // Redirect control edge to the real recv since this is not the same
   1137         // device send/recv.
   1138         --num_control_flow_edges;
   1139         AddInput(real_recv, control_flow_edge->src()->name(),
   1140                  Graph::kControlSlot);
   1141       }
   1142 
   1143       if (!edge->IsControlEdge() &&
   1144           IsRefType(src->output_type(edge->src_output()))) {
   1145         AddNodeAttr("_start_time", recv_start_time, recv);
   1146         if (real_recv != recv) {
   1147           AddNodeAttr("_start_time", recv_start_time, real_recv);
   1148         }
   1149         // If src is of ref type and the edge is not a control edge, dst has
   1150         // read semantics and therefore we must control the recv.
   1151         ref_recvs.push_back(real_recv);
   1152       } else {
   1153         // Memorize the send/recv pair, only if this is not a "ref" edge.
   1154         // NOTE(yuanbyu): Collapsing ref edges requires extreme care so
   1155         // for now we don't do it.
   1156         dup_recv[key] = {recv, real_recv, recv_start_time};
   1157         ref_control_inputs.push_back(recv->name());
   1158       }
   1159 
   1160       if (edge->IsControlEdge()) {
   1161         ++num_control;
   1162         AddInput(dst_def, recv->name(), Graph::kControlSlot);
   1163       } else {
   1164         ++num_data;
   1165         AddInput(dst_def, recv->name(), 0);
   1166       }
   1167     }
   1168 
   1169     // Add control edges from 'ref_control_inputs' to 'ref_recvs'.
   1170     // NOTE(yuanbyu): Adding these control edges should not introduce
   1171     // deadlocks. 'dst' has implicit "read" nodes that, when we split
   1172     // across devices, are made explicit; Retargeting the dependencies
   1173     // to 'dst' to those nodes would not introduce cycles if there isn't
   1174     // one before the transformation.
   1175     // NOTE(yuanbyu): This may impact performance because it defers the
   1176     // execution of recvs until all the other inputs become available.
   1177     AddReadControl(ref_recvs, ref_control_inputs);
   1178 
   1179     // Add back the control edges for control flow that are not used.
   1180     if (control_flow_edge != nullptr) {
   1181       for (int i = 0; i < num_control_flow_edges; ++i) {
   1182         AddInput(dst_def, control_flow_edge->src()->name(),
   1183                  Graph::kControlSlot);
   1184       }
   1185     }
   1186   }
   1187 
   1188   const FunctionLibraryDefinition* flib_def = opts.flib_def;
   1189   if (flib_def == nullptr) {
   1190     flib_def = &g->flib_def();
   1191   }
   1192 
   1193   // Set versions, function library and send/recv incarnation.
   1194   for (auto& it : *partitions) {
   1195     GraphDef* gdef = &it.second;
   1196     *gdef->mutable_versions() = g->versions();
   1197     // Prune unreachable functions from `flib_def` before adding them to `gdef`.
   1198     *gdef->mutable_library() = flib_def->ReachableDefinitions(*gdef).ToProto();
   1199 
   1200     // Traverse the graph to fill every send/recv op's incarnation
   1201     // information.
   1202     SetIncarnation(opts, gdef);
   1203   }
   1204 
   1205   // Set the start times for recvs at the very end.
   1206   if (opts.scheduling_for_recvs) {
   1207     for (auto& it : dup_recv) {
   1208       AddNodeAttr("_start_time", it.second.start_time, it.second.recv);
   1209       if (it.second.real_recv != it.second.recv) {
   1210         AddNodeAttr("_start_time", it.second.start_time, it.second.real_recv);
   1211       }
   1212     }
   1213   }
   1214 
   1215   VLOG(1) << "Added send/recv: controls=" << num_control
   1216           << ", data=" << num_data;
   1217   return Status::OK();
   1218 }
   1219 
   1220 }  // namespace tensorflow
   1221