1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/jit/extract_outside_compilation_pass.h" 17 18 #include "absl/strings/match.h" 19 #include "absl/strings/str_cat.h" 20 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" 21 #include "tensorflow/compiler/jit/encapsulate_util.h" 22 #include "tensorflow/compiler/tf2xla/side_effect_util.h" 23 #include "tensorflow/compiler/tf2xla/tf2xla_util.h" 24 #include "tensorflow/core/common_runtime/function.h" 25 #include "tensorflow/core/framework/function.h" 26 #include "tensorflow/core/framework/graph_to_functiondef.h" 27 #include "tensorflow/core/framework/node_def_builder.h" 28 #include "tensorflow/core/framework/node_def_util.h" 29 #include "tensorflow/core/framework/tensor_shape.pb.h" 30 #include "tensorflow/core/graph/algorithm.h" 31 #include "tensorflow/core/lib/core/errors.h" 32 #include "tensorflow/core/lib/gtl/cleanup.h" 33 #include "tensorflow/core/util/dump_graph.h" 34 35 namespace tensorflow { 36 37 namespace { 38 39 // Add a key placeholder node to the graph. The key placeholder node will be 40 // used as input for XlaRecvAtHost/XlaSendFromHost nodes. 41 xla::StatusOr<Node*> AddHostComputeKeyPlaceholder( 42 const string& xla_cluster_name, Graph* g) { 43 NodeDef key_def; 44 NodeDefBuilder builder(absl::StrCat(xla_cluster_name, "_key_placeholder"), 45 "Placeholder"); 46 builder.Attr("dtype", DT_STRING); 47 builder.Attr("shape", PartialTensorShape({2})); 48 builder.Attr("_host_compute_call_node", xla_cluster_name); 49 Status s = builder.Finalize(&key_def); 50 if (!s.ok()) return s; 51 52 Node* n = g->AddNode(key_def, &s); 53 if (!s.ok()) return s; 54 return n; 55 } 56 57 // Returns if the node is a XLA computation key placeholder. 58 bool IsKeyPlaceholderNode(const Node& n) { 59 return n.type_string() == "Placeholder" && 60 absl::EndsWith(n.name(), "_key_placeholder"); 61 } 62 63 // Returns nodes with given type. 64 std::vector<Node*> GatherNodesWithType(const Graph& g, const string& type) { 65 std::vector<Node*> result; 66 for (Node* n : g.nodes()) { 67 if (n->type_string() == type) { 68 result.push_back(n); 69 } 70 } 71 return result; 72 } 73 74 // Gets data types from `arg_nodes` and fills them into `recv_at_host_dtypes`. 75 Status GetArgDataTypes(const std::vector<Node*>& arg_nodes, 76 std::vector<DataType>* recv_at_host_dtypes) { 77 recv_at_host_dtypes->resize(arg_nodes.size(), DT_INVALID); 78 for (auto* n : arg_nodes) { 79 int index; 80 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); 81 DataType dtype; 82 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "T", &dtype)); 83 (*recv_at_host_dtypes)[index] = dtype; 84 } 85 for (int i = 0; i < recv_at_host_dtypes->size(); i++) { 86 if ((*recv_at_host_dtypes)[i] == DT_INVALID) { 87 return errors::Internal("Cannot get datatype for input ", i); 88 } 89 } 90 return Status::OK(); 91 } 92 93 // Builds XlaRecvAtHost node. 94 xla::StatusOr<Node*> BuildRecvAtHostNode( 95 Graph* g, const string& oc_cluster_name, 96 const std::vector<DataType>& recv_at_host_dtypes, Node* key_placeholder) { 97 NodeDefBuilder recv_at_host_builder( 98 absl::StrCat("outside_compilation_", oc_cluster_name, "_recv"), 99 "_XlaRecvAtHost"); 100 NodeDef recv_at_host_def; 101 recv_at_host_builder.Attr("Toutputs", recv_at_host_dtypes); 102 // The correct device_ordinal will be inserted during replication in a 103 // subsequent rewrite. 104 AttrValue device_ordinal_value; 105 device_ordinal_value.set_placeholder("device_ordinal"); 106 recv_at_host_builder.Attr("device_ordinal", device_ordinal_value); 107 recv_at_host_builder.Attr( 108 "key", absl::StrCat("host_compute_channel_", oc_cluster_name)); 109 recv_at_host_builder.Attr(kXlaHasHostTransferAttrName, true); 110 recv_at_host_builder.Input(key_placeholder->name(), 0, DT_STRING); 111 TF_RETURN_IF_ERROR(recv_at_host_builder.Finalize(&recv_at_host_def)); 112 Status s; 113 Node* recv_at_host_node = g->AddNode(recv_at_host_def, &s); 114 TF_RETURN_IF_ERROR(s); 115 return recv_at_host_node; 116 } 117 118 // Builds XlaRecvAtHost node, and replaces all _Arg nodes with it. 119 xla::StatusOr<Node*> ReplaceArgNodesWithRecvAtHostNode( 120 Graph* g, const string& oc_cluster_name, 121 std::vector<DataType>* recv_at_host_dtypes, Node* key_placeholder) { 122 // TODO(b/77601805): use out nodes for source node, instead of traversing all 123 // nodes. 124 std::vector<Node*> arg_nodes = GatherNodesWithType(*g, "_Arg"); 125 TF_RETURN_IF_ERROR(GetArgDataTypes(arg_nodes, recv_at_host_dtypes)); 126 TF_ASSIGN_OR_RETURN( 127 Node * recv_at_host_node, 128 BuildRecvAtHostNode(g, oc_cluster_name, *recv_at_host_dtypes, 129 key_placeholder)); 130 for (auto* n : arg_nodes) { 131 int index; 132 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); 133 // Record out edges and remove `n` before adding those edges to RecvAtHost. 134 // This is to avoid multiple producers. 135 std::vector<OutEdgeInfo> out_edge_info; 136 for (auto edge : n->out_edges()) { 137 out_edge_info.push_back( 138 {edge->dst(), edge->src_output(), edge->dst_input()}); 139 } 140 g->RemoveNode(n); 141 for (const OutEdgeInfo& edge : out_edge_info) { 142 if (edge.dst_input == Graph::kControlSlot) { 143 g->AddControlEdge(recv_at_host_node, edge.dst); 144 } else { 145 g->AddEdge(recv_at_host_node, index, edge.dst, edge.dst_input); 146 } 147 } 148 149 // Rewrite dst nodes because their input changed. 150 for (int i = 0; i < out_edge_info.size(); i++) { 151 const OutEdgeInfo edge = out_edge_info[i]; 152 if (edge.dst_input == Graph::kControlSlot) { 153 continue; 154 } 155 156 Node* dst = edge.dst; 157 NodeDef new_def = dst->def(); 158 *new_def.mutable_input(edge.dst_input) = 159 absl::StrCat(recv_at_host_node->name(), ":", index); 160 TF_ASSIGN_OR_RETURN(Node * dst_replace, ReplaceNode(g, dst, new_def)); 161 162 // Other edges might have `dst` as dst node as well. Update those edges 163 // with `dst_replace`. 164 for (int j = i + 1; j < out_edge_info.size(); j++) { 165 if (out_edge_info[j].dst == dst) { 166 out_edge_info[j].dst = dst_replace; 167 } 168 } 169 } 170 } 171 g->AddEdge(key_placeholder, 0, recv_at_host_node, 0); 172 return recv_at_host_node; 173 } 174 175 // Gets data types from `ret_nodes` and fills them into `send_from_host_dtypes`. 176 Status GetRetDataTypes(const std::vector<Node*>& ret_nodes, 177 std::vector<DataType>* send_from_host_dtypes) { 178 send_from_host_dtypes->resize(ret_nodes.size(), DT_INVALID); 179 for (auto* n : ret_nodes) { 180 int index; 181 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); 182 DataType dtype; 183 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "T", &dtype)); 184 (*send_from_host_dtypes)[index] = dtype; 185 } 186 for (int i = 0; i < send_from_host_dtypes->size(); i++) { 187 if ((*send_from_host_dtypes)[i] == DT_INVALID) { 188 return errors::Internal("Cannot get datatype for output ", i); 189 } 190 } 191 return Status::OK(); 192 } 193 194 // Builds XlaSendFromHost node. 195 xla::StatusOr<Node*> BuildSendFromHostNode( 196 Graph* g, const string& oc_cluster_name, 197 const std::vector<Node*>& ret_nodes, 198 const std::vector<DataType>& send_from_host_dtypes, Node* key_placeholder) { 199 NodeDefBuilder send_from_host_builder( 200 absl::StrCat("outside_compilation_", oc_cluster_name, "_send"), 201 "_XlaSendFromHost"); 202 NodeDef send_from_host_def; 203 send_from_host_builder.Attr("Tinputs", send_from_host_dtypes); 204 // The correct device_ordinal will be inserted during replication in a 205 // subsequent rewrite. 206 AttrValue device_ordinal_value; 207 device_ordinal_value.set_placeholder("device_ordinal"); 208 send_from_host_builder.Attr("device_ordinal", device_ordinal_value); 209 send_from_host_builder.Attr( 210 "key", absl::StrCat("host_compute_channel_", oc_cluster_name)); 211 send_from_host_builder.Attr(kXlaHasHostTransferAttrName, true); 212 std::vector<NodeDefBuilder::NodeOut> inputs(send_from_host_dtypes.size()); 213 for (auto* n : ret_nodes) { 214 int index; 215 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); 216 if (index < 0 || index >= send_from_host_dtypes.size()) { 217 return errors::Internal("Invalid _Retval index: ", index); 218 } 219 for (auto edge : n->in_edges()) { 220 inputs[index] = 221 NodeDefBuilder::NodeOut{edge->src()->name(), edge->src_output(), 222 edge->src()->output_type(edge->src_output())}; 223 } 224 } 225 send_from_host_builder.Input(inputs); 226 send_from_host_builder.Input(key_placeholder->name(), 0, DT_STRING); 227 TF_RETURN_IF_ERROR(send_from_host_builder.Finalize(&send_from_host_def)); 228 Status s; 229 Node* send_from_host_node = g->AddNode(send_from_host_def, &s); 230 TF_RETURN_IF_ERROR(s); 231 return send_from_host_node; 232 } 233 234 // Builds XlaSendFromHost node, and replaces all _Retval nodes with it. 235 xla::StatusOr<Node*> ReplaceRetNodesWithSendFromHostNode( 236 Graph* g, const string& oc_cluster_name, 237 std::vector<DataType>* send_from_host_dtypes, Node* key_placeholder) { 238 // TODO(b/77601805): use in nodes for sink node, instead of traversing all 239 // nodes. 240 std::vector<Node*> ret_nodes = GatherNodesWithType(*g, "_Retval"); 241 TF_RETURN_IF_ERROR(GetRetDataTypes(ret_nodes, send_from_host_dtypes)); 242 TF_ASSIGN_OR_RETURN( 243 Node * send_from_host_node, 244 BuildSendFromHostNode(g, oc_cluster_name, ret_nodes, 245 *send_from_host_dtypes, key_placeholder)); 246 for (auto* n : ret_nodes) { 247 int index; 248 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); 249 for (auto edge : n->in_edges()) { 250 if (edge->src_output() == Graph::kControlSlot) { 251 g->AddControlEdge(edge->src(), send_from_host_node); 252 } else { 253 g->AddEdge(edge->src(), edge->src_output(), send_from_host_node, index); 254 } 255 } 256 g->RemoveNode(n); 257 } 258 g->AddEdge(key_placeholder, 0, send_from_host_node, 259 send_from_host_dtypes->size()); 260 return send_from_host_node; 261 } 262 263 // Returns input shapes (excluding key placeholder) for `send_from_host_node` 264 // if they are all fully defined; absl::nullopt otherwise. 265 absl::optional<std::vector<PartialTensorShape>> GetInferredInputShapes( 266 int num_inputs, Node* send_from_host_node) { 267 std::vector<PartialTensorShape> results(num_inputs); 268 for (int i = 0; i < num_inputs; i++) { 269 const Edge* e; 270 if (!send_from_host_node->input_edge(i, &e).ok()) { 271 return absl::nullopt; 272 } 273 274 std::vector<PartialTensorShape> shapes; 275 if (!GetNodeAttr(e->src()->attrs(), kXlaInferredShapesAttrName, &shapes) 276 .ok()) { 277 return absl::nullopt; 278 } 279 280 const PartialTensorShape shape = shapes[e->src_output()]; 281 if (!shape.IsFullyDefined()) { 282 return absl::nullopt; 283 } 284 285 results[e->dst_input()] = shape; 286 } 287 return results; 288 } 289 290 // Builds XlaHostCompute NodeDef from the outside compilation call node. 291 xla::StatusOr<NodeDef> BuildXlaHostComputeNodeDef( 292 const Node* call_node, const std::map<string, int>& host_compute_core) { 293 string original_oc_name; 294 TF_RETURN_IF_ERROR(GetNodeAttr( 295 call_node->attrs(), "_outside_compilation_subgraph", &original_oc_name)); 296 NodeDefBuilder host_compute_builder( 297 absl::StrCat("outside_compilation_", original_oc_name, "_host_compute"), 298 "XlaHostCompute"); 299 300 // Copy all attributes. 301 for (auto attr : call_node->attrs()) { 302 host_compute_builder.Attr(attr.first, attr.second); 303 } 304 305 // Populate tpu_core assignment. 306 const auto iter = host_compute_core.find(original_oc_name); 307 if (iter != host_compute_core.end()) { 308 int core = iter->second; 309 host_compute_builder.Attr("tpu_core", core); 310 } 311 312 // Set input tokens. 313 host_compute_builder.Attr(kXlaTokenInputNodesAttrName, 314 std::vector<string>{kXlaTokenArgNodeName}); 315 316 // Populate inputs. 317 std::vector<DataType> input_dtypes; 318 TF_RETURN_IF_ERROR(GetNodeAttr(call_node->attrs(), "Tinputs", &input_dtypes)); 319 std::vector<NodeDefBuilder::NodeOut> inputs(input_dtypes.size()); 320 for (auto e : call_node->in_edges()) { 321 if (e->IsControlEdge()) { 322 continue; 323 } 324 325 if (e->dst_input() < 0 || e->dst_input() >= input_dtypes.size()) { 326 return errors::Internal("Invalid dst_input: ", e->dst_input()); 327 } 328 inputs[e->dst_input()] = NodeDefBuilder::NodeOut{ 329 e->src()->name(), e->src_output(), input_dtypes[e->dst_input()]}; 330 } 331 host_compute_builder.Input(inputs); 332 333 NodeDef new_def; 334 TF_RETURN_IF_ERROR(host_compute_builder.Finalize(&new_def)); 335 return new_def; 336 } 337 338 Status ValidateOutsideCompilationCallNode(Node* call_node) { 339 // DT_INT64 as input/output for outside compilation is not supported yet: 340 // b/120809951. 341 for (const Edge* e : call_node->in_edges()) { 342 if (e->IsControlEdge()) { 343 continue; 344 } 345 DataType dtype = e->src()->output_type(e->src_output()); 346 if (dtype == DT_INT64) { 347 return errors::Unimplemented( 348 "int64 input for outside compilation is not supported yet: " 349 "b/120809951. Please cast output of node ", 350 e->src()->DebugString(), 351 " to int32 before feeding it into outside compilation."); 352 } 353 } 354 for (const Edge* e : call_node->out_edges()) { 355 if (e->IsControlEdge()) { 356 continue; 357 } 358 DataType dtype = e->dst()->input_type(e->dst_input()); 359 if (dtype == DT_INT64) { 360 return errors::Unimplemented( 361 "int64 output for outside compilation is not supported yet: " 362 "b/120809951. Please cast input of node ", 363 e->dst()->DebugString(), 364 " to int32 before returning it from outside compilation."); 365 } 366 } 367 return Status::OK(); 368 } 369 370 // Replace outside compilation function call node with XlaHostCompute node. 371 // If the function call node has no input/output edges, we will just remove it 372 // and not create a XlaHostCompute node. 373 Status ReplaceOrRemoveOutsideCompilationCallNode( 374 Graph* g, Node* call_node, const std::map<string, int>& host_compute_core) { 375 // If the function call node has no input/output edges, just remove it. 376 bool has_edge = false; 377 for (auto e : call_node->in_edges()) { 378 if (!e->IsControlEdge() || e->src() != g->source_node()) { 379 has_edge = true; 380 break; 381 } 382 } 383 for (auto e : call_node->out_edges()) { 384 if (!e->IsControlEdge() || e->dst() != g->sink_node()) { 385 has_edge = true; 386 break; 387 } 388 } 389 if (!has_edge) { 390 VLOG(4) << "Did not add HostCompute node for " << call_node->DebugString(); 391 g->RemoveNode(call_node); 392 return Status::OK(); 393 } 394 395 // Build XlaHostCompute NodeDef. 396 TF_ASSIGN_OR_RETURN(NodeDef node_def, 397 BuildXlaHostComputeNodeDef(call_node, host_compute_core)); 398 TF_ASSIGN_OR_RETURN(Node * host_compute_node, 399 ReplaceNode(g, call_node, node_def)); 400 VLOG(4) << "Added HostCompute node: " << host_compute_node->DebugString(); 401 402 return Status::OK(); 403 } 404 405 // Resets "device_ordinal" attr to placeholder value for related nodes 406 // (XlaRecvAtHost nodes; XlaSendFromHost nodes; If/While/FuncCall nodes 407 // containing XlaRecvAtHost/XlaSendFromHost). 408 Status ResetDeviceOrdinalToPlaceholderValue(Graph* g) { 409 AttrValue device_ordinal_value; 410 device_ordinal_value.set_placeholder("device_ordinal"); 411 for (Node* n : g->nodes()) { 412 if (!HasNodeAttr(n->def(), kXlaHasHostTransferAttrName)) { 413 continue; 414 } 415 416 if (n->type_string() == "_XlaRecvAtHost" || 417 n->type_string() == "_XlaSendFromHost") { 418 n->ClearAttr("device_ordinal"); 419 n->AddAttr("device_ordinal", device_ordinal_value); 420 } else if (n->type_string() == "If") { 421 for (const string& attr_name : 422 std::vector<string>{"then_branch", "else_branch"}) { 423 NameAttrList branch_func; 424 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func)); 425 (*branch_func.mutable_attr())["device_ordinal"] = device_ordinal_value; 426 n->ClearAttr(attr_name); 427 n->AddAttr(attr_name, branch_func); 428 } 429 } else if (n->type_string() == "While") { 430 for (const string& attr_name : std::vector<string>{"cond", "body"}) { 431 NameAttrList branch_func; 432 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func)); 433 (*branch_func.mutable_attr())["device_ordinal"] = device_ordinal_value; 434 n->ClearAttr(attr_name); 435 n->AddAttr(attr_name, branch_func); 436 } 437 } else if (HasNodeAttr(n->def(), "device_ordinal")) { 438 // Function call node containing outside compilation. 439 n->ClearAttr("device_ordinal"); 440 n->AddAttr("device_ordinal", device_ordinal_value); 441 } else { 442 return errors::Internal("Unknown node marked with ", 443 kXlaHasHostTransferAttrName, ": ", 444 n->DebugString()); 445 } 446 } 447 return Status::OK(); 448 } 449 450 // For an XLA computation, builds host side graph given all outside compilation 451 // graphs inside it. The host side graph contains: 452 // 1) a "sequencer" node (we will add control edge between XlaRecvAtHost and 453 // XlaSendFromHost to this sequencer node, so all outside compilation nodes 454 // will be executed *before* this sequencer). 455 // 2) a "key placeholder" node. Later in ExpandHostGraphIntoMainGraph(), we will 456 // replace this node with compilation result node. 457 // 3) all outside compilation graphs. 458 Status ConstructHostGraph( 459 const string& xla_cluster_name, const string& outside_compilation_attr_name, 460 const std::vector<string>& outside_compilation_host_graphs, 461 FunctionLibraryDefinition* fld, const string& host_graph_func_name) { 462 Graph host_graph(fld); 463 464 // Create sequencer node in host graph. 465 NodeDefBuilder sequencer_builder(absl::StrCat(xla_cluster_name, "_sequencer"), 466 "NoOp"); 467 sequencer_builder.Attr("_xla_host_transfer_sequencer", xla_cluster_name); 468 NodeDef sequencer_def; 469 TF_RETURN_IF_ERROR(sequencer_builder.Finalize(&sequencer_def)); 470 Status s; 471 Node* sequencer = host_graph.AddNode(sequencer_def, &s); 472 TF_RETURN_IF_ERROR(s); 473 474 // Create key placeholder in host graph. 475 TF_ASSIGN_OR_RETURN( 476 Node * key_placeholder, 477 AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph)); 478 479 // For each outside compilation graph, copy them to host graph with the 480 // following changes: 481 // a) Use key_placeholder in host graph instead of its own. 482 // b) Add control edge from host transfer nodes (XlaRecvAtHost, 483 // XlaSendFromHost, If/While nodes containing 484 // XlaRecvAtHost/XlaSendFromHost) to sequencer node. 485 // c) Clear node_def.device(), so device placer won't get confused. 486 for (const string& host_func : outside_compilation_host_graphs) { 487 VLOG(4) << "Expanding host graph " << host_func; 488 // Temporarily use "0" as "device_ordinal". It will be reset to placeholder 489 // value after we expanded all host graphs. We cannot just use placeholder 490 // value here because FunctionDef instantiation does not allow placeholder 491 // value for attributes. 492 AttrValue device_ordinal_attr; 493 device_ordinal_attr.set_i(0); 494 protobuf::Map<string, AttrValue> attrs; 495 attrs["device_ordinal"] = device_ordinal_attr; 496 FunctionBody* host_fbody = nullptr; 497 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( 498 *fld->Find(host_func), AttrSlice(&attrs), fld, 499 [&](const string& op, const OpDef** sig) { 500 return fld->LookUpOpDef(op, sig); 501 }, 502 &host_fbody)); 503 std::unique_ptr<FunctionBody> host_fbody_deleter(host_fbody); 504 505 // We use ReverseDFS() to copy nodes. Make sure all nodes are reverse 506 // reachable from sink node so all nodes will be copied. 507 // TODO(b/77601805): consolidate copy graph functions. 508 FixupSourceAndSinkEdges(host_fbody->graph); 509 510 std::map<const Node*, Node*> node_map; 511 node_map[host_fbody->graph->source_node()] = host_graph.source_node(); 512 node_map[host_fbody->graph->sink_node()] = host_graph.sink_node(); 513 Status s; 514 ReverseDFS( 515 *host_fbody->graph, /*enter=*/nullptr, 516 [&](const Node* n) { 517 if (!s.ok()) { 518 return; 519 } 520 521 Node* copy; 522 if (node_map.find(n) != node_map.end()) { 523 // Already copied this node. 524 copy = node_map.at(n); 525 } else if (IsKeyPlaceholderNode(*n)) { 526 // Change a). 527 copy = key_placeholder; 528 node_map[n] = copy; 529 } else { 530 // Copy the node. 531 NodeDef copy_def = n->def(); 532 // Change c). 533 copy_def.clear_device(); 534 copy = host_graph.AddNode(copy_def, &s); 535 if (!s.ok()) { 536 return; 537 } 538 node_map[n] = copy; 539 } 540 541 // Only handle input edges. Output edges will be added later as 542 // its output nodes' input edges. 543 for (auto e : n->in_edges()) { 544 if (node_map.find(e->src()) == node_map.end()) { 545 s = errors::Internal("Cannot find node image for ", 546 e->src()->DebugString()); 547 return; 548 } 549 host_graph.AddEdge(node_map[e->src()], e->src_output(), copy, 550 e->dst_input()); 551 } 552 553 // Change b). 554 if (HasNodeAttr(copy->def(), kXlaHasHostTransferAttrName)) { 555 host_graph.AddControlEdge(copy, sequencer); 556 } 557 }, 558 NodeComparatorID()); 559 560 if (!s.ok()) { 561 return s; 562 } 563 } 564 // Reset "device_ordinal" to placeholder value. 565 TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(&host_graph)); 566 567 // sequencer and key_placeholder might be dead nodes. Prune them if necessary. 568 // - sequencer should be pruned iff it has no input control edges from 569 // RecvAtHost/SendFromHost. If it has input control edge, we connect it to 570 // sink node so it won't be pruned. 571 // - key_placeholder should be pruned iff there's no RecvAtHost/SendFromHost. 572 // We don't need to do anything special. 573 if (!sequencer->in_edges().empty()) { 574 host_graph.AddControlEdge(sequencer, host_graph.sink_node()); 575 } 576 PruneForReverseReachability( 577 &host_graph, std::unordered_set<const Node*>{host_graph.sink_node()}); 578 579 // Postprocess edges between different outside compilations. 580 TF_RETURN_IF_ERROR(PostprocessEdgesBetweenOutsideCompilations( 581 &host_graph, outside_compilation_attr_name)); 582 583 if (VLOG_IS_ON(4)) { 584 DumpGraphToFile(absl::StrCat("extract_outside_compilation_host_graph_for_", 585 xla_cluster_name), 586 host_graph, fld); 587 } 588 589 FunctionDef host_graph_fdef; 590 TF_RETURN_IF_ERROR( 591 GraphToFunctionDef(host_graph, host_graph_func_name, &host_graph_fdef)); 592 if (fld->Find(host_graph_func_name)) { 593 TF_RETURN_IF_ERROR( 594 fld->ReplaceFunction(host_graph_func_name, host_graph_fdef)); 595 } else { 596 TF_RETURN_IF_ERROR(fld->AddFunctionDef(host_graph_fdef)); 597 } 598 599 return Status::OK(); 600 } 601 602 // Expand XLA computation's outside compilation host side graph into main graph. 603 // Add a control edge between sequencer node and the XLA computation node. 604 Status ExpandHostGraphIntoMainGraph(Graph* main_graph, 605 FunctionLibraryDefinition* fld, 606 const string& host_graph_func_name, 607 Node* xla_computation_node) { 608 // Temporarily use "0" as "device_ordinal". It will be rewritten with the 609 // correct value in a later pass. We cannot just use placeholder value here 610 // because FunctionDef instantiation does not allow placeholder value for 611 // attributes. 612 AttrValue device_ordinal_attr; 613 device_ordinal_attr.set_i(0); 614 protobuf::Map<string, AttrValue> attrs; 615 attrs["device_ordinal"] = device_ordinal_attr; 616 FunctionBody* fbody = nullptr; 617 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( 618 *fld->Find(host_graph_func_name), AttrSlice(&attrs), fld, 619 [&](const string& op, const OpDef** sig) { 620 return fld->LookUpOpDef(op, sig); 621 }, 622 &fbody)); 623 std::unique_ptr<FunctionBody> fbody_deleter(fbody); 624 Graph* host_graph = fbody->graph; 625 626 // We use ReverseDFS() to copy nodes. Make sure all nodes are reverse 627 // reachable from sink node so all nodes will be copied. 628 // TODO(b/77601805): consolidate copy graph functions. 629 FixupSourceAndSinkEdges(host_graph); 630 631 // Copy all nodes. 632 std::map<const Node*, Node*> node_map; 633 node_map[host_graph->source_node()] = main_graph->source_node(); 634 node_map[host_graph->sink_node()] = main_graph->sink_node(); 635 Status s = Status::OK(); 636 auto copy_node_fn = [&](const Node* n) { 637 if (!s.ok()) { 638 return; 639 } 640 641 Node* copy; 642 if (node_map.find(n) != node_map.end()) { 643 // Already copied this node. 644 copy = node_map.at(n); 645 } else { 646 // Copy the node. 647 NodeDef copy_def = n->def(); 648 copy = main_graph->AddNode(copy_def, &s); 649 if (!s.ok()) { 650 return; 651 } 652 node_map[n] = copy; 653 } 654 655 // Only handle input edges. Output edges will be added later as its output 656 // nodes' input edges. 657 for (auto e : n->in_edges()) { 658 if (node_map.find(e->src()) == node_map.end()) { 659 s = errors::Internal("Cannot find node image for ", 660 e->src()->DebugString()); 661 return; 662 } 663 main_graph->AddEdge(node_map[e->src()], e->src_output(), copy, 664 e->dst_input()); 665 } 666 667 // Add control edge from sequencer to XLA computation node. 668 if (copy->type_string() == "NoOp" && 669 HasNodeAttr(copy->def(), "_xla_host_transfer_sequencer")) { 670 main_graph->AddControlEdge(copy, xla_computation_node); 671 } 672 }; 673 ReverseDFS(*host_graph, /*enter=*/nullptr, copy_node_fn, NodeComparatorID()); 674 return s; 675 } 676 677 // Rewrites shape inference graph for outside compilation: 678 // 1) If XlaSendFromHost also exists in `host_graph`, copy nodes from 679 // `host_graph`. Because we might still have outside compilation to outside 680 // compilation placeholder nodes in shape inference graph, which will prevent 681 // us from inferring XlaSendFromHost shape. But in `host_graph`, we already 682 // removed those placeholder nodes. 683 // 2) Remove control edges. 684 // 3) Prune nodes that are not useful for shape inference. 685 Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name, 686 Graph* host_graph, 687 FunctionLibraryDefinition* fld) { 688 // Use "0" as "device_ordinal". It does not matter for shape inference. 689 AttrValue device_ordinal_attr; 690 device_ordinal_attr.set_i(0); 691 protobuf::Map<string, AttrValue> attrs; 692 attrs["device_ordinal"] = device_ordinal_attr; 693 FunctionBody* fbody = nullptr; 694 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( 695 *fld->Find(shape_inference_graph_name), AttrSlice(&attrs), fld, 696 [&](const string& op, const OpDef** sig) { 697 return fld->LookUpOpDef(op, sig); 698 }, 699 &fbody)); 700 std::unique_ptr<FunctionBody> fbody_deleter(fbody); 701 Graph* g = fbody->graph; 702 703 // Find SendFromHost node. 704 Node* send_from_host = nullptr; 705 for (Node* n : g->nodes()) { 706 if (n->type_string() == "_XlaSendFromHost") { 707 send_from_host = n; 708 break; 709 } 710 } 711 if (!send_from_host) { 712 return errors::Internal("Shape inference graph ", 713 shape_inference_graph_name, 714 " does not have _XlaSendFromHost node."); 715 } 716 717 // See if the SendFromHost node exists in `host_graph`. 718 Node* send_from_host_main_graph = nullptr; 719 for (Node* n : host_graph->nodes()) { 720 if (n->name() == send_from_host->name()) { 721 send_from_host_main_graph = n; 722 break; 723 } 724 } 725 if (send_from_host_main_graph) { 726 // This is an "top-level" outside compilation. Clear the graph, and copy 727 // SendFromHost and all its predecessors from `host_graph`. 728 std::vector<Node*> nodes; 729 for (Node* n : g->op_nodes()) { 730 nodes.push_back(n); 731 } 732 for (Node* n : nodes) { 733 g->RemoveNode(n); 734 } 735 736 std::map<const Node*, Node*> node_map; 737 node_map[host_graph->source_node()] = g->source_node(); 738 Status s; 739 auto copy_node_fn = [&](const Node* n) { 740 if (!s.ok()) { 741 return; 742 } 743 744 if (node_map.find(n) != node_map.end()) { 745 return; 746 } 747 748 NodeDef copy_def = n->def(); 749 Node* copy = g->AddNode(copy_def, &s); 750 if (!s.ok()) { 751 return; 752 } 753 for (auto e : n->in_edges()) { 754 if (node_map.find(e->src()) == node_map.end()) { 755 s = errors::Internal("Cannot find node image for ", 756 e->src()->DebugString()); 757 return; 758 } 759 g->AddEdge(node_map[e->src()], e->src_output(), copy, e->dst_input()); 760 } 761 762 node_map[n] = copy; 763 }; 764 // TODO(b/77601805): consolidate copy graph functions. 765 ReverseDFSFrom(*host_graph, 766 std::vector<const Node*>{send_from_host_main_graph}, 767 /*enter=*/nullptr, copy_node_fn, NodeComparatorID()); 768 if (!s.ok()) { 769 return s; 770 } 771 772 send_from_host = node_map[send_from_host_main_graph]; 773 } else { 774 // This is an outside compilation embedded in If/While/gradient/etc. 775 // It will be enough for shape inference. Leave `g` unchanged. 776 } 777 778 // Control edges are not useful for shape inference. Remove them. 779 for (auto e : g->edges()) { 780 if (e->IsControlEdge()) { 781 g->RemoveEdge(e); 782 } 783 } 784 785 // Nodes that are not reverse reachable from SendFromHost are not useful for 786 // shape inference. Prune them. 787 PruneForReverseReachability(g, 788 std::unordered_set<const Node*>{send_from_host}); 789 790 if (VLOG_IS_ON(4)) { 791 DumpGraphToFile(shape_inference_graph_name, *g, fld); 792 } 793 794 // Replace original shape inference graph. 795 FunctionDef fdef_replace; 796 TF_RETURN_IF_ERROR( 797 GraphToFunctionDef(*g, shape_inference_graph_name, &fdef_replace)); 798 TF_RETURN_IF_ERROR( 799 fld->ReplaceFunction(shape_inference_graph_name, fdef_replace)); 800 801 return Status::OK(); 802 } 803 804 // Builds XlaSendToHost node which sends cond predicate to host. 805 xla::StatusOr<Node*> BuildSendIfPredNode(const string& name, 806 const string& host_transfer_key, 807 Node* pred_node, Graph* g) { 808 NodeDefBuilder send_pred_builder(name, "XlaSendToHost"); 809 send_pred_builder.Attr("Tinput", DT_BOOL); 810 send_pred_builder.Attr("key", absl::StrCat(host_transfer_key, "_dtoh_0")); 811 send_pred_builder.Attr(kXlaTokenInputNodesAttrName, 812 std::vector<string>{kXlaTokenArgNodeName}); 813 send_pred_builder.Input(pred_node->name(), 0, DT_BOOL); 814 NodeDef send_pred_def; 815 TF_RETURN_IF_ERROR(send_pred_builder.Finalize(&send_pred_def)); 816 Status s; 817 Node* send_pred_node = g->AddNode(send_pred_def, &s); 818 TF_RETURN_IF_ERROR(s); 819 g->AddEdge(pred_node, 0, send_pred_node, 0); 820 return send_pred_node; 821 } 822 823 // Replaces key placeholder node with an _Arg node. 824 Status ReplaceKeyPlaceholderWithArgNode(const string& xla_cluster_name, 825 const string& func_name, 826 FunctionLibraryDefinition* fld) { 827 // Temporarily use "0" as "device_ordinal". It will be reset to placeholder 828 // value after rewriting. 829 AttrValue device_ordinal_attr; 830 device_ordinal_attr.set_i(0); 831 protobuf::Map<string, AttrValue> attrs; 832 attrs["device_ordinal"] = device_ordinal_attr; 833 FunctionBody* fbody = nullptr; 834 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( 835 *fld->Find(func_name), AttrSlice(&attrs), fld, 836 [&](const string& op, const OpDef** sig) { 837 return fld->LookUpOpDef(op, sig); 838 }, 839 &fbody)); 840 std::unique_ptr<FunctionBody> fbody_deleter(fbody); 841 Graph* g = fbody->graph; 842 843 // Find or create the key placeholder node. 844 Node* key_placeholder = nullptr; 845 for (Node* n : g->nodes()) { 846 if (IsKeyPlaceholderNode(*n)) { 847 key_placeholder = n; 848 break; 849 } 850 } 851 if (!key_placeholder) { 852 TF_ASSIGN_OR_RETURN(key_placeholder, 853 AddHostComputeKeyPlaceholder(xla_cluster_name, g)); 854 } 855 856 // Build the _Arg node, and replace key placeholder node with it. 857 NodeDefBuilder arg_builder("key_arg", FunctionLibraryDefinition::kArgOp); 858 arg_builder.Attr("T", DT_STRING); 859 arg_builder.Attr("index", 0); 860 NodeDef arg_def; 861 TF_RETURN_IF_ERROR(arg_builder.Finalize(&arg_def)); 862 TF_RETURN_IF_ERROR(ReplaceNode(g, key_placeholder, arg_def).status()); 863 864 // Reset "device_ordinal" to placeholder value. 865 TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(g)); 866 867 FunctionDef replace_fdef; 868 TF_RETURN_IF_ERROR(GraphToFunctionDef(*g, func_name, &replace_fdef)); 869 TF_RETURN_IF_ERROR(fld->ReplaceFunction(func_name, replace_fdef)); 870 return Status::OK(); 871 } 872 873 // Builds host side graph for If node. 874 Status BuildHostGraphForIfNode(const string& xla_cluster_attr_name, 875 const string& outside_compilation_attr_name, 876 const string& xla_cluster_name, 877 const string& if_node_name, 878 const string& host_transfer_key, 879 const string& host_graph_func_name, 880 FunctionLibraryDefinition* fld, 881 const string& then_branch_host_func_name, 882 const string& else_branch_host_func_name) { 883 Graph host_graph(fld); 884 string outside_compilation_name = absl::StrCat("oc_if_", if_node_name); 885 AttrValue device_ordinal_value; 886 device_ordinal_value.set_placeholder("device_ordinal"); 887 888 // Step 1: add key placeholder node. 889 TF_ASSIGN_OR_RETURN( 890 Node * key_placeholder, 891 AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph)); 892 893 // Step 2: build XlaRecvAtHost node to recv predicate. 894 NodeDefBuilder recv_pred_builder( 895 absl::StrCat("recv_oc_if_pred_", if_node_name), "_XlaRecvAtHost"); 896 recv_pred_builder.Attr("Toutputs", std::vector<DataType>{DT_BOOL}); 897 recv_pred_builder.Attr("key", host_transfer_key); 898 recv_pred_builder.Attr("device_ordinal", device_ordinal_value); 899 recv_pred_builder.Attr(xla_cluster_attr_name, xla_cluster_name); 900 recv_pred_builder.Attr(outside_compilation_attr_name, 901 outside_compilation_name); 902 recv_pred_builder.Attr(kXlaHasHostTransferAttrName, true); 903 recv_pred_builder.Input(key_placeholder->name(), 0, DT_STRING); 904 NodeDef recv_pred_def; 905 TF_RETURN_IF_ERROR(recv_pred_builder.Finalize(&recv_pred_def)); 906 Status s; 907 Node* recv_pred_node = host_graph.AddNode(recv_pred_def, &s); 908 TF_RETURN_IF_ERROR(s); 909 host_graph.AddEdge(key_placeholder, 0, recv_pred_node, 0); 910 911 // Step 3: rewrite `{then, else}_branch_host_func_name`, replace key 912 // placeholder with an _Arg node. 913 TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( 914 xla_cluster_name, then_branch_host_func_name, fld)); 915 TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( 916 xla_cluster_name, else_branch_host_func_name, fld)); 917 918 // Step 4: build If node to choose between `{then, else}_branch_host_graph`. 919 NodeDefBuilder if_builder(absl::StrCat("oc_if_", if_node_name), "If"); 920 if_builder.Attr("Tcond", DT_BOOL); 921 if_builder.Attr("Tin", std::vector<DataType>{DT_STRING}); 922 if_builder.Attr("Tout", std::vector<DataType>{}); 923 NameAttrList host_then_branch, host_else_branch; 924 host_then_branch.set_name(then_branch_host_func_name); 925 (*host_then_branch.mutable_attr())["device_ordinal"] = device_ordinal_value; 926 host_else_branch.set_name(else_branch_host_func_name); 927 (*host_else_branch.mutable_attr())["device_ordinal"] = device_ordinal_value; 928 if_builder.Attr("then_branch", host_then_branch); 929 if_builder.Attr("else_branch", host_else_branch); 930 if_builder.Attr(kXlaHasHostTransferAttrName, true); 931 if_builder.Attr(xla_cluster_attr_name, xla_cluster_name); 932 if_builder.Attr(outside_compilation_attr_name, outside_compilation_name); 933 if_builder.Input(recv_pred_node->name(), 0, DT_BOOL); 934 std::vector<NodeDefBuilder::NodeOut> if_inputs{ 935 {key_placeholder->name(), 0, DT_STRING}}; 936 if_builder.Input(if_inputs); 937 NodeDef if_def; 938 TF_RETURN_IF_ERROR(if_builder.Finalize(&if_def)); 939 Node* if_node = host_graph.AddNode(if_def, &s); 940 TF_RETURN_IF_ERROR(s); 941 host_graph.AddEdge(recv_pred_node, 0, if_node, 0); 942 host_graph.AddEdge(key_placeholder, 0, if_node, 1); 943 944 // Convert `host_graph` to function, and add a "device_ordinal" attr. 945 FunctionDef oc_host_graph_fdef; 946 TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name, 947 &oc_host_graph_fdef)); 948 if (fld->Find(host_graph_func_name)) { 949 TF_RETURN_IF_ERROR( 950 fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef)); 951 } else { 952 TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef)); 953 } 954 955 return Status::OK(); 956 } 957 958 // Rewrites loop cond to add a node which sends loop cond to host. 959 Status AddSendLoopPredToLoopCond(FunctionLibraryDefinition* fld, 960 const NameAttrList& loop_cond_func, 961 const string& while_node_name, 962 const string& host_transfer_key) { 963 // Instantiate the loop cond function. 964 FunctionBody* fbody = nullptr; 965 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( 966 *fld->Find(loop_cond_func.name()), AttrSlice(&loop_cond_func.attr()), fld, 967 [&](const string& op, const OpDef** sig) { 968 return fld->LookUpOpDef(op, sig); 969 }, 970 &fbody)); 971 std::unique_ptr<FunctionBody> fbody_deleter(fbody); 972 Graph* g = fbody->graph; 973 974 // Find the _Retval node and the loop cond node. 975 Node* ret_node = nullptr; 976 for (Node* n : g->nodes()) { 977 if (n->type_string() == "_Retval") { 978 if (ret_node) { 979 return errors::Internal("Multiple return node for loop cond function ", 980 loop_cond_func.name(), ": ", 981 ret_node->DebugString(), " and ", 982 n->DebugString()); 983 } else { 984 ret_node = n; 985 } 986 } 987 } 988 if (!ret_node) { 989 return errors::Internal("No _Retval node for loop cond function ", 990 loop_cond_func.name()); 991 } 992 Node* loop_cond; 993 TF_RETURN_IF_ERROR(ret_node->input_node(0, &loop_cond)); 994 995 // Build the XlaSendToHost node. 996 NodeDefBuilder send_loop_cond_builder( 997 absl::StrCat("send_oc_while_cond_", while_node_name), "XlaSendToHost"); 998 send_loop_cond_builder.Attr("Tinput", DT_BOOL); 999 send_loop_cond_builder.Attr("key", 1000 absl::StrCat(host_transfer_key, "_dtoh_0")); 1001 send_loop_cond_builder.Attr(kXlaTokenInputNodesAttrName, 1002 std::vector<string>{kXlaTokenArgNodeName}); 1003 send_loop_cond_builder.Input(loop_cond->name(), 0, DT_BOOL); 1004 NodeDef send_loop_cond_def; 1005 TF_RETURN_IF_ERROR(send_loop_cond_builder.Finalize(&send_loop_cond_def)); 1006 Status s; 1007 Node* send_loop_cond_node = g->AddNode(send_loop_cond_def, &s); 1008 TF_RETURN_IF_ERROR(s); 1009 g->AddEdge(loop_cond, 0, send_loop_cond_node, 0); 1010 1011 // Replace original function. 1012 FunctionDef replace_fdef; 1013 TF_RETURN_IF_ERROR( 1014 GraphToFunctionDef(*g, loop_cond_func.name(), &replace_fdef)); 1015 TF_RETURN_IF_ERROR(fld->ReplaceFunction(loop_cond_func.name(), replace_fdef)); 1016 1017 return Status::OK(); 1018 } 1019 1020 // Rewrites while loop cond function for host. 1021 Status RewriteHostWhileLoopCond( 1022 const string& cond_host_func_name, const string& while_node_name, 1023 const string& host_transfer_key, const string& xla_cluster_attr_name, 1024 const string& xla_cluster_name, const string& outside_compilation_attr_name, 1025 const string& outside_compilation_name, FunctionLibraryDefinition* fld) { 1026 // Replace key placeholder node with _Arg node. 1027 TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( 1028 xla_cluster_name, cond_host_func_name, fld)); 1029 1030 // Instantiate cond function. 1031 AttrValue device_ordinal_temp_value; 1032 device_ordinal_temp_value.set_i(0); 1033 protobuf::Map<string, AttrValue> attrs; 1034 attrs["device_ordinal"] = device_ordinal_temp_value; 1035 FunctionBody* cond_fbody = nullptr; 1036 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( 1037 *fld->Find(cond_host_func_name), AttrSlice(&attrs), fld, 1038 [&](const string& op, const OpDef** sig) { 1039 return fld->LookUpOpDef(op, sig); 1040 }, 1041 &cond_fbody)); 1042 std::unique_ptr<FunctionBody> cond_fbody_deleter(cond_fbody); 1043 Graph* cond_graph = cond_fbody->graph; 1044 Node* key_arg = nullptr; 1045 for (Node* n : cond_graph->nodes()) { 1046 if (n->type_string() == "_Arg") { 1047 key_arg = n; 1048 } 1049 } 1050 if (!key_arg) { 1051 return errors::Internal( 1052 "No _Arg node found for host compute key in function ", 1053 cond_host_func_name); 1054 } 1055 1056 // Add an XlaRecvAtHost node to use as cond function return value. 1057 // We don't need to set kXlaHasHostTransferAttrName for this node, because 1058 // it's already added for the "While" node on the host. 1059 NodeDefBuilder recv_pred_builder( 1060 absl::StrCat("recv_oc_while_cond_", while_node_name), "_XlaRecvAtHost"); 1061 recv_pred_builder.Attr("Toutputs", std::vector<DataType>{DT_BOOL}); 1062 recv_pred_builder.Attr("key", host_transfer_key); 1063 AttrValue device_ordinal_value; 1064 device_ordinal_value.set_placeholder("device_ordinal"); 1065 recv_pred_builder.Attr("device_ordinal", device_ordinal_value); 1066 recv_pred_builder.Attr(xla_cluster_attr_name, xla_cluster_name); 1067 recv_pred_builder.Attr(outside_compilation_attr_name, 1068 outside_compilation_name); 1069 recv_pred_builder.Input(key_arg->name(), 0, DT_STRING); 1070 NodeDef recv_pred_def; 1071 TF_RETURN_IF_ERROR(recv_pred_builder.Finalize(&recv_pred_def)); 1072 Status s; 1073 Node* recv_pred_node = cond_graph->AddNode(recv_pred_def, &s); 1074 TF_RETURN_IF_ERROR(s); 1075 cond_graph->AddEdge(key_arg, 0, recv_pred_node, 0); 1076 NodeDefBuilder ret_builder( 1077 absl::StrCat("recv_oc_while_cond_ret_", while_node_name), "_Retval"); 1078 ret_builder.Attr("T", DT_BOOL); 1079 ret_builder.Attr("index", 0); 1080 ret_builder.Input(recv_pred_node->name(), 0, DT_BOOL); 1081 NodeDef ret_def; 1082 TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def)); 1083 Node* ret_node = cond_graph->AddNode(ret_def, &s); 1084 TF_RETURN_IF_ERROR(s); 1085 cond_graph->AddEdge(recv_pred_node, 0, ret_node, 0); 1086 1087 // Reset device_ordinal to placeholder value. 1088 TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(cond_graph)); 1089 1090 // Replace original function. 1091 FunctionDef cond_replace_fdef; 1092 TF_RETURN_IF_ERROR( 1093 GraphToFunctionDef(*cond_graph, cond_host_func_name, &cond_replace_fdef)); 1094 TF_RETURN_IF_ERROR( 1095 fld->ReplaceFunction(cond_host_func_name, cond_replace_fdef)); 1096 1097 return Status::OK(); 1098 } 1099 1100 // Rewrites while loop body function for host. 1101 Status RewriteHostWhileLoopBody( 1102 const string& body_host_func_name, const string& while_node_name, 1103 const string& host_transfer_key, const string& xla_cluster_attr_name, 1104 const string& xla_cluster_name, const string& outside_compilation_attr_name, 1105 const string& outside_compilation_name, FunctionLibraryDefinition* fld) { 1106 // Replace key placeholder node with _Arg node. 1107 TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( 1108 xla_cluster_name, body_host_func_name, fld)); 1109 1110 // Instantiate body function. 1111 AttrValue device_ordinal_temp_value; 1112 device_ordinal_temp_value.set_i(0); 1113 protobuf::Map<string, AttrValue> attrs; 1114 attrs["device_ordinal"] = device_ordinal_temp_value; 1115 FunctionBody* body_fbody = nullptr; 1116 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( 1117 *fld->Find(body_host_func_name), AttrSlice(&attrs), fld, 1118 [&](const string& op, const OpDef** sig) { 1119 return fld->LookUpOpDef(op, sig); 1120 }, 1121 &body_fbody)); 1122 std::unique_ptr<FunctionBody> body_fbody_deleter(body_fbody); 1123 Graph* body_graph = body_fbody->graph; 1124 Node* key_arg = nullptr; 1125 for (Node* n : body_graph->nodes()) { 1126 if (n->type_string() == "_Arg") { 1127 key_arg = n; 1128 } 1129 } 1130 if (!key_arg) { 1131 return errors::Internal( 1132 "No _Arg node found for host compute key in function ", 1133 body_host_func_name); 1134 } 1135 1136 // Add a _Retval node to loop body. 1137 NodeDefBuilder ret_builder( 1138 absl::StrCat("recv_oc_while_body_ret_", while_node_name), "_Retval"); 1139 ret_builder.Attr("T", DT_STRING); 1140 ret_builder.Attr("index", 0); 1141 ret_builder.Input(key_arg->name(), 0, DT_STRING); 1142 NodeDef ret_def; 1143 TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def)); 1144 Status s; 1145 Node* ret_node = body_graph->AddNode(ret_def, &s); 1146 TF_RETURN_IF_ERROR(s); 1147 body_graph->AddEdge(key_arg, 0, ret_node, 0); 1148 1149 // Reset device_ordinal to placeholder value. 1150 TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(body_graph)); 1151 1152 // Replace original function. 1153 FunctionDef body_replace_fdef; 1154 TF_RETURN_IF_ERROR( 1155 GraphToFunctionDef(*body_graph, body_host_func_name, &body_replace_fdef)); 1156 TF_RETURN_IF_ERROR( 1157 fld->ReplaceFunction(body_host_func_name, body_replace_fdef)); 1158 1159 return Status::OK(); 1160 } 1161 1162 // Builds host side graph for while node. 1163 Status BuildHostGraphForWhileNode( 1164 const string& xla_cluster_attr_name, 1165 const string& outside_compilation_attr_name, const string& xla_cluster_name, 1166 const string& while_node_name, const string& host_transfer_key, 1167 const string& host_graph_func_name, FunctionLibraryDefinition* fld, 1168 const string& cond_host_func_name, const string& body_host_func_name) { 1169 Graph host_graph(fld); 1170 string outside_compilation_name = absl::StrCat("oc_while_", while_node_name); 1171 1172 // Step 1: add key placeholder node. 1173 TF_ASSIGN_OR_RETURN( 1174 Node * key_placeholder, 1175 AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph)); 1176 1177 // Step 2: rewrite cond function. 1178 TF_RETURN_IF_ERROR(RewriteHostWhileLoopCond( 1179 cond_host_func_name, while_node_name, host_transfer_key, 1180 xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name, 1181 outside_compilation_name, fld)); 1182 1183 // Step 3: rewrite body function. 1184 TF_RETURN_IF_ERROR(RewriteHostWhileLoopBody( 1185 body_host_func_name, while_node_name, host_transfer_key, 1186 xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name, 1187 outside_compilation_name, fld)); 1188 1189 // Step 4: build While node. 1190 NodeDefBuilder while_builder(absl::StrCat("oc_while_", while_node_name), 1191 "While"); 1192 while_builder.Attr("T", std::vector<DataType>{DT_STRING}); 1193 NameAttrList func; 1194 AttrValue device_ordinal_value; 1195 device_ordinal_value.set_placeholder("device_ordinal"); 1196 (*func.mutable_attr())["device_ordinal"] = device_ordinal_value; 1197 func.set_name(cond_host_func_name); 1198 while_builder.Attr("cond", func); 1199 func.set_name(body_host_func_name); 1200 while_builder.Attr("body", func); 1201 while_builder.Attr(kXlaHasHostTransferAttrName, true); 1202 while_builder.Attr(xla_cluster_attr_name, xla_cluster_name); 1203 while_builder.Attr(outside_compilation_attr_name, outside_compilation_name); 1204 std::vector<NodeDefBuilder::NodeOut> while_inputs{ 1205 {key_placeholder->name(), 0, DT_STRING}}; 1206 while_builder.Input(while_inputs); 1207 NodeDef while_def; 1208 TF_RETURN_IF_ERROR(while_builder.Finalize(&while_def)); 1209 Status s; 1210 Node* while_node = host_graph.AddNode(while_def, &s); 1211 TF_RETURN_IF_ERROR(s); 1212 host_graph.AddEdge(key_placeholder, 0, while_node, 0); 1213 1214 // Convert `host_graph` to function. 1215 FunctionDef oc_host_graph_fdef; 1216 TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name, 1217 &oc_host_graph_fdef)); 1218 if (fld->Find(host_graph_func_name)) { 1219 TF_RETURN_IF_ERROR( 1220 fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef)); 1221 } else { 1222 TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef)); 1223 } 1224 1225 return Status::OK(); 1226 } 1227 1228 // Builds host graph for func call nodes. 1229 Status BuildHostGraphForFuncCallNode(const string& func_call_node_name, 1230 const string& xla_cluster_name, 1231 const string& func_call_host_func_name, 1232 const string& host_graph_func_name, 1233 FunctionLibraryDefinition* fld) { 1234 Graph host_graph(fld); 1235 AttrValue device_ordinal_value; 1236 device_ordinal_value.set_placeholder("device_ordinal"); 1237 1238 // Step 1: add key placeholder node. 1239 TF_ASSIGN_OR_RETURN( 1240 Node * key_placeholder, 1241 AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph)); 1242 1243 // Step 2: rewrite `host_func_name`, replace key placeholder with an _Arg 1244 // node. 1245 TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( 1246 xla_cluster_name, func_call_host_func_name, fld)); 1247 1248 // Step 3: build a function call node with `host_func_name`, with 1249 // `key_placeholder` as input. 1250 NodeDefBuilder call_builder(absl::StrCat("oc_call_", func_call_node_name), 1251 func_call_host_func_name, fld); 1252 call_builder.Input(key_placeholder->name(), 0, DT_STRING); 1253 call_builder.Attr("device_ordinal", device_ordinal_value); 1254 call_builder.Attr(kXlaHasHostTransferAttrName, true); 1255 NodeDef call_def; 1256 TF_RETURN_IF_ERROR(call_builder.Finalize(&call_def)); 1257 Status s; 1258 Node* call_node = host_graph.AddNode(call_def, &s); 1259 TF_RETURN_IF_ERROR(s); 1260 host_graph.AddEdge(key_placeholder, 0, call_node, 0); 1261 1262 // Convert `host_graph` to function, and add a "device_ordinal" attr. 1263 FunctionDef oc_host_graph_fdef; 1264 TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name, 1265 &oc_host_graph_fdef)); 1266 if (fld->Find(host_graph_func_name)) { 1267 TF_RETURN_IF_ERROR( 1268 fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef)); 1269 } else { 1270 TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef)); 1271 } 1272 1273 return Status::OK(); 1274 } 1275 1276 Status ExtractOutsideCompilationForNodesWithAssociatedFunctions( 1277 Graph* g, const string& xla_cluster_attr_name, 1278 const string& outside_compilation_attr_name, const string& xla_cluster_name, 1279 const std::map<string, int>& host_compute_core, FunctionLibraryRuntime* flr, 1280 FunctionLibraryDefinition* fld, std::vector<string>* host_graphs, 1281 std::vector<string>* shape_inference_graphs, 1282 bool* has_outside_compilation) { 1283 std::vector<Node*> if_nodes, while_nodes, func_call_nodes; 1284 for (Node* n : g->nodes()) { 1285 if (n->type_string() == "If") { 1286 if_nodes.push_back(n); 1287 } else if (n->type_string() == "While") { 1288 while_nodes.push_back(n); 1289 } else if (fld->Contains(n->type_string())) { 1290 func_call_nodes.push_back(n); 1291 } else if (n->type_string() == FunctionLibraryDefinition::kGradientOp) { 1292 // Only gradient for user-defined function should be considered as 1293 // function call node. 1294 NameAttrList original_func; 1295 TF_RETURN_IF_ERROR(GetNodeAttr( 1296 n->def(), FunctionLibraryDefinition::kFuncAttr, &original_func)); 1297 if (fld->Contains(original_func.name())) { 1298 func_call_nodes.push_back(n); 1299 } 1300 } 1301 } 1302 1303 for (Node* n : func_call_nodes) { 1304 // Extract outside compilation for the function call. 1305 bool func_has_outside_compilation = false; 1306 NameAttrList func; 1307 func.set_name(n->type_string()); 1308 typedef protobuf::Map<string, AttrValue> AttrMap; 1309 *func.mutable_attr() = AttrMap(n->attrs().begin(), n->attrs().end()); 1310 string new_func_name = absl::StrCat(n->name(), "_oc"); 1311 string host_func_name = absl::StrCat("oc_func_call_host_", n->name()); 1312 TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( 1313 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, 1314 func, new_func_name, host_func_name, host_compute_core, flr, fld, 1315 shape_inference_graphs, &func_has_outside_compilation)); 1316 1317 // If the function call does not have outside compilation, nothing to do. 1318 if (!func_has_outside_compilation) { 1319 continue; 1320 } 1321 1322 *has_outside_compilation = true; 1323 1324 // Change `n` to call the new function directly. 1325 NodeDefBuilder replace_builder(n->name(), new_func_name, fld); 1326 for (const Edge* e : n->in_edges()) { 1327 if (e->IsControlEdge()) { 1328 continue; 1329 } 1330 replace_builder.Input(e->src()->name(), e->src_output(), 1331 e->src()->output_type(e->src_output())); 1332 } 1333 for (const auto& attr : n->attrs()) { 1334 replace_builder.Attr(attr.first, attr.second); 1335 } 1336 NodeDef replace_def; 1337 TF_RETURN_IF_ERROR(replace_builder.Finalize(&replace_def)); 1338 TF_ASSIGN_OR_RETURN(Node * replace, ReplaceNode(g, n, replace_def)); 1339 replace->AddAttr(kXlaTokenInputNodesAttrName, 1340 std::vector<string>{kXlaTokenArgNodeName}); 1341 1342 // Build host side graph for the function call. 1343 string oc_host_graph_name = 1344 absl::StrCat("oc_func_host_graph_", replace->name()); 1345 TF_RETURN_IF_ERROR( 1346 BuildHostGraphForFuncCallNode(replace->name(), xla_cluster_name, 1347 host_func_name, oc_host_graph_name, fld)); 1348 1349 // Record the host graph. 1350 host_graphs->push_back(oc_host_graph_name); 1351 } 1352 1353 for (Node* n : if_nodes) { 1354 // Instantiate "then_branch" and "else_branch". 1355 NameAttrList then_branch, else_branch; 1356 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "then_branch", &then_branch)); 1357 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "else_branch", &else_branch)); 1358 1359 // Extract outside compilation for then_branch and else_branch. 1360 bool then_branch_has_outside_compilation = false; 1361 bool else_branch_has_outside_compilation = false; 1362 string then_branch_host_func_name = 1363 absl::StrCat("oc_then_branch_host_if_", n->name()), 1364 else_branch_host_func_name = 1365 absl::StrCat("oc_else_branch_host_if_", n->name()); 1366 string then_branch_xla_func_name = absl::StrCat(then_branch.name(), "_oc"), 1367 else_branch_xla_func_name = absl::StrCat(else_branch.name(), "_oc"); 1368 TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( 1369 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, 1370 then_branch, then_branch_xla_func_name, then_branch_host_func_name, 1371 host_compute_core, flr, fld, shape_inference_graphs, 1372 &then_branch_has_outside_compilation)); 1373 TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( 1374 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, 1375 else_branch, else_branch_xla_func_name, else_branch_host_func_name, 1376 host_compute_core, flr, fld, shape_inference_graphs, 1377 &else_branch_has_outside_compilation)); 1378 1379 // If then/else branch do not have outside compilation, nothing to do. 1380 if (!then_branch_has_outside_compilation && 1381 !else_branch_has_outside_compilation) { 1382 continue; 1383 } 1384 1385 *has_outside_compilation = true; 1386 1387 // Change If node to call the new functions. 1388 then_branch.set_name(then_branch_xla_func_name); 1389 n->ClearAttr("then_branch"); 1390 n->AddAttr("then_branch", then_branch); 1391 else_branch.set_name(else_branch_xla_func_name); 1392 n->ClearAttr("else_branch"); 1393 n->AddAttr("else_branch", else_branch); 1394 1395 string host_transfer_key = absl::StrCat("oc_if_pred_", n->name()); 1396 1397 // XLA computation: add a SendToHost node to send cond predicate. 1398 Node* pred_node; 1399 TF_RETURN_IF_ERROR(n->input_node(0, &pred_node)); 1400 TF_ASSIGN_OR_RETURN( 1401 Node * send_pred_node, 1402 BuildSendIfPredNode(absl::StrCat("send_oc_if_pred_", n->name()), 1403 host_transfer_key, pred_node, g)); 1404 n->AddAttr(kXlaTokenInputNodesAttrName, 1405 std::vector<string>{send_pred_node->name()}); 1406 1407 // Add a control edge from `send_pred_node` to If node, so XlaCompiler will 1408 // visit If node after `send_pred_node`, thus the token output for 1409 // `send_pred_node` has been generated. 1410 g->AddControlEdge(send_pred_node, n); 1411 1412 // Build host side graph for the "If" node. 1413 string oc_host_graph_name = absl::StrCat("oc_if_host_graph_", n->name()); 1414 TF_RETURN_IF_ERROR(BuildHostGraphForIfNode( 1415 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, 1416 n->name(), host_transfer_key, oc_host_graph_name, fld, 1417 then_branch_host_func_name, else_branch_host_func_name)); 1418 host_graphs->push_back(oc_host_graph_name); 1419 } 1420 1421 for (Node* n : while_nodes) { 1422 // Instantiate "cond" and "body". 1423 NameAttrList cond, body; 1424 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "cond", &cond)); 1425 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "body", &body)); 1426 1427 // Extract outside compilation for cond and body. 1428 bool cond_has_outside_compilation = false; 1429 bool body_has_outside_compilation = false; 1430 string cond_host_func_name = absl::StrCat("oc_cond_host_while_", n->name()), 1431 body_host_func_name = absl::StrCat("oc_body_host_while_", n->name()); 1432 string cond_xla_func_name = absl::StrCat(cond.name(), "_oc"), 1433 body_xla_func_name = absl::StrCat(body.name(), "_oc"); 1434 TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( 1435 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, 1436 cond, cond_xla_func_name, cond_host_func_name, host_compute_core, flr, 1437 fld, shape_inference_graphs, &cond_has_outside_compilation)); 1438 TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( 1439 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, 1440 body, body_xla_func_name, body_host_func_name, host_compute_core, flr, 1441 fld, shape_inference_graphs, &body_has_outside_compilation)); 1442 1443 // If cond/body do not have outside compilation, nothing to do. 1444 if (!cond_has_outside_compilation && !body_has_outside_compilation) { 1445 continue; 1446 } 1447 1448 *has_outside_compilation = true; 1449 1450 // Change While node to call the new functions. 1451 cond.set_name(cond_xla_func_name); 1452 n->ClearAttr("cond"); 1453 n->AddAttr("cond", cond); 1454 body.set_name(body_xla_func_name); 1455 n->ClearAttr("body"); 1456 n->AddAttr("body", body); 1457 1458 string host_transfer_key = absl::StrCat("oc_while_pred_", n->name()); 1459 1460 // XLA computation: rewrite cond function to add a SendToHost node to send 1461 // loop predicate. 1462 TF_RETURN_IF_ERROR( 1463 AddSendLoopPredToLoopCond(fld, cond, n->name(), host_transfer_key)); 1464 n->AddAttr(kXlaTokenInputNodesAttrName, 1465 std::vector<string>{kXlaTokenArgNodeName}); 1466 1467 // Build host side graph for the "While" node. 1468 string oc_host_graph_name = absl::StrCat("oc_while_host_graph_", n->name()); 1469 TF_RETURN_IF_ERROR(BuildHostGraphForWhileNode( 1470 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, 1471 n->name(), host_transfer_key, oc_host_graph_name, fld, 1472 cond_host_func_name, body_host_func_name)); 1473 host_graphs->push_back(oc_host_graph_name); 1474 } 1475 1476 return Status::OK(); 1477 } 1478 1479 } // namespace 1480 1481 Status RewriteOutsideCompilationSubgraphFn::operator()( 1482 const std::vector<OutputTensor>& arg_source_tensors, 1483 std::unique_ptr<Graph>* graph, std::vector<int>* input_permutation, 1484 std::vector<int>* output_permutation, NodeDef* node_def) { 1485 string old_name = node_def->op(); 1486 string new_name = absl::StrCat(xla_cluster_name_, "_", old_name); 1487 node_def->set_op(new_name); 1488 node_def->set_name(new_name); 1489 1490 // Later we will run PruneForReverseReachability(), so make sure all original 1491 // nodes are reachable from sink node and won't be removed. 1492 FixupSourceAndSinkEdges(graph->get()); 1493 1494 // Step 1: create a key placeholder node. 1495 TF_ASSIGN_OR_RETURN( 1496 Node * key_placeholder, 1497 AddHostComputeKeyPlaceholder(xla_cluster_name_, graph->get())); 1498 1499 // Step 2: build RecvAtHost node, and replace all _Arg nodes with it. 1500 std::vector<DataType> recv_at_host_dtypes; 1501 TF_ASSIGN_OR_RETURN( 1502 Node * recv_at_host_node, 1503 ReplaceArgNodesWithRecvAtHostNode(graph->get(), new_name, 1504 &recv_at_host_dtypes, key_placeholder)); 1505 1506 // Step 3: build SendFromHost node, and replace all _Retval nodes with it. 1507 std::vector<DataType> send_from_host_dtypes; 1508 TF_ASSIGN_OR_RETURN( 1509 Node * send_from_host_node, 1510 ReplaceRetNodesWithSendFromHostNode( 1511 graph->get(), new_name, &send_from_host_dtypes, key_placeholder)); 1512 1513 // Step 4: add XLA cluster and outside compilation attr. 1514 for (Node* n : (*graph)->nodes()) { 1515 if (IsKeyPlaceholderNode(*n)) { 1516 continue; 1517 } 1518 1519 n->AddAttr(xla_cluster_attr_name_, xla_cluster_name_); 1520 n->AddAttr(outside_compilation_attr_name_, old_name); 1521 } 1522 1523 // Check whether we have all input shapes for XlaSendFromHost. If we do, we 1524 // will set `shapes` attr for the call node; otherwise we will save the 1525 // shape inference graph and set `shape_inference_graph` for the call node. 1526 absl::optional<std::vector<PartialTensorShape>> shapes = 1527 GetInferredInputShapes(send_from_host_dtypes.size(), send_from_host_node); 1528 for (Node* n : (*graph)->nodes()) { 1529 n->ClearAttr(kXlaInferredShapesAttrName); 1530 } 1531 1532 // Step 5: add control edges for originally XLA <-> outside compilation 1533 // control edges. 1534 for (Node* n : (*graph)->nodes()) { 1535 if (HasNodeAttr(n->def(), kXlaConnectedToXlaComputationAttrName)) { 1536 (*graph)->AddControlEdge(n, send_from_host_node); 1537 n->ClearAttr(kXlaConnectedToXlaComputationAttrName); 1538 } 1539 if (HasNodeAttr(n->def(), kXlaConnectedFromXlaComputationAttrName)) { 1540 (*graph)->AddControlEdge(recv_at_host_node, n); 1541 n->ClearAttr(kXlaConnectedFromXlaComputationAttrName); 1542 } 1543 } 1544 1545 // Step 6: RecvAtHost/SendFromHost/key_placeholder might be dead nodes. Prune 1546 // them if necessary. 1547 // - RecvAtHost should be pruned iff it has no output data/control edges. If 1548 // it has any output edge, it will be reverse reachable from sink node. We 1549 // don't need to do anything special. 1550 // - SendFromHost should be pruned iff it has no input data/control edges. If 1551 // it has input edges other than key_placeholder, we connect it to sink 1552 // node so it won't be pruned. 1553 // - key_placeholder should be pruned iff RecvAtHost/SendFromHost are pruned. 1554 // We don't need to do anything special. 1555 if (send_from_host_node->in_edges().size() > 1) { 1556 (*graph)->AddControlEdge(send_from_host_node, (*graph)->sink_node()); 1557 } 1558 PruneForReverseReachability( 1559 graph->get(), std::unordered_set<const Node*>{(*graph)->sink_node()}); 1560 1561 // Step 7: add necessary attributes to function call node, so we can replace 1562 // it with HostCompute node later. 1563 AddNodeAttr("_outside_compilation_subgraph", old_name, node_def); 1564 if (shapes) { 1565 NameAttrList shape_inference_graph; 1566 AddNodeAttr("shape_inference_graph", shape_inference_graph, node_def); 1567 AddNodeAttr("shapes", *shapes, node_def); 1568 } else { 1569 string shape_inference_func_name = 1570 absl::StrCat("_outside_compilation_shape_inference_", new_name); 1571 NameAttrList shape_inference_graph; 1572 shape_inference_graph.set_name(shape_inference_func_name); 1573 AddNodeAttr("shape_inference_graph", shape_inference_graph, node_def); 1574 AddNodeAttr("shapes", std::vector<TensorShapeProto>{}, node_def); 1575 } 1576 AddNodeAttr("ancestors", std::vector<string>{}, node_def); 1577 AddNodeAttr("Tinputs", recv_at_host_dtypes, node_def); 1578 AddNodeAttr("Toutputs", send_from_host_dtypes, node_def); 1579 AddNodeAttr("key", absl::StrCat("host_compute_channel_", new_name), node_def); 1580 1581 return Status::OK(); 1582 } 1583 1584 Status ExtractOutsideCompilationForFunction( 1585 const string& xla_cluster_attr_name, 1586 const string& outside_compilation_attr_name, const string& xla_cluster_name, 1587 const NameAttrList& func_name_attrs, const string& new_func_name, 1588 const string& host_graph_func_name, 1589 const std::map<string, int>& host_compute_core, FunctionLibraryRuntime* flr, 1590 FunctionLibraryDefinition* fld, std::vector<string>* shape_inference_graphs, 1591 bool* has_outside_compilation) { 1592 // Convert the function to graph. 1593 const string& func_name = func_name_attrs.name(); 1594 FunctionLibraryRuntime::Handle handle; 1595 TF_RETURN_IF_ERROR( 1596 flr->Instantiate(func_name, AttrSlice(&func_name_attrs.attr()), &handle)); 1597 Status ret_status = Status::OK(); 1598 auto cleanup_handle = gtl::MakeCleanup([&]() { 1599 auto s = flr->ReleaseHandle(handle); 1600 if (!s.ok()) { 1601 ret_status.Update(s); 1602 } 1603 }); 1604 const FunctionBody* fbody = flr->GetFunctionBody(handle); 1605 1606 // Check if we have outside compilation nodes. 1607 *has_outside_compilation = false; 1608 for (Node* n : fbody->graph->nodes()) { 1609 if (HasNodeAttr(n->def(), outside_compilation_attr_name)) { 1610 *has_outside_compilation = true; 1611 break; 1612 } 1613 } 1614 // We cannot early return here, because we might have outside compilation in 1615 // If/While function body. 1616 1617 // Preprocess edges between different outside compilations. They will be 1618 // restored in `ConstructHostGraph()`. 1619 TF_RETURN_IF_ERROR(PreprocessEdgesBetweenOutsideCompilations( 1620 fbody->graph, outside_compilation_attr_name)); 1621 if (VLOG_IS_ON(4)) { 1622 DumpGraphToFile( 1623 absl::StrCat("extract_outside_compilation_for_func_before_", func_name), 1624 *fbody->graph, fld); 1625 } 1626 1627 // Encapsulate outside_compilation cluster into function call node. 1628 std::unique_ptr<Graph> graph_out; 1629 RewriteOutsideCompilationSubgraphFn rewrite_fn( 1630 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name); 1631 TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions( 1632 outside_compilation_attr_name, "", *fbody->graph, rewrite_fn, 1633 /*reuse_existing_functions=*/true, &graph_out, fld)); 1634 1635 // Replace outside_compilation function nodes with HostCompute ops. 1636 std::vector<Node*> outside_compilation_nodes; 1637 std::vector<string> outside_compilation_host_graphs; 1638 for (Node* n : graph_out->nodes()) { 1639 if (HasNodeAttr(n->def(), "_outside_compilation_subgraph")) { 1640 outside_compilation_nodes.push_back(n); 1641 outside_compilation_host_graphs.push_back(n->name()); 1642 1643 // If we could not infer shapes for XlaSendFromHost inputs statically, we 1644 // will set the "shape_inference_graph" attribute. In that case, copy 1645 // outside compilation subgraph as shape inference graph in `fld`. 1646 NameAttrList shape_inference_graph; 1647 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "shape_inference_graph", 1648 &shape_inference_graph)); 1649 if (!shape_inference_graph.name().empty()) { 1650 shape_inference_graphs->push_back(shape_inference_graph.name()); 1651 1652 const FunctionDef* xla_fdef = fld->Find(n->name()); 1653 if (!xla_fdef) { 1654 return errors::Internal("Cannot find XLA function ", n->name()); 1655 } 1656 FunctionDef shape_inference_fdef = *xla_fdef; 1657 shape_inference_fdef.mutable_signature()->set_name( 1658 shape_inference_graph.name()); 1659 if (fld->Find(shape_inference_graph.name())) { 1660 TF_RETURN_IF_ERROR(fld->ReplaceFunction(shape_inference_graph.name(), 1661 shape_inference_fdef)); 1662 } else { 1663 TF_RETURN_IF_ERROR(fld->AddFunctionDef(shape_inference_fdef)); 1664 } 1665 } 1666 } 1667 } 1668 for (Node* n : outside_compilation_nodes) { 1669 TF_RETURN_IF_ERROR(ValidateOutsideCompilationCallNode(n)); 1670 TF_RETURN_IF_ERROR(ReplaceOrRemoveOutsideCompilationCallNode( 1671 graph_out.get(), n, host_compute_core)); 1672 } 1673 1674 // Handle nodes with associated functions. 1675 TF_RETURN_IF_ERROR(ExtractOutsideCompilationForNodesWithAssociatedFunctions( 1676 graph_out.get(), xla_cluster_attr_name, outside_compilation_attr_name, 1677 xla_cluster_name, host_compute_core, flr, fld, 1678 &outside_compilation_host_graphs, shape_inference_graphs, 1679 has_outside_compilation)); 1680 1681 // Construct host graph. 1682 TF_RETURN_IF_ERROR(ConstructHostGraph( 1683 xla_cluster_name, outside_compilation_attr_name, 1684 outside_compilation_host_graphs, fld, host_graph_func_name)); 1685 1686 // Remove the outside compilation graphs from function library. 1687 for (const string& func : outside_compilation_host_graphs) { 1688 TF_RETURN_IF_ERROR(fld->RemoveFunction(func)); 1689 } 1690 1691 // Replace original function. 1692 FunctionDef updated_fdef; 1693 TF_RETURN_IF_ERROR( 1694 GraphToFunctionDef(*graph_out, new_func_name, &updated_fdef)); 1695 const FunctionDef* original_fdef = fld->Find(func_name); 1696 if (original_fdef) { 1697 for (const auto& attr : original_fdef->attr()) { 1698 (*updated_fdef.mutable_attr())[attr.first] = attr.second; 1699 } 1700 } 1701 if (fld->Find(new_func_name)) { 1702 TF_RETURN_IF_ERROR(fld->ReplaceFunction(new_func_name, updated_fdef)); 1703 } else { 1704 TF_RETURN_IF_ERROR(fld->AddFunctionDef(updated_fdef)); 1705 } 1706 if (VLOG_IS_ON(4)) { 1707 DumpGraphToFile( 1708 absl::StrCat("extract_outside_compilation_for_func_after_", func_name), 1709 *graph_out, fld); 1710 } 1711 1712 return ret_status; 1713 } 1714 1715 Status ExtractOutsideCompilation( 1716 const string& xla_cluster_attr_name, 1717 const string& outside_compilation_attr_name, 1718 const std::unordered_map<string, XlaClusterInfo>& clusters, Graph* g, 1719 FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld) { 1720 if (VLOG_IS_ON(4)) { 1721 DumpGraphToFile("extract_outside_compilation_before", *g, fld); 1722 } 1723 1724 std::vector<string> shape_inference_graphs; 1725 for (auto& iter : clusters) { 1726 string xla_cluster_name = iter.first; 1727 Node* n = iter.second.node; 1728 auto const& func_name_attrs = iter.second.func_name_attrs; 1729 auto const& host_compute_core = iter.second.host_compute_core; 1730 1731 bool has_outside_compilation; 1732 string host_graph_func_name = absl::StrCat("oc_host_graph_", n->name()); 1733 TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( 1734 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, 1735 func_name_attrs, func_name_attrs.name(), host_graph_func_name, 1736 host_compute_core, flr, fld, &shape_inference_graphs, 1737 &has_outside_compilation)); 1738 TF_RETURN_IF_ERROR( 1739 ExpandHostGraphIntoMainGraph(g, fld, host_graph_func_name, n)); 1740 TF_RETURN_IF_ERROR(fld->RemoveFunction(host_graph_func_name)); 1741 } 1742 1743 for (auto shape_inference_graph_name : shape_inference_graphs) { 1744 TF_RETURN_IF_ERROR( 1745 RewriteShapeInferenceGraph(shape_inference_graph_name, g, fld)); 1746 } 1747 1748 if (VLOG_IS_ON(4)) { 1749 DumpGraphToFile("extract_outside_compilation_after", *g, fld); 1750 } 1751 return Status::OK(); 1752 } 1753 1754 } // namespace tensorflow 1755