Home | History | Annotate | Download | only in graph
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/graph/graph_partition.h"
     17 
     18 #include <deque>
     19 #include <queue>
     20 #include <unordered_map>
     21 #include <unordered_set>
     22 #include <utility>
     23 #include <vector>
     24 
     25 #include "tensorflow/core/framework/memory_types.h"
     26 #include "tensorflow/core/framework/node_def_builder.h"
     27 #include "tensorflow/core/framework/tensor.pb.h"
     28 #include "tensorflow/core/framework/types.h"
     29 #include "tensorflow/core/framework/versions.pb.h"
     30 #include "tensorflow/core/graph/algorithm.h"
     31 #include "tensorflow/core/graph/control_flow.h"
     32 #include "tensorflow/core/graph/costmodel.h"
     33 #include "tensorflow/core/graph/graph_def_builder.h"
     34 #include "tensorflow/core/graph/node_builder.h"
     35 #include "tensorflow/core/graph/tensor_id.h"
     36 #include "tensorflow/core/lib/core/errors.h"
     37 #include "tensorflow/core/lib/hash/hash.h"
     38 #include "tensorflow/core/platform/logging.h"
     39 #include "tensorflow/core/util/device_name_utils.h"
     40 
     41 namespace tensorflow {
     42 
     43 namespace {
     44 
     45 inline bool IsMerge(const NodeDef& node_def) {
     46   return node_def.op() == "Merge" || node_def.op() == "RefMerge";
     47 }
     48 
     49 inline bool IsNextIteration(const NodeDef& node_def) {
     50   return node_def.op() == "NextIteration" ||
     51          node_def.op() == "RefNextIteration";
     52 }
     53 
     54 struct DupRecvKey {
     55   int src_node_id;           // Edge's src node id
     56   int src_output_slot;       // Edge's src node output slot
     57   GraphDef* dst_graph;       // Edge's dst node is in this subgraph
     58   bool recv_output_on_host;  // The output of recv is on host
     59 };
     60 
     61 struct DupRecvKeyHash {
     62   size_t operator()(const DupRecvKey& k) const {
     63     size_t h = Hash64(reinterpret_cast<const char*>(&k.src_node_id),
     64                       sizeof(k.src_node_id), k.src_output_slot);
     65     h = Hash64(reinterpret_cast<const char*>(&k.dst_graph), sizeof(k.dst_graph),
     66                h);
     67     h = Hash64(reinterpret_cast<const char*>(&k.recv_output_on_host),
     68                sizeof(k.recv_output_on_host), h);
     69     return h;
     70   }
     71 };
     72 
     73 struct DupRecvKeyEq {
     74   bool operator()(const DupRecvKey& x, const DupRecvKey& y) const {
     75     return (x.src_node_id == y.src_node_id) &&
     76            (x.src_output_slot == y.src_output_slot) &&
     77            (x.dst_graph == y.dst_graph) &&
     78            (x.recv_output_on_host == y.recv_output_on_host);
     79   }
     80 };
     81 
     82 // struct used to store the recvs, so that start times can be properly updated
     83 struct RecvInfo {
     84   NodeDef* recv;
     85   NodeDef* real_recv;
     86   int64 start_time;
     87 };
     88 
     89 typedef std::unordered_map<DupRecvKey, RecvInfo, DupRecvKeyHash, DupRecvKeyEq>
     90     DupRecvTable;
     91 
     92 struct PairIntHash {
     93  public:
     94   std::size_t operator()(const std::pair<int, int>& x) const {
     95     return std::hash<int>()(x.first) ^ std::hash<int>()(x.second);
     96   }
     97 };
     98 // A map used to store memory types for the inputs/outputs of every node.
     99 // The key is a pair of ints consisting of a node id and input/output index.
    100 typedef std::unordered_map<std::pair<int, int>, MemoryType, PairIntHash>
    101     MemoryTypeMap;
    102 
    103 // We collect the following information about the graph before performing
    104 // graph partitioning.
    105 struct GraphInfo {
    106   std::vector<DeviceType> device_types;
    107   MemoryTypeMap input_types;
    108   MemoryTypeMap output_types;
    109   std::vector<ControlFlowInfo> cf_info;
    110 };
    111 
    112 DataType EdgeType(const Edge* e) {
    113   if (e->IsControlEdge()) {
    114     return DT_FLOAT;
    115   } else {
    116     return e->dst()->input_type(e->dst_input());
    117   }
    118 }
    119 
    120 // Return true iff we need to add the same device send/recv for 'edge'.
    121 bool NeedSameDeviceSendRecv(const Edge* edge, const GraphInfo& info) {
    122   if (edge->IsControlEdge()) {
    123     return false;
    124   }
    125 
    126   Node* src = edge->src();
    127   Node* dst = edge->dst();
    128   if (src->assigned_device_name() == dst->assigned_device_name()) {
    129     int src_port = edge->src_output();
    130     int dst_port = edge->dst_input();
    131     if (info.device_types[src->id()] != DEVICE_CPU) {
    132       auto src_it = info.output_types.find({src->id(), src_port});
    133       DCHECK(src_it != info.output_types.end());
    134       auto dst_it = info.input_types.find({dst->id(), dst_port});
    135       DCHECK(dst_it != info.input_types.end());
    136       return src_it->second != dst_it->second;
    137     }
    138   }
    139   return false;
    140 }
    141 
    142 // Return true iff (dst, dst_input) is specified on host memory.
    143 bool IsDstInputOnHost(const Edge* edge, const GraphInfo& info) {
    144   Node* dst = edge->dst();
    145   int dst_port = edge->dst_input();
    146   if (info.device_types[dst->id()] != DEVICE_CPU) {
    147     if (edge->IsControlEdge()) return false;
    148     auto dst_it = info.input_types.find({dst->id(), dst_port});
    149     DCHECK(dst_it != info.input_types.end());
    150     return dst_it->second == HOST_MEMORY;
    151   }
    152   return true;
    153 }
    154 
    155 // Add an input to dst that comes from the "src_slot" output of the
    156 // node named by "src_name".
    157 void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) {
    158   if (src_slot == Graph::kControlSlot) {
    159     dst->add_input(strings::StrCat("^", src_name));
    160   } else if (src_slot == 0) {
    161     dst->add_input(src_name.data(), src_name.size());
    162   } else {
    163     dst->add_input(strings::StrCat(src_name, ":", src_slot));
    164   }
    165 }
    166 
    167 // Add a control edge from each input to each recv.
    168 void AddReadControl(const std::vector<NodeDef*>& recvs,
    169                     const std::vector<string>& inputs) {
    170   for (NodeDef* recv : recvs) {
    171     for (const string& input : inputs) {
    172       recv->add_input(strings::StrCat("^", input));
    173     }
    174   }
    175 }
    176 
    177 void SetSendRecvAttrs(const PartitionOptions& opts, const Edge* edge,
    178                       NodeDefBuilder* builder) {
    179   builder->Attr("tensor_name",
    180                 strings::StrCat("edge_", edge->id(), "_", edge->src()->name()));
    181   builder->Attr("send_device", edge->src()->assigned_device_name());
    182   builder->Attr("send_device_incarnation",
    183                 static_cast<int64>(
    184                     opts.get_incarnation(edge->src()->assigned_device_name())));
    185   builder->Attr("recv_device", edge->dst()->assigned_device_name());
    186   builder->Attr("client_terminated", false);
    187 }
    188 
    189 NodeDef* AddSend(const PartitionOptions& opts, const GraphInfo& g_info,
    190                  GraphDef* gdef, const Edge* edge,
    191                  NodeDefBuilder::NodeOut send_from, int64 start_time,
    192                  Status* status) {
    193   const DataType dtype = send_from.data_type;
    194   const DataType cast_dtype = opts.should_cast ? opts.should_cast(edge) : dtype;
    195   const Node* src = edge->src();
    196   const int src_port = edge->src_output();
    197 
    198   // host_memory = true iff we need to use HostSend/HostCast.
    199   bool host_memory = false;
    200   if (!edge->IsControlEdge()) {
    201     auto src_it = g_info.output_types.find({src->id(), src_port});
    202     DCHECK(src_it != g_info.output_types.end());
    203     host_memory = (src_it->second == HOST_MEMORY);
    204   }
    205 
    206   // Add a cast node that casts dtype to cast_dtype.
    207   // NOTE(yuanbyu): Only cast for cross-device send/recv.
    208   if (dtype != cast_dtype && !NeedSameDeviceSendRecv(edge, g_info)) {
    209     const string cast_op = (host_memory) ? "_HostCast" : "Cast";
    210     NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op);
    211     cast_builder.Device(src->assigned_device_name()).Input(send_from);
    212     if (opts.scheduling_for_recvs) {
    213       cast_builder.Attr("_start_time", start_time);
    214     }
    215     cast_builder.Attr("DstT", cast_dtype);
    216     NodeDef* cast = gdef->add_node();
    217     *status = cast_builder.Finalize(cast);
    218     if (!status->ok()) return nullptr;
    219 
    220     // Connect the Send op to the cast.
    221     send_from.Reset(cast->name(), 0, cast_dtype);
    222   }
    223 
    224   // Add the send node.
    225   const string send_op = (host_memory) ? "_HostSend" : "_Send";
    226   NodeDefBuilder send_builder(opts.new_name(src->name()), send_op);
    227   SetSendRecvAttrs(opts, edge, &send_builder);
    228   send_builder.Device(src->assigned_device_name()).Input(send_from);
    229   if (opts.scheduling_for_recvs) {
    230     send_builder.Attr("_start_time", start_time);
    231   }
    232   NodeDef* send = gdef->add_node();
    233   *status = send_builder.Finalize(send);
    234   return send;
    235 }
    236 
    237 NodeDef* AddRecv(const PartitionOptions& opts, const GraphInfo& g_info,
    238                  GraphDef* gdef, const Edge* edge, NodeDef** real_recv,
    239                  Status* status) {
    240   const DataType dtype = EdgeType(edge);
    241   const Node* src = edge->src();
    242   const Node* dst = edge->dst();
    243   const int dst_port = edge->dst_input();
    244   DataType cast_dtype = dtype;
    245 
    246   // NOTE(yuanbyu): Only cast for cross-device send/recv.
    247   if (opts.should_cast && !NeedSameDeviceSendRecv(edge, g_info)) {
    248     cast_dtype = opts.should_cast(edge);
    249   }
    250 
    251   // host_memory = true iff we need to use HostRecv/HostCast.
    252   bool host_memory = false;
    253   if (!edge->IsControlEdge()) {
    254     auto dst_it = g_info.input_types.find({dst->id(), dst_port});
    255     DCHECK(dst_it != g_info.input_types.end());
    256     host_memory = (dst_it->second == HOST_MEMORY);
    257   }
    258 
    259   // Add the recv node.
    260   const string recv_op = (host_memory) ? "_HostRecv" : "_Recv";
    261   NodeDefBuilder recv_builder(opts.new_name(src->name()), recv_op);
    262   SetSendRecvAttrs(opts, edge, &recv_builder);
    263   recv_builder.Device(dst->assigned_device_name())
    264       .Attr("tensor_type", cast_dtype);
    265   NodeDef* recv = gdef->add_node();
    266   *status = recv_builder.Finalize(recv);
    267   if (!status->ok()) return nullptr;
    268   *real_recv = recv;
    269 
    270   // Add the cast node (from cast_dtype to dtype) or an Identity node.
    271   if (dtype != cast_dtype) {
    272     const string cast_op = (host_memory) ? "_HostCast" : "Cast";
    273     NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op);
    274     cast_builder.Attr("DstT", dtype);
    275     cast_builder.Device(dst->assigned_device_name())
    276         .Input(recv->name(), 0, cast_dtype);
    277     NodeDef* cast = gdef->add_node();
    278     *status = cast_builder.Finalize(cast);
    279     if (!status->ok()) return nullptr;
    280     return cast;
    281   } else if (edge->IsControlEdge()) {
    282     // An Identity is only needed for control edges.
    283     NodeDefBuilder id_builder(opts.new_name(src->name()), "Identity");
    284     id_builder.Device(dst->assigned_device_name())
    285         .Input(recv->name(), 0, cast_dtype);
    286     NodeDef* id = gdef->add_node();
    287     *status = id_builder.Finalize(id);
    288     if (!status->ok()) return nullptr;
    289     return id;
    290   } else {
    291     return recv;
    292   }
    293 }
    294 
    295 NodeDef* AddDummyConst(const PartitionOptions& opts, GraphDef* gdef,
    296                        const Edge* edge, Status* status) {
    297   const Node* src = edge->src();
    298   Tensor tensor(DT_FLOAT, TensorShape({0}));
    299   NodeDef* result = gdef->add_node();
    300   *status = NodeDefBuilder(opts.new_name(src->name()), "Const")
    301                 .Device(src->assigned_device_name())
    302                 .Attr("dtype", DT_FLOAT)
    303                 .Attr("value", tensor)
    304                 .Finalize(result);
    305   return result;
    306 }
    307 
    308 // A dummy node for scheduling.
    309 NodeDef* AddControlTrigger(const PartitionOptions& opts, GraphDef* gdef,
    310                            const string& assigned_device_name, int64 epoch,
    311                            int64 starttime, Status* status) {
    312   NodeDef* result = gdef->add_node();
    313   *status = NodeDefBuilder(opts.new_name(strings::StrCat("synch_", epoch)),
    314                            "ControlTrigger")
    315                 .Device(assigned_device_name)
    316                 .Attr("_start_time", starttime)
    317                 .Finalize(result);
    318   return result;
    319 }
    320 
    321 // Optimize colocation for control flow nodes. For cond, we want the
    322 // switch nodes to colocate with its data input. This is particularly
    323 // needed for conditional reading of a remote variable. It may also
    324 // reduce the number of devices involved in a loop.
    325 // TODO(yuanbyu): In this case, we don't respect the requested device in
    326 // the GraphDef for these nodes. Ideally, the placer would enforce the
    327 // colocation to render this unnecessary.
    328 void OptimizeControlFlowColocation(Graph* graph) {
    329   auto visit = [](Node* node) {
    330     if (IsSwitch(node)) {
    331       for (const Edge* in_edge : node->in_edges()) {
    332         if (in_edge->dst_input() == 0) {
    333           // Colocate with the data input.
    334           node->set_assigned_device_name(
    335               in_edge->src()->assigned_device_name());
    336           return;
    337         }
    338       }
    339     } else if (IsExit(node)) {
    340       for (const Edge* in_edge : node->in_edges()) {
    341         if (!in_edge->IsControlEdge()) {
    342           // Colocate with upstream node.
    343           node->set_assigned_device_name(
    344               in_edge->src()->assigned_device_name());
    345           return;
    346         }
    347       }
    348     } else {
    349       if ((IsEnter(node) && !IsRefType(node->input_type(0))) ||
    350           IsNextIteration(node)) {
    351         const Edge* data_edge = nullptr;
    352         for (const Edge* out_edge : node->out_edges()) {
    353           if (!out_edge->IsControlEdge()) {
    354             data_edge = out_edge;
    355             break;
    356           }
    357         }
    358         // Colocate with the first downstream data node.
    359         if (data_edge) {
    360           node->set_assigned_device_name(
    361               data_edge->dst()->assigned_device_name());
    362         }
    363       }
    364     }
    365   };
    366   DFS(*graph, visit, {});
    367 }
    368 
    369 string ControlLoopName(const string& name) {
    370   return strings::StrCat("_cloop", name);
    371 }
    372 
    373 bool IsControlLoop(const Node* node) {
    374   const string& name = node->name();
    375   return StringPiece(name).starts_with("_cloop");
    376 }
    377 
    378 // An enter node for control flow.
    379 Node* AddControlEnter(Graph* g, const string& node_name,
    380                       const string& device_name, const string& frame_name,
    381                       const int parallel_iterations, Status* status) {
    382   NodeBuilder node_builder(node_name, "Enter", g->op_registry());
    383   node_builder.Input({"dummy", 0, DT_FLOAT});
    384   node_builder.Attr("frame_name", frame_name);
    385   node_builder.Attr("parallel_iterations", parallel_iterations);
    386   Node* res_node;
    387   *status = node_builder.Finalize(g, &res_node);
    388   if (!status->ok()) return nullptr;
    389   res_node->set_assigned_device_name(device_name);
    390   return res_node;
    391 }
    392 
    393 // A merge node for control flow.
    394 Node* AddControlMerge(const string& in_name1, const string& in_name2, Graph* g,
    395                       const string& node_name, const string& device_name,
    396                       Status* status) {
    397   NodeBuilder node_builder(node_name, "Merge", g->op_registry());
    398   node_builder.Input({{in_name1, 0, DT_FLOAT}, {in_name2, 0, DT_FLOAT}});
    399   Node* res_node;
    400   *status = node_builder.Finalize(g, &res_node);
    401   if (!status->ok()) return nullptr;
    402   res_node->set_assigned_device_name(device_name);
    403   return res_node;
    404 }
    405 
    406 // A switch node for control flow.
    407 Node* AddControlSwitch(NodeBuilder::NodeOut input1, NodeBuilder::NodeOut input2,
    408                        const string& device_name,
    409                        const GraphDefBuilder::Options& bopts) {
    410   Node* res_node =
    411       ops::BinaryOp("Switch", std::move(input1), std::move(input2), bopts);
    412   if (bopts.HaveError()) return nullptr;
    413   res_node->set_assigned_device_name(device_name);
    414   return res_node;
    415 }
    416 
    417 // A next_iteration node for control flow.
    418 Node* AddControlNext(NodeBuilder::NodeOut input, const string& device_name,
    419                      const GraphDefBuilder::Options& bopts) {
    420   Node* res_node = ops::UnaryOp("NextIteration", std::move(input), bopts);
    421   if (bopts.HaveError()) return nullptr;
    422   res_node->set_assigned_device_name(device_name);
    423   return res_node;
    424 }
    425 
    426 Node* EmptyConst(const GraphDefBuilder::Options& options) {
    427   if (options.HaveError()) return nullptr;
    428   NodeBuilder node_builder(options.GetNameForOp("Const"), "Const",
    429                            options.op_registry());
    430   const DataType dt = DataTypeToEnum<float>::v();
    431   TensorProto proto;
    432   proto.set_dtype(dt);
    433   TensorShape empty_shape({0});
    434   empty_shape.AsProto(proto.mutable_tensor_shape());
    435   node_builder.Attr("dtype", dt).Attr("value", proto);
    436   return options.FinalizeBuilder(&node_builder);
    437 }
    438 
    439 // A dummy const node for control flow.
    440 Node* AddControlConst(const string& device_name,
    441                       const GraphDefBuilder::Options& bopts) {
    442   Node* res_node = EmptyConst(bopts);
    443   if (bopts.HaveError()) return nullptr;
    444   res_node->set_assigned_device_name(device_name);
    445   return res_node;
    446 }
    447 
    448 // A synthetic loop, made up of dummy nodes. It performs control-flow actions
    449 // on behalf of a leader on a different device.
    450 struct ControlLoop {
    451   Node* enter = nullptr;
    452   Node* merge = nullptr;
    453   Node* switch_node = nullptr;
    454 };
    455 
    456 // Add the control flow info of a new node added during partitioning.
    457 // The new node has the same control flow info as src.
    458 void AddControlFlowInfo(const Node* node, const Node* src,
    459                         std::vector<ControlFlowInfo>* cf_info) {
    460   int id = node->id();
    461   if (static_cast<size_t>(id) >= cf_info->size()) {
    462     cf_info->resize(id + 1);
    463   }
    464   const ControlFlowInfo& src_info = (*cf_info)[src->id()];
    465   ControlFlowInfo* info = &(*cf_info)[id];
    466   info->frame = src_info.frame;
    467   info->parent_frame = src_info.parent_frame;
    468   info->frame_name = src_info.frame_name;
    469 }
    470 
    471 // Constructs a control loop. Returns a struct containing the newly created
    472 // enter, merge, and switch nodes. The enter and merge nodes are used in the
    473 // recursive construction of control loops for nested frames (loops). The
    474 // switch node will be connected to the LoopCond node. The merge node will
    475 // be connected to all the recvs of the same frame by control edges when
    476 // the actual partitioning happens.
    477 Status AddControlLoop(const PartitionOptions& opts, Graph* g, const Node* src,
    478                       const Edge* edge, Node* loop_cond,
    479                       std::vector<ControlFlowInfo>* cf_info,
    480                       ControlLoop* loop) {
    481   Status status;
    482   GraphDefBuilder::Options bopts(g, &status);
    483   const ControlFlowInfo& src_info = (*cf_info)[src->id()];
    484   const string& device_name = edge->dst()->assigned_device_name();
    485   const string& frame_name = src_info.frame_name;
    486   int parallel_iterations;
    487   status = GetNodeAttr(src_info.frame->attrs(), "parallel_iterations",
    488                        &parallel_iterations);
    489   if (!status.ok()) return status;
    490 
    491   // The names of the nodes to be added.
    492   const string& enter_name =
    493       ControlLoopName(opts.new_name(edge->dst()->name()));
    494   const string& merge_name =
    495       ControlLoopName(opts.new_name(edge->dst()->name()));
    496   const string& switch_name =
    497       ControlLoopName(opts.new_name(edge->dst()->name()));
    498   const string& next_name = ControlLoopName(opts.new_name(edge->dst()->name()));
    499 
    500   // Add the nodes to the graph g.
    501   Node* enter = AddControlEnter(g, enter_name, device_name, frame_name,
    502                                 parallel_iterations, &status);
    503   if (!status.ok()) return status;
    504   Node* merge = AddControlMerge(enter_name, next_name, g, merge_name,
    505                                 device_name, &status);
    506   if (!status.ok()) return status;
    507   Node* switch_node = AddControlSwitch(merge, loop_cond, device_name,
    508                                        bopts.WithName(switch_name));
    509   if (!status.ok()) return status;
    510   Node* next =
    511       AddControlNext({switch_node, 1}, device_name, bopts.WithName(next_name));
    512   if (!status.ok()) return status;
    513 
    514   // Add control flow info for these new nodes:
    515   AddControlFlowInfo(enter, src, cf_info);
    516   AddControlFlowInfo(merge, src, cf_info);
    517   AddControlFlowInfo(switch_node, src, cf_info);
    518   AddControlFlowInfo(next, src, cf_info);
    519 
    520   // Add input edges for the newly created merge node:
    521   g->AddEdge(enter, 0, merge, 0);
    522   g->AddEdge(next, 0, merge, 1);
    523 
    524   loop->enter = enter;
    525   loop->merge = merge;
    526   loop->switch_node = switch_node;
    527   return Status::OK();
    528 }
    529 
    530 // Build memory and device type info for every node in the graph.
    531 // TODO(yuanbyu): It might be simpler if we convert MemoryType to
    532 // DeviceType for the inputs/outputs of each node.
    533 Status BuildMemoryDeviceInfo(const Graph& g, GraphInfo* info) {
    534   MemoryTypeVector input_memory_types;
    535   MemoryTypeVector output_memory_types;
    536 
    537   info->device_types.resize(g.num_node_ids(), DEVICE_CPU);
    538   for (const Node* node : g.op_nodes()) {
    539     DeviceNameUtils::ParsedName parsed;
    540     if (!DeviceNameUtils::ParseFullName(node->assigned_device_name(),
    541                                         &parsed)) {
    542       return errors::Internal("Malformed assigned device '",
    543                               node->assigned_device_name(), "'");
    544     }
    545 
    546     TF_RETURN_IF_ERROR(MemoryTypesForNode(
    547         g.op_registry(), DeviceType(parsed.type), node->def(),
    548         &input_memory_types, &output_memory_types));
    549 
    550     int node_id = node->id();
    551     info->device_types[node_id] = DeviceType(parsed.type);
    552     for (size_t i = 0; i < input_memory_types.size(); ++i) {
    553       info->input_types[{node_id, i}] = input_memory_types[i];
    554     }
    555     for (size_t i = 0; i < output_memory_types.size(); ++i) {
    556       info->output_types[{node_id, i}] = output_memory_types[i];
    557     }
    558   }
    559   return Status::OK();
    560 }
    561 
    562 const Node* InputFrame(const Node* node,
    563                        const std::vector<ControlFlowInfo>& cf_info) {
    564   // An input is in the same frame as the node except for Enter nodes.
    565   // The input of Enter is in the parent frame of the Enter node.
    566   if (!node->IsEnter()) {
    567     return node;
    568   }
    569   return cf_info[node->id()].parent_frame;
    570 }
    571 
    572 const Node* OutputFrame(const Node* node,
    573                         const std::vector<ControlFlowInfo>& cf_info) {
    574   // An output is in the same frame as the node except for Exit nodes.
    575   // The output of Exit is in the parent frame of the Exit node.
    576   if (!node->IsExit()) {
    577     return node;
    578   }
    579   return cf_info[node->id()].parent_frame;
    580 }
    581 
    582 // Each participating device needs to decide a) if there is a next iteration,
    583 // and b) if the loop terminates. We take the approach to encode this control
    584 // flow logic in the dataflow graph. There are at least two possible encodings.
    585 // In a completely decentralized encoding, the participants communicate peer
    586 // to peer. The other encoding uses a frame leader (the participant who owns
    587 // the pivot termination predicate) to broadcast the termination condition to
    588 // all the participants. For now we take the latter because it is simpler.
    589 //
    590 // TODO(yuanbyu): The correctness of this construction is rather subtle. I got
    591 // it wrong many times so it would be nice to write a proof to be sure.
    592 Status AddControlFlow(const PartitionOptions& opts, Graph* g,
    593                       GraphInfo* g_info) {
    594   Status status;
    595   GraphDefBuilder::Options bopts(g, &status);
    596   std::vector<ControlFlowInfo>& cf_info = g_info->cf_info;
    597 
    598   // Build the control flow info for every node.
    599   status = BuildControlFlowInfo(g, &cf_info);
    600   if (!status.ok()) return status;
    601 
    602   OptimizeControlFlowColocation(g);
    603 
    604   // The map from frames to their LoopCond nodes.
    605   std::unordered_map<string, Node*> frame_cond_map;
    606   int num_node_ids = g->num_node_ids();
    607   for (int i = 0; i < num_node_ids; ++i) {
    608     Node* node = g->FindNodeId(i);
    609     if (node == nullptr) continue;
    610 
    611     if (IsLoopCond(node)) {
    612       const string& frame_name = cf_info[node->id()].frame_name;
    613       DCHECK(!frame_name.empty());
    614       frame_cond_map[frame_name] = node;
    615     }
    616   }
    617 
    618   // Add all control loops for cross-device frames.
    619   // A control loop is added only when there is a cross-device edge in a
    620   // non-root frame. Nothing is added if there is no loops. We also don't
    621   // add anything for a frame that is completely local to a device. For
    622   // nested loops, we stack the control loops together by connecting
    623   // the merge of the outer loop to the enter of the inner loop.
    624   //
    625   // A map from <frame_name, device_name> to ControlLoop.
    626   std::unordered_map<string, ControlLoop> control_loops;
    627   int num_edge_ids = g->num_edge_ids();
    628   for (int i = 0; i < num_edge_ids; ++i) {
    629     const Edge* edge = g->FindEdgeId(i);
    630     if (edge == nullptr) continue;
    631 
    632     const Node* src = edge->src();
    633     const Node* dst = edge->dst();
    634     // Skip Sink/Source nodes.
    635     if (!src->IsOp() || !dst->IsOp()) continue;
    636 
    637     const string& src_device = src->assigned_device_name();
    638     const string& dst_device = dst->assigned_device_name();
    639     // Skip local edges.
    640     if (src_device == dst_device) continue;
    641 
    642     const Node* src_frame = OutputFrame(src, cf_info);
    643     const Node* dst_frame = InputFrame(dst, cf_info);
    644     const string& src_frame_name = cf_info[src_frame->id()].frame_name;
    645     const string& dst_frame_name = cf_info[dst_frame->id()].frame_name;
    646     // Skip if src and dst are not in the same frame.
    647     if (src_frame_name.empty() || src_frame_name != dst_frame_name) {
    648       continue;
    649     }
    650 
    651     // Add the control loop. Start by adding the control loop for the
    652     // current frame if needed, and recursively adding the control loop
    653     // for its outer frame when nested.
    654     ControlLoop child_loop;
    655     while (true) {
    656       const string& curr_frame_name = cf_info[src_frame->id()].frame_name;
    657       if (curr_frame_name.empty()) {
    658         // We have reached the root frame.
    659         if (child_loop.merge != nullptr) {
    660           const string& node_name = opts.new_name(edge->dst()->name());
    661           const string& device_name = edge->dst()->assigned_device_name();
    662           Node* const_node =
    663               AddControlConst(device_name, bopts.WithName(node_name));
    664           if (!status.ok()) return status;
    665           AddControlFlowInfo(const_node, src_frame, &cf_info);
    666           g->AddEdge(const_node, 0, child_loop.enter, 0);
    667         }
    668         break;
    669       }
    670 
    671       const string& cl_key = strings::StrCat(curr_frame_name, "$$", dst_device);
    672       auto it = control_loops.find(cl_key);
    673       if (it != control_loops.end()) {
    674         if (child_loop.enter != nullptr) {
    675           g->AddEdge(it->second.merge, 0, child_loop.enter, 0);
    676         }
    677         break;
    678       }
    679 
    680       // Get the frame's LoopCond.
    681       auto cond_it = frame_cond_map.find(curr_frame_name);
    682       if (cond_it == frame_cond_map.end()) {
    683         return errors::InvalidArgument(
    684             "A cross-device loop must have a pivot predicate: ",
    685             curr_frame_name);
    686       }
    687       Node* loop_cond = cond_it->second;
    688 
    689       // Add the control loop.
    690       ControlLoop curr_loop;
    691       status = AddControlLoop(opts, g, src_frame, edge, loop_cond, &cf_info,
    692                               &curr_loop);
    693       if (!status.ok()) return status;
    694       control_loops[cl_key] = curr_loop;
    695 
    696       if (child_loop.enter != nullptr) {
    697         // Connect the merge of the outer loop to the enter of the inner.
    698         g->AddEdge(curr_loop.merge, 0, child_loop.enter, 0);
    699       }
    700       src_frame = cf_info[src_frame->id()].parent_frame;
    701       child_loop = curr_loop;
    702     }
    703   }
    704 
    705   // For a cross-device edge, on the dst device, add a control edge
    706   // from the merge node of the control loop to dst. If a send/recv is
    707   // introduced for this edge in future partitioning, we delete this
    708   // control edge and add a new control edge from the merge to the recv.
    709   num_edge_ids = g->num_edge_ids();
    710   for (int i = 0; i < num_edge_ids; ++i) {
    711     const Edge* edge = g->FindEdgeId(i);
    712     if (edge == nullptr) continue;
    713 
    714     const Node* src = edge->src();
    715     Node* dst = edge->dst();
    716     // Skip Sink/Source nodes.
    717     if (!src->IsOp() || !dst->IsOp()) continue;
    718 
    719     const string& src_device = src->assigned_device_name();
    720     const string& dst_device = dst->assigned_device_name();
    721     if (src_device != dst_device) {
    722       const Node* src_frame = OutputFrame(src, cf_info);
    723       const Node* dst_frame = InputFrame(dst, cf_info);
    724       const string& src_frame_name = cf_info[src_frame->id()].frame_name;
    725       const string& dst_frame_name = cf_info[dst_frame->id()].frame_name;
    726       if (!src_frame_name.empty() && src_frame_name == dst_frame_name) {
    727         const string& cl_key =
    728             strings::StrCat(dst_frame_name, "$$", dst_device);
    729         ControlLoop loop = control_loops[cl_key];
    730         DCHECK(loop.enter != nullptr);
    731         // Note that we'll create multiple duplicate edges if dst has multiple
    732         // cross-device inputs. This is expected by the logic in Partition(), so
    733         // it can add control edges to the recv nodes once they're created.
    734         g->AddControlEdge(loop.merge, dst, /*allow_duplicates=*/true);
    735       }
    736     }
    737   }
    738   return Status::OK();
    739 }
    740 
    741 struct PriorityTopoSortNode {
    742   PriorityTopoSortNode(const NodeDef* n, int64 st) : node(n), start_time(st) {}
    743 
    744   const NodeDef* node;
    745   int64 start_time;
    746 };
    747 
    748 struct PriorityTopoSortNodeGreater {
    749   bool operator()(const PriorityTopoSortNode& left,
    750                   const PriorityTopoSortNode& right) {
    751     return left.start_time > right.start_time;
    752   }
    753 };
    754 
    755 }  // namespace
    756 
    757 // Returns in <nodes> the nodes that should participate in epoch-based recv
    758 // scheduling, along with their times; <nodes> is ordered by increasing
    759 // start_time. Returns in <node_to_start_time_out> the timing for all nodes,
    760 // even those not in <nodes>.
    761 //
    762 // Comparing to sorting on the node's start time only, this also processes the
    763 // nodes in dependency order, and updates start times to ensure a node's
    764 // start_time > the start time for all dependencies.
    765 //
    766 // Note that graph_partition_test.cc accesses this function for testing, even
    767 // though it's not declared in the header.
    768 Status TopologicalSortNodesWithTimePriority(
    769     const GraphDef* gdef, std::vector<std::pair<const NodeDef*, int64>>* nodes,
    770     std::unordered_map<const NodeDef*, int64>* node_to_start_time_out) {
    771   // Queue of nodes to process; lowest start time is returned first.
    772   std::priority_queue<PriorityTopoSortNode, std::vector<PriorityTopoSortNode>,
    773                       PriorityTopoSortNodeGreater>
    774       q;
    775   std::unordered_map<const NodeDef*, int64> node_to_start_time;
    776   auto enqueue = [&q, &node_to_start_time](const NodeDef* node) {
    777     const int64 start_time = node_to_start_time[node];
    778     q.emplace(node, start_time);
    779   };
    780 
    781   // Build initial structures, initial contents of queue.
    782   std::unordered_map<string, std::vector<const NodeDef*>> node_to_output_nodes;
    783   std::unordered_map<const NodeDef*, int> inputs_needed;
    784   for (int n = 0; n < gdef->node_size(); ++n) {
    785     const NodeDef* ndef = &gdef->node(n);
    786     for (int i = 0; i < ndef->input_size(); ++i) {
    787       node_to_output_nodes[ParseTensorName(ndef->input(i)).first.ToString()]
    788           .push_back(ndef);
    789     }
    790     int64 start_time;
    791     TF_RETURN_IF_ERROR(GetNodeAttr(*ndef, "_start_time", &start_time));
    792     node_to_start_time[ndef] = start_time;
    793     inputs_needed[ndef] = ndef->input_size();
    794     if (ndef->input_size() == 0) {
    795       enqueue(ndef);
    796     }
    797   }
    798 
    799   // Determine which merge nodes are parts of loops; these
    800   // need to happen in the traversal after all non-NextIteration inputs
    801   // are run.
    802   for (int n = 0; n < gdef->node_size(); ++n) {
    803     const NodeDef* ndef = &gdef->node(n);
    804     if (IsNextIteration(*ndef)) {
    805       for (const NodeDef* n : node_to_output_nodes[ndef->name()]) {
    806         if (IsMerge(*n)) {
    807           // n is a merge that is part of a loop structure.
    808           // It doesn't need to wait for this NextIteration loop
    809           // when doing the traversal.
    810           --inputs_needed[n];
    811         }
    812       }
    813     }
    814   }
    815 
    816   // Traverse.
    817   std::vector<std::pair<const NodeDef*, int64>> start_times;
    818   start_times.reserve(gdef->node_size());
    819   while (!q.empty()) {
    820     PriorityTopoSortNode cur = q.top();
    821     q.pop();
    822 
    823     start_times.emplace_back(cur.node, cur.start_time);
    824 
    825     for (const NodeDef* n : node_to_output_nodes[cur.node->name()]) {
    826       auto& output_start_time = node_to_start_time[n];
    827       if (output_start_time <= cur.start_time) {
    828         output_start_time = cur.start_time + 1;
    829       }
    830       if (--inputs_needed[n] == 0) {
    831         enqueue(n);
    832       }
    833     }
    834   }
    835 
    836   // Done.
    837   nodes->swap(start_times);
    838   node_to_start_time_out->swap(node_to_start_time);
    839   return Status::OK();
    840 }
    841 
    842 Status AddControlEdges(const PartitionOptions& opts,
    843                        std::unordered_map<string, GraphDef>* partitions) {
    844   Status status;
    845   // TODO(yuanbyu): Very naive for now. To be improved.
    846   const int num_epochs = 100;
    847   const int prefetch = 6;
    848 
    849   for (auto& part : *partitions) {
    850     GraphDef* gdef = &part.second;
    851     std::vector<std::pair<const NodeDef*, int64>> start_times;
    852     std::unordered_map<const NodeDef*, int64> node_to_start_time;
    853     status = TopologicalSortNodesWithTimePriority(gdef, &start_times,
    854                                                   &node_to_start_time);
    855     if (!status.ok()) {
    856       return status;
    857     }
    858 
    859     // Add a dummy node for every epoch, and add a control edge from the
    860     // "last" node in the preceding epoch to the dummy node.
    861     string device_name = gdef->node(0).device();
    862     int64 makespan = start_times.back().second;
    863     int64 resolution = (makespan / num_epochs) + 1;
    864 
    865     int i = 0;
    866     int j = 0;
    867     std::vector<NodeDef*> dummys;
    868     while (i < num_epochs && static_cast<size_t>(j) < start_times.size()) {
    869       if (i * resolution > start_times[j].second) {
    870         j++;
    871       } else {
    872         NodeDef* dummy = AddControlTrigger(opts, gdef, device_name, i,
    873                                            i * resolution, &status);
    874         if (!status.ok()) {
    875           return status;
    876         }
    877         dummys.push_back(dummy);
    878         if (j > 0) {
    879           string src_name = start_times[j - 1].first->name();
    880           AddInput(dummy, src_name, Graph::kControlSlot);
    881         }
    882         i++;
    883       }
    884     }
    885 
    886     // Finally, add the control edges to recvs.
    887     for (int n = 0; n < gdef->node_size(); ++n) {
    888       NodeDef* ndef = gdef->mutable_node(n);
    889       if (ndef->op() == "_Recv") {
    890         const int64 start_time = node_to_start_time[ndef];
    891         const int recv_epoch = start_time / resolution;
    892         if (recv_epoch >= prefetch) {
    893           NodeDef* dummy = dummys[recv_epoch - prefetch];
    894           AddInput(ndef, dummy->name(), Graph::kControlSlot);
    895         }
    896       }
    897     }
    898   }
    899   return Status::OK();
    900 }
    901 
    902 // If 'ndef' is a Send or Recv, fills its attr send_device_incarnation
    903 // if possible.
    904 void SetIncarnation(const PartitionOptions& opts, NodeDef* ndef) {
    905   StringPiece op(ndef->op());
    906   if (op != "_Send" && op != "_Recv") {
    907     // Not related to send/recv.
    908     return;
    909   }
    910   string send_device;
    911   if (!GetNodeAttr(*ndef, "send_device", &send_device).ok()) {
    912     // No known send_device. The runtime will detect it later.
    913     return;
    914   }
    915   int64 incarnation = PartitionOptions::kIllegalIncarnation;
    916   if (!GetNodeAttr(*ndef, "send_device_incarnation", &incarnation).ok() ||
    917       (incarnation == PartitionOptions::kIllegalIncarnation)) {
    918     incarnation = opts.get_incarnation(send_device);
    919     SetAttrValue(incarnation,
    920                  &((*ndef->mutable_attr())["send_device_incarnation"]));
    921   }
    922 }
    923 
    924 // Sets attribute send_device_incarnation of all Send/Recv nodes in
    925 // 'gdef', if possible.
    926 void SetIncarnation(const PartitionOptions& opts, GraphDef* gdef) {
    927   for (NodeDef& ndef : *gdef->mutable_node()) {
    928     SetIncarnation(opts, &ndef);
    929   }
    930   for (FunctionDef& fdef : *gdef->mutable_library()->mutable_function()) {
    931     for (NodeDef& ndef : *fdef.mutable_node_def()) {
    932       SetIncarnation(opts, &ndef);
    933     }
    934   }
    935 }
    936 
    937 Status Partition(const PartitionOptions& opts, Graph* g,
    938                  std::unordered_map<string, GraphDef>* partitions) {
    939   Status status;
    940   partitions->clear();
    941 
    942   GraphInfo g_info;
    943   if (!opts.control_flow_added) {
    944     // Add the "code" for distributed execution of control flow. Code is
    945     // added only for the frames that are placed on multiple devices. The
    946     // new graph is an equivalent transformation of the original graph and
    947     // has the property that it can be subsequently partitioned arbitrarily
    948     // (down to the level of individual device) for distributed execution.
    949     status = AddControlFlow(opts, g, &g_info);
    950     if (!status.ok()) return status;
    951   }
    952 
    953   // At this point, all the graph mutations have been done. Build memory
    954   // and device type info for every node and edge in the graph.
    955   status = BuildMemoryDeviceInfo(*g, &g_info);
    956   if (!status.ok()) return status;
    957 
    958   string dstp;
    959   std::vector<const Edge*> inputs;
    960   DupRecvTable dup_recv(3);
    961   // For a node dst, 'ref_recvs' remembers the recvs introduced by a ref
    962   // edge to dst. 'ref_control_inputs' remembers the inputs by a non-ref
    963   // edge to dst. We will add a control edge for every pair in
    964   // (ref_recvs x ref_control_inputs).
    965   std::vector<NodeDef*> ref_recvs;
    966   std::vector<string> ref_control_inputs;
    967 
    968   int32 num_data = 0;
    969   int32 num_control = 0;
    970   for (const Node* dst : g->op_nodes()) {
    971     dstp = opts.node_to_loc(dst);
    972     GraphDef* dst_graph = &(*partitions)[dstp];
    973     NodeDef* dst_def = dst_graph->add_node();
    974     *dst_def = dst->def();
    975     dst_def->set_device(dst->assigned_device_name());
    976     dst_def->clear_input();  // Inputs are filled below
    977     if (opts.need_to_record_start_times) {
    978       int64 start_time;
    979       status = GetNodeAttr(*dst_def, "_start_time", &start_time);
    980       if (errors::IsNotFound(status)) {
    981         start_time = opts.start_times[dst->id()].value();
    982         AddNodeAttr("_start_time", start_time, dst_def);
    983       } else if (!status.ok()) {
    984         return status;
    985       }
    986     }
    987 
    988     // Arrange the incoming edges to dst so that input[i] holds the
    989     // input flowing into slot numbered i. Trailing entries in input[]
    990     // hold control edges.
    991     inputs.clear();
    992     inputs.resize(dst->num_inputs(), nullptr);
    993     ref_recvs.clear();
    994     ref_control_inputs.clear();
    995     const Edge* control_flow_edge = nullptr;
    996     int32 num_control_flow_edges = 0;
    997     int32 num_input_edges = 0;
    998     for (const Edge* edge : dst->in_edges()) {
    999       if (edge->IsControlEdge()) {
   1000         if (IsMerge(edge->src()) && IsControlLoop(edge->src())) {
   1001           // This is one of the control edges added for control flow. There
   1002           // can be multiple such edges as the dest node may have multiple
   1003           // remote inputs. We keep track of the number of such edges.
   1004           control_flow_edge = edge;
   1005           ++num_control_flow_edges;
   1006         } else {
   1007           inputs.push_back(edge);
   1008         }
   1009       } else {
   1010         DCHECK(inputs[edge->dst_input()] == nullptr);
   1011         inputs[edge->dst_input()] = edge;
   1012         ++num_input_edges;
   1013       }
   1014     }
   1015 
   1016     if (num_input_edges != dst->num_inputs()) {
   1017       return errors::InvalidArgument("Incomplete graph, missing ",
   1018                                      (dst->num_inputs() - num_input_edges),
   1019                                      " inputs for ", dst->name());
   1020     }
   1021 
   1022     // Process in order so that all data edges are added as inputs to
   1023     // dst in Edge::dst_input() order.
   1024     for (const Edge* edge : inputs) {
   1025       const Node* src = edge->src();
   1026       if (!src->IsOp()) continue;  // Skip Sink/Source nodes.
   1027 
   1028       GraphDef* src_graph = &(*partitions)[opts.node_to_loc(src)];
   1029       if (src_graph == dst_graph && !NeedSameDeviceSendRecv(edge, g_info)) {
   1030         // Same partition and compatible memory types:
   1031         AddInput(dst_def, src->name(), edge->src_output());
   1032         if (edge->IsControlEdge() ||
   1033             !IsRefType(src->output_type(edge->src_output()))) {
   1034           ref_control_inputs.push_back(src->name());
   1035         }
   1036         continue;
   1037       }
   1038 
   1039       int64 send_start_time = 0;
   1040       int64 recv_start_time = 0;
   1041       if (opts.scheduling_for_recvs) {
   1042         status = GetNodeAttr(src->attrs(), "_start_time", &send_start_time);
   1043         if (errors::IsNotFound(status) && opts.need_to_record_start_times) {
   1044           send_start_time = opts.start_times[src->id()].value();
   1045         } else if (!status.ok()) {
   1046           return status;
   1047         }
   1048 
   1049         status = GetNodeAttr(dst->attrs(), "_start_time", &recv_start_time);
   1050         if (errors::IsNotFound(status) && opts.need_to_record_start_times) {
   1051           recv_start_time = opts.start_times[dst->id()].value();
   1052         } else if (!status.ok()) {
   1053           return status;
   1054         }
   1055       }
   1056 
   1057       // Check whether there is already a send/recv pair transferring
   1058       // the same tensor/control from the src to dst partition.
   1059       const bool on_host = IsDstInputOnHost(edge, g_info);
   1060       DupRecvKey key{src->id(), edge->src_output(), dst_graph, on_host};
   1061       auto iter = dup_recv.find(key);
   1062       if (iter != dup_recv.end()) {
   1063         // We found one. Reuse the data/control transferred already.
   1064         const string& recv_node_name = iter->second.recv->name();
   1065         if (edge->IsControlEdge()) {
   1066           AddInput(dst_def, recv_node_name, Graph::kControlSlot);
   1067         } else {
   1068           AddInput(dst_def, recv_node_name, 0);
   1069         }
   1070         ref_control_inputs.push_back(recv_node_name);
   1071 
   1072         // We want the start_time for the recv to be the smallest of the start
   1073         // times of it's consumers. So we update this whenever we use a recv,
   1074         // and write it out to the attribute at the end of the subroutine
   1075         if (iter->second.start_time > recv_start_time) {
   1076           iter->second.start_time = recv_start_time;
   1077         }
   1078         continue;
   1079       }
   1080 
   1081       NodeDefBuilder::NodeOut send_from;
   1082       if (edge->IsControlEdge()) {
   1083         // Insert a dummy const node that will generate a tiny
   1084         // data element to be sent from send to recv.
   1085         VLOG(1) << "Send/Recv control: " << src->assigned_device_name() << "["
   1086                 << src->name() << "] -> " << dst->assigned_device_name() << "["
   1087                 << dst->name() << "]";
   1088         NodeDef* dummy = AddDummyConst(opts, src_graph, edge, &status);
   1089         if (!status.ok()) return status;
   1090         // Set the start time for this dummy node.
   1091         if (opts.scheduling_for_recvs) {
   1092           AddNodeAttr("_start_time", send_start_time, dummy);
   1093         }
   1094         AddInput(dummy, src->name(), Graph::kControlSlot);
   1095         send_from.Reset(dummy->name(), 0, DT_FLOAT);
   1096       } else {
   1097         send_from.Reset(src->name(), edge->src_output(), EdgeType(edge));
   1098       }
   1099 
   1100       // Need to split edge by placing matching send/recv nodes on
   1101       // the src/dst sides of the edge.
   1102       NodeDef* send = AddSend(opts, g_info, src_graph, edge, send_from,
   1103                               send_start_time, &status);
   1104       if (!status.ok()) return status;
   1105 
   1106       NodeDef* real_recv = nullptr;
   1107       NodeDef* recv =
   1108           AddRecv(opts, g_info, dst_graph, edge, &real_recv, &status);
   1109       if (!status.ok()) return status;
   1110 
   1111       // Fix up the control flow edge.
   1112       // NOTE(yuanbyu): 'real_recv' must be the real recv node.
   1113       if (src_graph == dst_graph) {
   1114         // For same device send/recv, add a control edge from send to recv.
   1115         // This prevents the asynchronous recv kernel from being scheduled
   1116         // before the data is available.
   1117         AddInput(real_recv, send->name(), Graph::kControlSlot);
   1118       } else if (control_flow_edge != nullptr) {
   1119         // Redirect control edge to the real recv since this is not the same
   1120         // device send/recv.
   1121         --num_control_flow_edges;
   1122         AddInput(real_recv, control_flow_edge->src()->name(),
   1123                  Graph::kControlSlot);
   1124       }
   1125 
   1126       if (!edge->IsControlEdge() &&
   1127           IsRefType(src->output_type(edge->src_output()))) {
   1128         AddNodeAttr("_start_time", recv_start_time, recv);
   1129         if (real_recv != recv) {
   1130           AddNodeAttr("_start_time", recv_start_time, real_recv);
   1131         }
   1132         // If src is of ref type and the edge is not a control edge, dst has
   1133         // read semantics and therefore we must control the recv.
   1134         ref_recvs.push_back(real_recv);
   1135       } else {
   1136         // Memorize the send/recv pair, only if this is not a "ref" edge.
   1137         // NOTE(yuanbyu): Collapsing ref edges requires extreme care so
   1138         // for now we don't do it.
   1139         dup_recv[key] = {recv, real_recv, recv_start_time};
   1140         ref_control_inputs.push_back(recv->name());
   1141       }
   1142 
   1143       if (edge->IsControlEdge()) {
   1144         ++num_control;
   1145         AddInput(dst_def, recv->name(), Graph::kControlSlot);
   1146       } else {
   1147         ++num_data;
   1148         AddInput(dst_def, recv->name(), 0);
   1149       }
   1150     }
   1151 
   1152     // Add control edges from 'ref_control_inputs' to 'ref_recvs'.
   1153     // NOTE(yuanbyu): Adding these control edges should not introduce
   1154     // deadlocks. 'dst' has implicit "read" nodes that, when we split
   1155     // across devices, are made explicit; Retargeting the dependencies
   1156     // to 'dst' to those nodes would not introduce cycles if there isn't
   1157     // one before the transformation.
   1158     // NOTE(yuanbyu): This may impact performance because it defers the
   1159     // execution of recvs until all the other inputs become available.
   1160     AddReadControl(ref_recvs, ref_control_inputs);
   1161 
   1162     // Add back the control edges for control flow that are not used.
   1163     if (control_flow_edge != nullptr) {
   1164       for (int i = 0; i < num_control_flow_edges; ++i) {
   1165         AddInput(dst_def, control_flow_edge->src()->name(),
   1166                  Graph::kControlSlot);
   1167       }
   1168     }
   1169   }
   1170 
   1171   const FunctionLibraryDefinition* flib_def = opts.flib_def;
   1172   if (flib_def == nullptr) {
   1173     flib_def = &g->flib_def();
   1174   }
   1175 
   1176   // Set versions, function library and send/recv incarnation.
   1177   for (auto& it : *partitions) {
   1178     GraphDef* gdef = &it.second;
   1179     *gdef->mutable_versions() = g->versions();
   1180     *gdef->mutable_library() = flib_def->ToProto();
   1181 
   1182     // Traverse the graph to fill every send/recv op's incarnation
   1183     // information.
   1184     SetIncarnation(opts, gdef);
   1185   }
   1186 
   1187   // Set the start times for recvs at the very end.
   1188   if (opts.scheduling_for_recvs) {
   1189     for (auto& it : dup_recv) {
   1190       AddNodeAttr("_start_time", it.second.start_time, it.second.recv);
   1191       if (it.second.real_recv != it.second.recv) {
   1192         AddNodeAttr("_start_time", it.second.start_time, it.second.real_recv);
   1193       }
   1194     }
   1195   }
   1196 
   1197   VLOG(1) << "Added send/recv: controls=" << num_control
   1198           << ", data=" << num_data;
   1199   return Status::OK();
   1200 }
   1201 
   1202 }  // namespace tensorflow
   1203