1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/graph/graph_partition.h" 17 18 #include <deque> 19 #include <queue> 20 #include <unordered_map> 21 #include <unordered_set> 22 #include <utility> 23 #include <vector> 24 25 #include "absl/container/flat_hash_map.h" 26 #include "tensorflow/core/framework/function.h" 27 #include "tensorflow/core/framework/memory_types.h" 28 #include "tensorflow/core/framework/node_def_builder.h" 29 #include "tensorflow/core/framework/tensor.pb.h" 30 #include "tensorflow/core/framework/types.h" 31 #include "tensorflow/core/framework/versions.pb.h" 32 #include "tensorflow/core/graph/algorithm.h" 33 #include "tensorflow/core/graph/control_flow.h" 34 #include "tensorflow/core/graph/costmodel.h" 35 #include "tensorflow/core/graph/graph_def_builder.h" 36 #include "tensorflow/core/graph/node_builder.h" 37 #include "tensorflow/core/graph/tensor_id.h" 38 #include "tensorflow/core/lib/core/errors.h" 39 #include "tensorflow/core/lib/hash/hash.h" 40 #include "tensorflow/core/lib/strings/str_util.h" 41 #include "tensorflow/core/platform/logging.h" 42 #include "tensorflow/core/util/device_name_utils.h" 43 44 namespace tensorflow { 45 46 namespace { 47 48 inline bool IsMerge(const NodeDef& node_def) { 49 return node_def.op() == "Merge" || node_def.op() == "RefMerge"; 50 } 51 52 inline bool IsNextIteration(const NodeDef& node_def) { 53 return node_def.op() == "NextIteration" || 54 node_def.op() == "RefNextIteration"; 55 } 56 57 struct DupRecvKey { 58 int src_node_id; // Edge's src node id 59 int src_output_slot; // Edge's src node output slot 60 GraphDef* dst_graph; // Edge's dst node is in this subgraph 61 bool recv_output_on_host; // The output of recv is on host 62 63 template <typename H> 64 friend H AbslHashValue(H h, const DupRecvKey& c) { 65 return H::combine(std::move(h), c.src_node_id, c.src_output_slot, 66 reinterpret_cast<std::uintptr_t>(c.dst_graph), 67 c.recv_output_on_host); 68 } 69 70 friend bool operator==(const DupRecvKey& x, const DupRecvKey& y) { 71 return (x.src_node_id == y.src_node_id) && 72 (x.src_output_slot == y.src_output_slot) && 73 (x.dst_graph == y.dst_graph) && 74 (x.recv_output_on_host == y.recv_output_on_host); 75 } 76 }; 77 78 // struct used to store the recvs, so that start times can be properly updated 79 struct RecvInfo { 80 NodeDef* recv; 81 NodeDef* real_recv; 82 int64 start_time; 83 }; 84 85 typedef absl::flat_hash_map<DupRecvKey, RecvInfo> DupRecvTable; 86 87 // A map used to store memory types for the inputs/outputs of every node. 88 // The key is a pair of ints consisting of a node id and input/output index. 89 // TODO(power): migrate back to std::pair when absl::Hash is fixed for MSVC. 90 struct NodePort { 91 int node_id; 92 int index; 93 94 friend bool operator==(const NodePort& x, const NodePort& y) { 95 return x.node_id == y.node_id && x.index == y.index; 96 } 97 98 template <typename H> 99 friend H AbslHashValue(H h, const NodePort& c) { 100 return H::combine(std::move(h), c.node_id, c.index); 101 } 102 }; 103 104 typedef absl::flat_hash_map<NodePort, MemoryType> MemoryTypeMap; 105 106 // We collect the following information about the graph before performing 107 // graph partitioning. 108 struct GraphInfo { 109 std::vector<DeviceType> device_types; 110 MemoryTypeMap input_types; 111 MemoryTypeMap output_types; 112 std::vector<ControlFlowInfo> cf_info; 113 }; 114 115 DataType EdgeType(const Edge* e) { 116 if (e->IsControlEdge()) { 117 return DT_FLOAT; 118 } else { 119 return e->dst()->input_type(e->dst_input()); 120 } 121 } 122 123 // Return true iff we need to add the same device send/recv for 'edge'. 124 bool NeedSameDeviceSendRecv(const Edge* edge, const GraphInfo& info) { 125 if (edge->IsControlEdge()) { 126 return false; 127 } 128 129 const Node* src = edge->src(); 130 const Node* dst = edge->dst(); 131 if (src->assigned_device_name() == dst->assigned_device_name()) { 132 int src_port = edge->src_output(); 133 int dst_port = edge->dst_input(); 134 if (info.device_types[src->id()] != DEVICE_CPU) { 135 auto src_it = info.output_types.find({src->id(), src_port}); 136 DCHECK(src_it != info.output_types.end()); 137 auto dst_it = info.input_types.find({dst->id(), dst_port}); 138 DCHECK(dst_it != info.input_types.end()); 139 return src_it->second != dst_it->second; 140 } 141 } 142 return false; 143 } 144 145 // Return true iff (dst, dst_input) is specified on host memory. 146 bool IsDstInputOnHost(const Edge* edge, const GraphInfo& info) { 147 const Node* dst = edge->dst(); 148 int dst_port = edge->dst_input(); 149 if (info.device_types[dst->id()] != DEVICE_CPU) { 150 if (edge->IsControlEdge()) return false; 151 auto dst_it = info.input_types.find({dst->id(), dst_port}); 152 DCHECK(dst_it != info.input_types.end()); 153 return dst_it->second == HOST_MEMORY; 154 } 155 return true; 156 } 157 158 // Add an input to dst that comes from the "src_slot" output of the 159 // node named by "src_name". 160 void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) { 161 if (src_slot == Graph::kControlSlot) { 162 dst->add_input(strings::StrCat("^", src_name)); 163 } else if (src_slot == 0) { 164 dst->add_input(src_name.data(), src_name.size()); 165 } else { 166 dst->add_input(strings::StrCat(src_name, ":", src_slot)); 167 } 168 } 169 170 // Add a control edge from each input to each recv. 171 void AddReadControl(const std::vector<NodeDef*>& recvs, 172 const std::vector<string>& inputs) { 173 for (NodeDef* recv : recvs) { 174 for (const string& input : inputs) { 175 recv->add_input(strings::StrCat("^", input)); 176 } 177 } 178 } 179 180 void SetSendRecvAttrs(const PartitionOptions& opts, const Edge* edge, 181 NodeDefBuilder* builder) { 182 builder->Attr("tensor_name", 183 strings::StrCat("edge_", edge->id(), "_", edge->src()->name())); 184 builder->Attr("send_device", edge->src()->assigned_device_name()); 185 builder->Attr("send_device_incarnation", 186 static_cast<int64>( 187 opts.get_incarnation(edge->src()->assigned_device_name()))); 188 builder->Attr("recv_device", edge->dst()->assigned_device_name()); 189 builder->Attr("client_terminated", false); 190 } 191 192 NodeDef* AddSend(const PartitionOptions& opts, const GraphInfo& g_info, 193 GraphDef* gdef, const Edge* edge, 194 NodeDefBuilder::NodeOut send_from, int64 start_time, 195 Status* status) { 196 const DataType dtype = send_from.data_type; 197 const DataType cast_dtype = opts.should_cast ? opts.should_cast(edge) : dtype; 198 const Node* src = edge->src(); 199 const int src_port = edge->src_output(); 200 201 // host_memory = true iff we need to use HostSend/HostCast. 202 bool host_memory = false; 203 if (!edge->IsControlEdge()) { 204 auto src_it = g_info.output_types.find({src->id(), src_port}); 205 DCHECK(src_it != g_info.output_types.end()); 206 host_memory = (src_it->second == HOST_MEMORY); 207 } 208 209 // Add a cast node that casts dtype to cast_dtype. 210 // NOTE(yuanbyu): Only cast for cross-device send/recv. 211 if (dtype != cast_dtype && !NeedSameDeviceSendRecv(edge, g_info)) { 212 const string cast_op = (host_memory) ? "_HostCast" : "Cast"; 213 NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op, 214 NodeDebugInfo(*src)); 215 cast_builder.Device(src->assigned_device_name()).Input(send_from); 216 if (opts.scheduling_for_recvs) { 217 cast_builder.Attr("_start_time", start_time); 218 } 219 cast_builder.Attr("DstT", cast_dtype); 220 221 if (cast_dtype == DT_BFLOAT16) { 222 // the below attribute specifies that the cast to bfloat16 should use 223 // truncation. This is needed to retain legacy behavior when we change 224 // the default bfloat16 casts to use rounding instead of truncation 225 cast_builder.Attr("Truncate", true); 226 } 227 228 NodeDef* cast = gdef->add_node(); 229 *status = cast_builder.Finalize(cast); 230 if (!status->ok()) return nullptr; 231 232 // Connect the Send op to the cast. 233 send_from.Reset(cast->name(), 0, cast_dtype); 234 } 235 236 // Add the send node. 237 const string send_op = (host_memory) ? "_HostSend" : "_Send"; 238 NodeDefBuilder send_builder(opts.new_name(src->name()), send_op, 239 NodeDebugInfo(*src)); 240 SetSendRecvAttrs(opts, edge, &send_builder); 241 send_builder.Device(src->assigned_device_name()).Input(send_from); 242 if (opts.scheduling_for_recvs) { 243 send_builder.Attr("_start_time", start_time); 244 } 245 NodeDef* send = gdef->add_node(); 246 *status = send_builder.Finalize(send); 247 return send; 248 } 249 250 NodeDef* AddRecv(const PartitionOptions& opts, const GraphInfo& g_info, 251 GraphDef* gdef, const Edge* edge, NodeDef** real_recv, 252 Status* status) { 253 const DataType dtype = EdgeType(edge); 254 const Node* src = edge->src(); 255 const Node* dst = edge->dst(); 256 const int dst_port = edge->dst_input(); 257 DataType cast_dtype = dtype; 258 259 // NOTE(yuanbyu): Only cast for cross-device send/recv. 260 if (opts.should_cast && !NeedSameDeviceSendRecv(edge, g_info)) { 261 cast_dtype = opts.should_cast(edge); 262 } 263 264 // host_memory = true iff we need to use HostRecv/HostCast. 265 bool host_memory = false; 266 if (!edge->IsControlEdge()) { 267 auto dst_it = g_info.input_types.find({dst->id(), dst_port}); 268 DCHECK(dst_it != g_info.input_types.end()); 269 host_memory = (dst_it->second == HOST_MEMORY); 270 } 271 272 // Add the recv node. 273 const string recv_op = (host_memory) ? "_HostRecv" : "_Recv"; 274 NodeDefBuilder recv_builder(opts.new_name(src->name()), recv_op, 275 NodeDebugInfo(*src)); 276 SetSendRecvAttrs(opts, edge, &recv_builder); 277 recv_builder.Device(dst->assigned_device_name()) 278 .Attr("tensor_type", cast_dtype); 279 NodeDef* recv = gdef->add_node(); 280 *status = recv_builder.Finalize(recv); 281 if (!status->ok()) return nullptr; 282 *real_recv = recv; 283 284 // Add the cast node (from cast_dtype to dtype) or an Identity node. 285 if (dtype != cast_dtype) { 286 const string cast_op = (host_memory) ? "_HostCast" : "Cast"; 287 NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op, 288 NodeDebugInfo(*src)); 289 cast_builder.Attr("DstT", dtype); 290 cast_builder.Device(dst->assigned_device_name()) 291 .Input(recv->name(), 0, cast_dtype); 292 NodeDef* cast = gdef->add_node(); 293 *status = cast_builder.Finalize(cast); 294 if (!status->ok()) return nullptr; 295 return cast; 296 } else if (edge->IsControlEdge()) { 297 // An Identity is only needed for control edges. 298 NodeDefBuilder id_builder(opts.new_name(src->name()), "Identity", 299 NodeDebugInfo(*src)); 300 id_builder.Device(dst->assigned_device_name()) 301 .Input(recv->name(), 0, cast_dtype); 302 NodeDef* id = gdef->add_node(); 303 *status = id_builder.Finalize(id); 304 if (!status->ok()) return nullptr; 305 return id; 306 } else { 307 return recv; 308 } 309 } 310 311 NodeDef* AddDummyConst(const PartitionOptions& opts, GraphDef* gdef, 312 const Edge* edge, Status* status) { 313 const Node* src = edge->src(); 314 Tensor tensor(DT_FLOAT, TensorShape({0})); 315 NodeDef* result = gdef->add_node(); 316 *status = NodeDefBuilder(opts.new_name(src->name()), "Const") 317 .Device(src->assigned_device_name()) 318 .Attr("dtype", DT_FLOAT) 319 .Attr("value", tensor) 320 .Finalize(result); 321 return result; 322 } 323 324 // A dummy node for scheduling. 325 NodeDef* AddControlTrigger(const PartitionOptions& opts, GraphDef* gdef, 326 const string& assigned_device_name, int64 epoch, 327 int64 starttime, Status* status) { 328 NodeDef* result = gdef->add_node(); 329 *status = NodeDefBuilder(opts.new_name(strings::StrCat("synch_", epoch)), 330 "ControlTrigger") 331 .Device(assigned_device_name) 332 .Attr("_start_time", starttime) 333 .Finalize(result); 334 return result; 335 } 336 337 // Optimize colocation for control flow nodes. For cond, we want the 338 // switch nodes to colocate with its data input. This is particularly 339 // needed for conditional reading of a remote variable. It may also 340 // reduce the number of devices involved in a loop. 341 // TODO(yuanbyu): In this case, we don't respect the requested device in 342 // the GraphDef for these nodes. Ideally, the placer would enforce the 343 // colocation to render this unnecessary. 344 void OptimizeControlFlowColocation(Graph* graph) { 345 auto visit = [](Node* node) { 346 if (IsSwitch(node)) { 347 for (const Edge* in_edge : node->in_edges()) { 348 if (in_edge->dst_input() == 0) { 349 // Colocate with the data input. 350 node->set_assigned_device_name( 351 in_edge->src()->assigned_device_name()); 352 return; 353 } 354 } 355 } else if (IsExit(node)) { 356 for (const Edge* in_edge : node->in_edges()) { 357 if (!in_edge->IsControlEdge()) { 358 // Colocate with upstream node. 359 node->set_assigned_device_name( 360 in_edge->src()->assigned_device_name()); 361 return; 362 } 363 } 364 } else { 365 if ((IsEnter(node) && !IsRefType(node->input_type(0))) || 366 IsNextIteration(node)) { 367 const Edge* data_edge = nullptr; 368 for (const Edge* out_edge : node->out_edges()) { 369 if (!out_edge->IsControlEdge()) { 370 data_edge = out_edge; 371 break; 372 } 373 } 374 // Colocate with the first downstream data node. 375 if (data_edge) { 376 node->set_assigned_device_name( 377 data_edge->dst()->assigned_device_name()); 378 } 379 } 380 } 381 }; 382 DFS(*graph, visit, {}); 383 } 384 385 string ControlLoopName(const string& name) { 386 return strings::StrCat("_cloop", name); 387 } 388 389 bool IsControlLoop(const Node* node) { 390 const string& name = node->name(); 391 return str_util::StartsWith(name, "_cloop"); 392 } 393 394 // An enter node for control flow. 395 Node* AddControlEnter(Graph* g, const string& node_name, 396 const string& device_name, const string& frame_name, 397 const int parallel_iterations, Status* status) { 398 NodeBuilder node_builder(node_name, "Enter", g->op_registry()); 399 node_builder.Input({"dummy", 0, DT_FLOAT}); 400 node_builder.Attr("frame_name", frame_name); 401 node_builder.Attr("parallel_iterations", parallel_iterations); 402 Node* res_node; 403 *status = node_builder.Finalize(g, &res_node); 404 if (!status->ok()) return nullptr; 405 res_node->set_assigned_device_name(device_name); 406 return res_node; 407 } 408 409 // A merge node for control flow. 410 Node* AddControlMerge(const string& in_name1, const string& in_name2, Graph* g, 411 const string& node_name, const string& device_name, 412 Status* status) { 413 NodeBuilder node_builder(node_name, "Merge", g->op_registry()); 414 node_builder.Input({{in_name1, 0, DT_FLOAT}, {in_name2, 0, DT_FLOAT}}); 415 Node* res_node; 416 *status = node_builder.Finalize(g, &res_node); 417 if (!status->ok()) return nullptr; 418 res_node->set_assigned_device_name(device_name); 419 return res_node; 420 } 421 422 // A switch node for control flow. 423 Node* AddControlSwitch(NodeBuilder::NodeOut input1, NodeBuilder::NodeOut input2, 424 const string& device_name, 425 const GraphDefBuilder::Options& bopts) { 426 Node* res_node = 427 ops::BinaryOp("Switch", std::move(input1), std::move(input2), bopts); 428 if (bopts.HaveError()) return nullptr; 429 res_node->set_assigned_device_name(device_name); 430 return res_node; 431 } 432 433 // A next_iteration node for control flow. 434 Node* AddControlNext(NodeBuilder::NodeOut input, const string& device_name, 435 const GraphDefBuilder::Options& bopts) { 436 Node* res_node = ops::UnaryOp("NextIteration", std::move(input), bopts); 437 if (bopts.HaveError()) return nullptr; 438 res_node->set_assigned_device_name(device_name); 439 return res_node; 440 } 441 442 Node* EmptyConst(const GraphDefBuilder::Options& options) { 443 if (options.HaveError()) return nullptr; 444 NodeBuilder node_builder(options.GetNameForOp("Const"), "Const", 445 options.op_registry()); 446 const DataType dt = DataTypeToEnum<float>::v(); 447 TensorProto proto; 448 proto.set_dtype(dt); 449 TensorShape empty_shape({0}); 450 empty_shape.AsProto(proto.mutable_tensor_shape()); 451 node_builder.Attr("dtype", dt).Attr("value", proto); 452 return options.FinalizeBuilder(&node_builder); 453 } 454 455 // A dummy const node for control flow. 456 Node* AddControlConst(const string& device_name, 457 const GraphDefBuilder::Options& bopts) { 458 Node* res_node = EmptyConst(bopts); 459 if (bopts.HaveError()) return nullptr; 460 res_node->set_assigned_device_name(device_name); 461 return res_node; 462 } 463 464 // A synthetic loop, made up of dummy nodes. It performs control-flow actions 465 // on behalf of a leader on a different device. 466 struct ControlLoop { 467 Node* enter = nullptr; 468 Node* merge = nullptr; 469 Node* switch_node = nullptr; 470 }; 471 472 // Add the control flow info of a new node added during partitioning. 473 // The new node has the same control flow info as src. 474 void AddControlFlowInfo(const Node* node, const Node* src, 475 std::vector<ControlFlowInfo>* cf_info) { 476 int id = node->id(); 477 if (static_cast<size_t>(id) >= cf_info->size()) { 478 cf_info->resize(id + 1); 479 } 480 const ControlFlowInfo& src_info = (*cf_info)[src->id()]; 481 ControlFlowInfo* info = &(*cf_info)[id]; 482 info->frame = src_info.frame; 483 info->parent_frame = src_info.parent_frame; 484 info->frame_name = src_info.frame_name; 485 } 486 487 // Constructs a control loop. Returns a struct containing the newly created 488 // enter, merge, and switch nodes. The enter and merge nodes are used in the 489 // recursive construction of control loops for nested frames (loops). The 490 // switch node will be connected to the LoopCond node. The merge node will 491 // be connected to all the recvs of the same frame by control edges when 492 // the actual partitioning happens. 493 Status AddControlLoop(const PartitionOptions& opts, Graph* g, const Node* src, 494 const Edge* edge, Node* loop_cond, 495 std::vector<ControlFlowInfo>* cf_info, 496 ControlLoop* loop) { 497 Status status; 498 GraphDefBuilder::Options bopts(g, &status); 499 const ControlFlowInfo& src_info = (*cf_info)[src->id()]; 500 const string& device_name = edge->dst()->assigned_device_name(); 501 const string& frame_name = src_info.frame_name; 502 int parallel_iterations; 503 status = GetNodeAttr(src_info.frame->attrs(), "parallel_iterations", 504 ¶llel_iterations); 505 if (!status.ok()) return status; 506 507 // The names of the nodes to be added. 508 const string& enter_name = 509 ControlLoopName(opts.new_name(edge->dst()->name())); 510 const string& merge_name = 511 ControlLoopName(opts.new_name(edge->dst()->name())); 512 const string& switch_name = 513 ControlLoopName(opts.new_name(edge->dst()->name())); 514 const string& next_name = ControlLoopName(opts.new_name(edge->dst()->name())); 515 516 // Add the nodes to the graph g. 517 Node* enter = AddControlEnter(g, enter_name, device_name, frame_name, 518 parallel_iterations, &status); 519 if (!status.ok()) return status; 520 Node* merge = AddControlMerge(enter_name, next_name, g, merge_name, 521 device_name, &status); 522 if (!status.ok()) return status; 523 Node* switch_node = AddControlSwitch(merge, loop_cond, device_name, 524 bopts.WithName(switch_name)); 525 if (!status.ok()) return status; 526 Node* next = 527 AddControlNext({switch_node, 1}, device_name, bopts.WithName(next_name)); 528 if (!status.ok()) return status; 529 530 // Add control flow info for these new nodes: 531 AddControlFlowInfo(enter, src, cf_info); 532 AddControlFlowInfo(merge, src, cf_info); 533 AddControlFlowInfo(switch_node, src, cf_info); 534 AddControlFlowInfo(next, src, cf_info); 535 536 // Add input edges for the newly created merge node: 537 g->AddEdge(enter, 0, merge, 0); 538 g->AddEdge(next, 0, merge, 1); 539 540 loop->enter = enter; 541 loop->merge = merge; 542 loop->switch_node = switch_node; 543 return Status::OK(); 544 } 545 546 // Build memory and device type info for every node in the graph. 547 // TODO(yuanbyu): It might be simpler if we convert MemoryType to 548 // DeviceType for the inputs/outputs of each node. 549 Status BuildMemoryDeviceInfo(const Graph& g, GraphInfo* info) { 550 MemoryTypeVector input_memory_types; 551 MemoryTypeVector output_memory_types; 552 553 info->device_types.resize(g.num_node_ids(), DEVICE_CPU); 554 for (const Node* node : g.op_nodes()) { 555 DeviceNameUtils::ParsedName parsed; 556 if (!DeviceNameUtils::ParseFullName(node->assigned_device_name(), 557 &parsed)) { 558 return errors::Internal("Malformed assigned device '", 559 node->assigned_device_name(), "'"); 560 } 561 562 TF_RETURN_IF_ERROR(MemoryTypesForNode( 563 g.op_registry(), DeviceType(parsed.type), node->def(), 564 &input_memory_types, &output_memory_types)); 565 566 int node_id = node->id(); 567 info->device_types[node_id] = DeviceType(parsed.type); 568 for (int i = 0; i < input_memory_types.size(); ++i) { 569 info->input_types[{node_id, i}] = input_memory_types[i]; 570 } 571 for (int i = 0; i < output_memory_types.size(); ++i) { 572 info->output_types[{node_id, i}] = output_memory_types[i]; 573 } 574 } 575 return Status::OK(); 576 } 577 578 const Node* InputFrame(const Node* node, 579 const std::vector<ControlFlowInfo>& cf_info) { 580 // An input is in the same frame as the node except for Enter nodes. 581 // The input of Enter is in the parent frame of the Enter node. 582 if (!node->IsEnter()) { 583 return node; 584 } 585 return cf_info[node->id()].parent_frame; 586 } 587 588 const Node* OutputFrame(const Node* node, 589 const std::vector<ControlFlowInfo>& cf_info) { 590 // An output is in the same frame as the node except for Exit nodes. 591 // The output of Exit is in the parent frame of the Exit node. 592 if (!node->IsExit()) { 593 return node; 594 } 595 return cf_info[node->id()].parent_frame; 596 } 597 598 // Each participating device needs to decide a) if there is a next iteration, 599 // and b) if the loop terminates. We take the approach to encode this control 600 // flow logic in the dataflow graph. There are at least two possible encodings. 601 // In a completely decentralized encoding, the participants communicate peer 602 // to peer. The other encoding uses a frame leader (the participant who owns 603 // the pivot termination predicate) to broadcast the termination condition to 604 // all the participants. For now we take the latter because it is simpler. 605 // 606 // TODO(yuanbyu): The correctness of this construction is rather subtle. I got 607 // it wrong many times so it would be nice to write a proof to be sure. 608 Status AddControlFlow(const PartitionOptions& opts, Graph* g, 609 GraphInfo* g_info) { 610 Status status; 611 GraphDefBuilder::Options bopts(g, &status); 612 std::vector<ControlFlowInfo>& cf_info = g_info->cf_info; 613 614 // Build the control flow info for every node. 615 status = BuildControlFlowInfo(g, &cf_info); 616 if (!status.ok()) return status; 617 618 OptimizeControlFlowColocation(g); 619 620 // The map from frames to their LoopCond nodes. 621 std::unordered_map<string, Node*> frame_cond_map; 622 int num_node_ids = g->num_node_ids(); 623 for (int i = 0; i < num_node_ids; ++i) { 624 Node* node = g->FindNodeId(i); 625 if (node == nullptr) continue; 626 627 if (IsLoopCond(node)) { 628 const string& frame_name = cf_info[node->id()].frame_name; 629 DCHECK(!frame_name.empty()); 630 frame_cond_map[frame_name] = node; 631 } 632 } 633 634 // Add all control loops for cross-device frames. 635 // A control loop is added only when there is a cross-device edge in a 636 // non-root frame. Nothing is added if there is no loops. We also don't 637 // add anything for a frame that is completely local to a device. For 638 // nested loops, we stack the control loops together by connecting 639 // the merge of the outer loop to the enter of the inner loop. 640 // 641 // A map from <frame_name, device_name> to ControlLoop. 642 std::unordered_map<string, ControlLoop> control_loops; 643 int num_edge_ids = g->num_edge_ids(); 644 for (int i = 0; i < num_edge_ids; ++i) { 645 const Edge* edge = g->FindEdgeId(i); 646 if (edge == nullptr) continue; 647 648 const Node* src = edge->src(); 649 const Node* dst = edge->dst(); 650 // Skip Sink/Source nodes. 651 if (!src->IsOp() || !dst->IsOp()) continue; 652 653 const string& src_device = src->assigned_device_name(); 654 const string& dst_device = dst->assigned_device_name(); 655 // Skip local edges. 656 if (src_device == dst_device) continue; 657 658 const Node* src_frame = OutputFrame(src, cf_info); 659 const Node* dst_frame = InputFrame(dst, cf_info); 660 const string& src_frame_name = cf_info[src_frame->id()].frame_name; 661 const string& dst_frame_name = cf_info[dst_frame->id()].frame_name; 662 // Skip if src and dst are not in the same frame. 663 if (src_frame_name.empty() || src_frame_name != dst_frame_name) { 664 continue; 665 } 666 667 // Add the control loop. Start by adding the control loop for the 668 // current frame if needed, and recursively adding the control loop 669 // for its outer frame when nested. 670 ControlLoop child_loop; 671 while (true) { 672 const string& curr_frame_name = cf_info[src_frame->id()].frame_name; 673 if (curr_frame_name.empty()) { 674 // We have reached the root frame. 675 if (child_loop.merge != nullptr) { 676 const string& node_name = opts.new_name(edge->dst()->name()); 677 const string& device_name = edge->dst()->assigned_device_name(); 678 Node* const_node = 679 AddControlConst(device_name, bopts.WithName(node_name)); 680 if (!status.ok()) return status; 681 AddControlFlowInfo(const_node, src_frame, &cf_info); 682 g->AddEdge(const_node, 0, child_loop.enter, 0); 683 } 684 break; 685 } 686 687 const string& cl_key = strings::StrCat(curr_frame_name, "$$", dst_device); 688 auto it = control_loops.find(cl_key); 689 if (it != control_loops.end()) { 690 if (child_loop.enter != nullptr) { 691 g->AddEdge(it->second.merge, 0, child_loop.enter, 0); 692 } 693 break; 694 } 695 696 // Get the frame's LoopCond. 697 auto cond_it = frame_cond_map.find(curr_frame_name); 698 if (cond_it == frame_cond_map.end()) { 699 return errors::InvalidArgument( 700 "A cross-device loop must have a pivot predicate: ", 701 curr_frame_name); 702 } 703 Node* loop_cond = cond_it->second; 704 705 // Add the control loop. 706 ControlLoop curr_loop; 707 status = AddControlLoop(opts, g, src_frame, edge, loop_cond, &cf_info, 708 &curr_loop); 709 if (!status.ok()) return status; 710 control_loops[cl_key] = curr_loop; 711 712 if (child_loop.enter != nullptr) { 713 // Connect the merge of the outer loop to the enter of the inner. 714 g->AddEdge(curr_loop.merge, 0, child_loop.enter, 0); 715 } 716 src_frame = cf_info[src_frame->id()].parent_frame; 717 child_loop = curr_loop; 718 } 719 } 720 721 // For a cross-device edge, on the dst device, add a control edge 722 // from the merge node of the control loop to dst. If a send/recv is 723 // introduced for this edge in future partitioning, we delete this 724 // control edge and add a new control edge from the merge to the recv. 725 num_edge_ids = g->num_edge_ids(); 726 for (int i = 0; i < num_edge_ids; ++i) { 727 const Edge* edge = g->FindEdgeId(i); 728 if (edge == nullptr) continue; 729 730 const Node* src = edge->src(); 731 Node* dst = edge->dst(); 732 // Skip Sink/Source nodes. 733 if (!src->IsOp() || !dst->IsOp()) continue; 734 735 const string& src_device = src->assigned_device_name(); 736 const string& dst_device = dst->assigned_device_name(); 737 if (src_device != dst_device) { 738 const Node* src_frame = OutputFrame(src, cf_info); 739 const Node* dst_frame = InputFrame(dst, cf_info); 740 const string& src_frame_name = cf_info[src_frame->id()].frame_name; 741 const string& dst_frame_name = cf_info[dst_frame->id()].frame_name; 742 if (!src_frame_name.empty() && src_frame_name == dst_frame_name) { 743 const string& cl_key = 744 strings::StrCat(dst_frame_name, "$$", dst_device); 745 ControlLoop loop = control_loops[cl_key]; 746 DCHECK(loop.enter != nullptr); 747 // Note that we'll create multiple duplicate edges if dst has multiple 748 // cross-device inputs. This is expected by the logic in Partition(), so 749 // it can add control edges to the recv nodes once they're created. 750 g->AddControlEdge(loop.merge, dst, /*allow_duplicates=*/true); 751 } 752 } 753 } 754 return Status::OK(); 755 } 756 757 struct PriorityTopoSortNode { 758 PriorityTopoSortNode(const NodeDef* n, int64 st) : node(n), start_time(st) {} 759 760 const NodeDef* node; 761 int64 start_time; 762 }; 763 764 struct PriorityTopoSortNodeGreater { 765 bool operator()(const PriorityTopoSortNode& left, 766 const PriorityTopoSortNode& right) { 767 return left.start_time > right.start_time; 768 } 769 }; 770 771 } // namespace 772 773 // Returns in <nodes> the nodes that should participate in epoch-based recv 774 // scheduling, along with their times; <nodes> is ordered by increasing 775 // start_time. Returns in <node_to_start_time_out> the timing for all nodes, 776 // even those not in <nodes>. 777 // 778 // Comparing to sorting on the node's start time only, this also processes the 779 // nodes in dependency order, and updates start times to ensure a node's 780 // start_time > the start time for all dependencies. 781 // 782 // Note that graph_partition_test.cc accesses this function for testing, even 783 // though it's not declared in the header. 784 Status TopologicalSortNodesWithTimePriority( 785 const GraphDef* gdef, std::vector<std::pair<const NodeDef*, int64>>* nodes, 786 std::unordered_map<const NodeDef*, int64>* node_to_start_time_out) { 787 // Queue of nodes to process; lowest start time is returned first. 788 std::priority_queue<PriorityTopoSortNode, std::vector<PriorityTopoSortNode>, 789 PriorityTopoSortNodeGreater> 790 q; 791 std::unordered_map<const NodeDef*, int64> node_to_start_time; 792 auto enqueue = [&q, &node_to_start_time](const NodeDef* node) { 793 const int64 start_time = node_to_start_time[node]; 794 q.emplace(node, start_time); 795 }; 796 797 // Build initial structures, initial contents of queue. 798 std::unordered_map<string, std::vector<const NodeDef*>> node_to_output_nodes; 799 std::unordered_map<const NodeDef*, int> inputs_needed; 800 for (int n = 0; n < gdef->node_size(); ++n) { 801 const NodeDef* ndef = &gdef->node(n); 802 for (int i = 0; i < ndef->input_size(); ++i) { 803 node_to_output_nodes[string(ParseTensorName(ndef->input(i)).first)] 804 .push_back(ndef); 805 } 806 int64 start_time; 807 TF_RETURN_IF_ERROR(GetNodeAttr(*ndef, "_start_time", &start_time)); 808 node_to_start_time[ndef] = start_time; 809 inputs_needed[ndef] = ndef->input_size(); 810 if (ndef->input_size() == 0) { 811 enqueue(ndef); 812 } 813 } 814 815 // Determine which merge nodes are parts of loops; these 816 // need to happen in the traversal after all non-NextIteration inputs 817 // are run. 818 for (int n = 0; n < gdef->node_size(); ++n) { 819 const NodeDef* ndef = &gdef->node(n); 820 if (IsNextIteration(*ndef)) { 821 for (const NodeDef* n : node_to_output_nodes[ndef->name()]) { 822 if (IsMerge(*n)) { 823 // n is a merge that is part of a loop structure. 824 // It doesn't need to wait for this NextIteration loop 825 // when doing the traversal. 826 --inputs_needed[n]; 827 } 828 } 829 } 830 } 831 832 // Traverse. 833 std::vector<std::pair<const NodeDef*, int64>> start_times; 834 start_times.reserve(gdef->node_size()); 835 while (!q.empty()) { 836 PriorityTopoSortNode cur = q.top(); 837 q.pop(); 838 839 start_times.emplace_back(cur.node, cur.start_time); 840 841 for (const NodeDef* n : node_to_output_nodes[cur.node->name()]) { 842 auto& output_start_time = node_to_start_time[n]; 843 if (output_start_time <= cur.start_time) { 844 output_start_time = cur.start_time + 1; 845 } 846 if (--inputs_needed[n] == 0) { 847 enqueue(n); 848 } 849 } 850 } 851 852 // Done. 853 nodes->swap(start_times); 854 node_to_start_time_out->swap(node_to_start_time); 855 return Status::OK(); 856 } 857 858 Status AddControlEdges(const PartitionOptions& opts, 859 std::unordered_map<string, GraphDef>* partitions) { 860 Status status; 861 // TODO(yuanbyu): Very naive for now. To be improved. 862 const int num_epochs = 100; 863 const int prefetch = 6; 864 865 for (auto& part : *partitions) { 866 GraphDef* gdef = &part.second; 867 std::vector<std::pair<const NodeDef*, int64>> start_times; 868 std::unordered_map<const NodeDef*, int64> node_to_start_time; 869 status = TopologicalSortNodesWithTimePriority(gdef, &start_times, 870 &node_to_start_time); 871 if (!status.ok()) { 872 return status; 873 } 874 875 // Add a dummy node for every epoch, and add a control edge from the 876 // "last" node in the preceding epoch to the dummy node. 877 string device_name = gdef->node(0).device(); 878 int64 makespan = start_times.back().second; 879 int64 resolution = (makespan / num_epochs) + 1; 880 881 int i = 0; 882 int j = 0; 883 std::vector<NodeDef*> dummys; 884 while (i < num_epochs && static_cast<size_t>(j) < start_times.size()) { 885 if (i * resolution > start_times[j].second) { 886 j++; 887 } else { 888 NodeDef* dummy = AddControlTrigger(opts, gdef, device_name, i, 889 i * resolution, &status); 890 if (!status.ok()) { 891 return status; 892 } 893 dummys.push_back(dummy); 894 if (j > 0) { 895 string src_name = start_times[j - 1].first->name(); 896 AddInput(dummy, src_name, Graph::kControlSlot); 897 } 898 i++; 899 } 900 } 901 902 // Finally, add the control edges to recvs. 903 for (int n = 0; n < gdef->node_size(); ++n) { 904 NodeDef* ndef = gdef->mutable_node(n); 905 if (ndef->op() == "_Recv") { 906 const int64 start_time = node_to_start_time[ndef]; 907 const int recv_epoch = start_time / resolution; 908 if (recv_epoch >= prefetch) { 909 NodeDef* dummy = dummys[recv_epoch - prefetch]; 910 AddInput(ndef, dummy->name(), Graph::kControlSlot); 911 } 912 } 913 } 914 } 915 return Status::OK(); 916 } 917 918 // If 'ndef' is a Send or Recv, fills its attr send_device_incarnation 919 // if possible. 920 void SetIncarnation(const PartitionOptions& opts, NodeDef* ndef) { 921 StringPiece op(ndef->op()); 922 if (op != "_Send" && op != "_Recv") { 923 // Not related to send/recv. 924 return; 925 } 926 string send_device; 927 if (!GetNodeAttr(*ndef, "send_device", &send_device).ok()) { 928 // No known send_device. The runtime will detect it later. 929 return; 930 } 931 int64 incarnation = PartitionOptions::kIllegalIncarnation; 932 if (!GetNodeAttr(*ndef, "send_device_incarnation", &incarnation).ok() || 933 (incarnation == PartitionOptions::kIllegalIncarnation)) { 934 incarnation = opts.get_incarnation(send_device); 935 SetAttrValue(incarnation, 936 &((*ndef->mutable_attr())["send_device_incarnation"])); 937 } 938 } 939 940 // Sets attribute send_device_incarnation of all Send/Recv nodes in 941 // 'gdef', if possible. 942 void SetIncarnation(const PartitionOptions& opts, GraphDef* gdef) { 943 for (NodeDef& ndef : *gdef->mutable_node()) { 944 SetIncarnation(opts, &ndef); 945 } 946 for (FunctionDef& fdef : *gdef->mutable_library()->mutable_function()) { 947 for (NodeDef& ndef : *fdef.mutable_node_def()) { 948 SetIncarnation(opts, &ndef); 949 } 950 } 951 } 952 953 Status Partition(const PartitionOptions& opts, Graph* g, 954 std::unordered_map<string, GraphDef>* partitions) { 955 Status status; 956 partitions->clear(); 957 958 GraphInfo g_info; 959 if (!opts.control_flow_added) { 960 // Add the "code" for distributed execution of control flow. Code is 961 // added only for the frames that are placed on multiple devices. The 962 // new graph is an equivalent transformation of the original graph and 963 // has the property that it can be subsequently partitioned arbitrarily 964 // (down to the level of individual device) for distributed execution. 965 status = AddControlFlow(opts, g, &g_info); 966 if (!status.ok()) return status; 967 } 968 969 // At this point, all the graph mutations have been done. Build memory 970 // and device type info for every node and edge in the graph. 971 status = BuildMemoryDeviceInfo(*g, &g_info); 972 if (!status.ok()) return status; 973 974 string dstp; 975 std::vector<const Edge*> inputs; 976 DupRecvTable dup_recv(3); 977 // For a node dst, 'ref_recvs' remembers the recvs introduced by a ref 978 // edge to dst. 'ref_control_inputs' remembers the inputs by a non-ref 979 // edge to dst. We will add a control edge for every pair in 980 // (ref_recvs x ref_control_inputs). 981 std::vector<NodeDef*> ref_recvs; 982 std::vector<string> ref_control_inputs; 983 984 int32 num_data = 0; 985 int32 num_control = 0; 986 for (const Node* dst : g->op_nodes()) { 987 dstp = opts.node_to_loc(dst); 988 GraphDef* dst_graph = &(*partitions)[dstp]; 989 NodeDef* dst_def = dst_graph->add_node(); 990 *dst_def = dst->def(); 991 MergeDebugInfo(NodeDebugInfo(dst->def()), dst_def); 992 dst_def->set_device(dst->assigned_device_name()); 993 dst_def->clear_input(); // Inputs are filled below 994 if (opts.need_to_record_start_times) { 995 int64 start_time; 996 status = GetNodeAttr(*dst_def, "_start_time", &start_time); 997 if (errors::IsNotFound(status)) { 998 start_time = opts.start_times[dst->id()].value(); 999 AddNodeAttr("_start_time", start_time, dst_def); 1000 } else if (!status.ok()) { 1001 return status; 1002 } 1003 } 1004 1005 // Arrange the incoming edges to dst so that input[i] holds the 1006 // input flowing into slot numbered i. Trailing entries in input[] 1007 // hold control edges. 1008 inputs.clear(); 1009 inputs.resize(dst->num_inputs(), nullptr); 1010 ref_recvs.clear(); 1011 ref_control_inputs.clear(); 1012 const Edge* control_flow_edge = nullptr; 1013 int32 num_control_flow_edges = 0; 1014 int32 num_input_edges = 0; 1015 for (const Edge* edge : dst->in_edges()) { 1016 if (edge->IsControlEdge()) { 1017 if (IsMerge(edge->src()) && IsControlLoop(edge->src())) { 1018 // This is one of the control edges added for control flow. There 1019 // can be multiple such edges as the dest node may have multiple 1020 // remote inputs. We keep track of the number of such edges. 1021 control_flow_edge = edge; 1022 ++num_control_flow_edges; 1023 } else { 1024 inputs.push_back(edge); 1025 } 1026 } else { 1027 DCHECK(inputs[edge->dst_input()] == nullptr); 1028 inputs[edge->dst_input()] = edge; 1029 ++num_input_edges; 1030 } 1031 } 1032 1033 if (num_input_edges != dst->num_inputs()) { 1034 return errors::InvalidArgument("Incomplete graph, missing ", 1035 (dst->num_inputs() - num_input_edges), 1036 " inputs for ", dst->name()); 1037 } 1038 1039 // Process in order so that all data edges are added as inputs to 1040 // dst in Edge::dst_input() order. 1041 for (const Edge* edge : inputs) { 1042 const Node* src = edge->src(); 1043 if (!src->IsOp()) continue; // Skip Sink/Source nodes. 1044 1045 GraphDef* src_graph = &(*partitions)[opts.node_to_loc(src)]; 1046 if (src_graph == dst_graph && !NeedSameDeviceSendRecv(edge, g_info)) { 1047 // Same partition and compatible memory types: 1048 AddInput(dst_def, src->name(), edge->src_output()); 1049 if (edge->IsControlEdge() || 1050 !IsRefType(src->output_type(edge->src_output()))) { 1051 ref_control_inputs.push_back(src->name()); 1052 } 1053 continue; 1054 } 1055 1056 int64 send_start_time = 0; 1057 int64 recv_start_time = 0; 1058 if (opts.scheduling_for_recvs) { 1059 status = GetNodeAttr(src->attrs(), "_start_time", &send_start_time); 1060 if (errors::IsNotFound(status) && opts.need_to_record_start_times) { 1061 send_start_time = opts.start_times[src->id()].value(); 1062 } else if (!status.ok()) { 1063 return status; 1064 } 1065 1066 status = GetNodeAttr(dst->attrs(), "_start_time", &recv_start_time); 1067 if (errors::IsNotFound(status) && opts.need_to_record_start_times) { 1068 recv_start_time = opts.start_times[dst->id()].value(); 1069 } else if (!status.ok()) { 1070 return status; 1071 } 1072 } 1073 1074 // Check whether there is already a send/recv pair transferring 1075 // the same tensor/control from the src to dst partition. 1076 const bool on_host = IsDstInputOnHost(edge, g_info); 1077 DupRecvKey key{src->id(), edge->src_output(), dst_graph, on_host}; 1078 auto iter = dup_recv.find(key); 1079 if (iter != dup_recv.end()) { 1080 // We found one. Reuse the data/control transferred already. 1081 const string& recv_node_name = iter->second.recv->name(); 1082 if (edge->IsControlEdge()) { 1083 AddInput(dst_def, recv_node_name, Graph::kControlSlot); 1084 } else { 1085 AddInput(dst_def, recv_node_name, 0); 1086 } 1087 ref_control_inputs.push_back(recv_node_name); 1088 1089 // We want the start_time for the recv to be the smallest of the start 1090 // times of it's consumers. So we update this whenever we use a recv, 1091 // and write it out to the attribute at the end of the subroutine 1092 if (iter->second.start_time > recv_start_time) { 1093 iter->second.start_time = recv_start_time; 1094 } 1095 continue; 1096 } 1097 1098 NodeDefBuilder::NodeOut send_from; 1099 if (edge->IsControlEdge()) { 1100 // Insert a dummy const node that will generate a tiny 1101 // data element to be sent from send to recv. 1102 VLOG(1) << "Send/Recv control: " << src->assigned_device_name() << "[" 1103 << src->name() << "] -> " << dst->assigned_device_name() << "[" 1104 << dst->name() << "]"; 1105 NodeDef* dummy = AddDummyConst(opts, src_graph, edge, &status); 1106 if (!status.ok()) return status; 1107 // Set the start time for this dummy node. 1108 if (opts.scheduling_for_recvs) { 1109 AddNodeAttr("_start_time", send_start_time, dummy); 1110 } 1111 AddInput(dummy, src->name(), Graph::kControlSlot); 1112 send_from.Reset(dummy->name(), 0, DT_FLOAT); 1113 } else { 1114 send_from.Reset(src->name(), edge->src_output(), EdgeType(edge)); 1115 } 1116 1117 // Need to split edge by placing matching send/recv nodes on 1118 // the src/dst sides of the edge. 1119 NodeDef* send = AddSend(opts, g_info, src_graph, edge, send_from, 1120 send_start_time, &status); 1121 if (!status.ok()) return status; 1122 1123 NodeDef* real_recv = nullptr; 1124 NodeDef* recv = 1125 AddRecv(opts, g_info, dst_graph, edge, &real_recv, &status); 1126 if (!status.ok()) return status; 1127 1128 // Fix up the control flow edge. 1129 // NOTE(yuanbyu): 'real_recv' must be the real recv node. 1130 if (src_graph == dst_graph) { 1131 // For same device send/recv, add a control edge from send to recv. 1132 // This prevents the asynchronous recv kernel from being scheduled 1133 // before the data is available. 1134 AddInput(real_recv, send->name(), Graph::kControlSlot); 1135 } else if (control_flow_edge != nullptr) { 1136 // Redirect control edge to the real recv since this is not the same 1137 // device send/recv. 1138 --num_control_flow_edges; 1139 AddInput(real_recv, control_flow_edge->src()->name(), 1140 Graph::kControlSlot); 1141 } 1142 1143 if (!edge->IsControlEdge() && 1144 IsRefType(src->output_type(edge->src_output()))) { 1145 AddNodeAttr("_start_time", recv_start_time, recv); 1146 if (real_recv != recv) { 1147 AddNodeAttr("_start_time", recv_start_time, real_recv); 1148 } 1149 // If src is of ref type and the edge is not a control edge, dst has 1150 // read semantics and therefore we must control the recv. 1151 ref_recvs.push_back(real_recv); 1152 } else { 1153 // Memorize the send/recv pair, only if this is not a "ref" edge. 1154 // NOTE(yuanbyu): Collapsing ref edges requires extreme care so 1155 // for now we don't do it. 1156 dup_recv[key] = {recv, real_recv, recv_start_time}; 1157 ref_control_inputs.push_back(recv->name()); 1158 } 1159 1160 if (edge->IsControlEdge()) { 1161 ++num_control; 1162 AddInput(dst_def, recv->name(), Graph::kControlSlot); 1163 } else { 1164 ++num_data; 1165 AddInput(dst_def, recv->name(), 0); 1166 } 1167 } 1168 1169 // Add control edges from 'ref_control_inputs' to 'ref_recvs'. 1170 // NOTE(yuanbyu): Adding these control edges should not introduce 1171 // deadlocks. 'dst' has implicit "read" nodes that, when we split 1172 // across devices, are made explicit; Retargeting the dependencies 1173 // to 'dst' to those nodes would not introduce cycles if there isn't 1174 // one before the transformation. 1175 // NOTE(yuanbyu): This may impact performance because it defers the 1176 // execution of recvs until all the other inputs become available. 1177 AddReadControl(ref_recvs, ref_control_inputs); 1178 1179 // Add back the control edges for control flow that are not used. 1180 if (control_flow_edge != nullptr) { 1181 for (int i = 0; i < num_control_flow_edges; ++i) { 1182 AddInput(dst_def, control_flow_edge->src()->name(), 1183 Graph::kControlSlot); 1184 } 1185 } 1186 } 1187 1188 const FunctionLibraryDefinition* flib_def = opts.flib_def; 1189 if (flib_def == nullptr) { 1190 flib_def = &g->flib_def(); 1191 } 1192 1193 // Set versions, function library and send/recv incarnation. 1194 for (auto& it : *partitions) { 1195 GraphDef* gdef = &it.second; 1196 *gdef->mutable_versions() = g->versions(); 1197 // Prune unreachable functions from `flib_def` before adding them to `gdef`. 1198 *gdef->mutable_library() = flib_def->ReachableDefinitions(*gdef).ToProto(); 1199 1200 // Traverse the graph to fill every send/recv op's incarnation 1201 // information. 1202 SetIncarnation(opts, gdef); 1203 } 1204 1205 // Set the start times for recvs at the very end. 1206 if (opts.scheduling_for_recvs) { 1207 for (auto& it : dup_recv) { 1208 AddNodeAttr("_start_time", it.second.start_time, it.second.recv); 1209 if (it.second.real_recv != it.second.recv) { 1210 AddNodeAttr("_start_time", it.second.start_time, it.second.real_recv); 1211 } 1212 } 1213 } 1214 1215 VLOG(1) << "Added send/recv: controls=" << num_control 1216 << ", data=" << num_data; 1217 return Status::OK(); 1218 } 1219 1220 } // namespace tensorflow 1221