Home | History | Annotate | Download | only in jit
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/jit/extract_outside_compilation_pass.h"
     17 
     18 #include "absl/strings/match.h"
     19 #include "absl/strings/str_cat.h"
     20 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
     21 #include "tensorflow/compiler/jit/encapsulate_util.h"
     22 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
     23 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
     24 #include "tensorflow/core/common_runtime/function.h"
     25 #include "tensorflow/core/framework/function.h"
     26 #include "tensorflow/core/framework/graph_to_functiondef.h"
     27 #include "tensorflow/core/framework/node_def_builder.h"
     28 #include "tensorflow/core/framework/node_def_util.h"
     29 #include "tensorflow/core/framework/tensor_shape.pb.h"
     30 #include "tensorflow/core/graph/algorithm.h"
     31 #include "tensorflow/core/lib/core/errors.h"
     32 #include "tensorflow/core/lib/gtl/cleanup.h"
     33 #include "tensorflow/core/util/dump_graph.h"
     34 
     35 namespace tensorflow {
     36 
     37 namespace {
     38 
     39 // Add a key placeholder node to the graph. The key placeholder node will be
     40 // used as input for XlaRecvAtHost/XlaSendFromHost nodes.
     41 xla::StatusOr<Node*> AddHostComputeKeyPlaceholder(
     42     const string& xla_cluster_name, Graph* g) {
     43   NodeDef key_def;
     44   NodeDefBuilder builder(absl::StrCat(xla_cluster_name, "_key_placeholder"),
     45                          "Placeholder");
     46   builder.Attr("dtype", DT_STRING);
     47   builder.Attr("shape", PartialTensorShape({2}));
     48   builder.Attr("_host_compute_call_node", xla_cluster_name);
     49   Status s = builder.Finalize(&key_def);
     50   if (!s.ok()) return s;
     51 
     52   Node* n = g->AddNode(key_def, &s);
     53   if (!s.ok()) return s;
     54   return n;
     55 }
     56 
     57 // Returns if the node is a XLA computation key placeholder.
     58 bool IsKeyPlaceholderNode(const Node& n) {
     59   return n.type_string() == "Placeholder" &&
     60          absl::EndsWith(n.name(), "_key_placeholder");
     61 }
     62 
     63 // Returns nodes with given type.
     64 std::vector<Node*> GatherNodesWithType(const Graph& g, const string& type) {
     65   std::vector<Node*> result;
     66   for (Node* n : g.nodes()) {
     67     if (n->type_string() == type) {
     68       result.push_back(n);
     69     }
     70   }
     71   return result;
     72 }
     73 
     74 // Gets data types from `arg_nodes` and fills them into `recv_at_host_dtypes`.
     75 Status GetArgDataTypes(const std::vector<Node*>& arg_nodes,
     76                        std::vector<DataType>* recv_at_host_dtypes) {
     77   recv_at_host_dtypes->resize(arg_nodes.size(), DT_INVALID);
     78   for (auto* n : arg_nodes) {
     79     int index;
     80     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
     81     DataType dtype;
     82     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "T", &dtype));
     83     (*recv_at_host_dtypes)[index] = dtype;
     84   }
     85   for (int i = 0; i < recv_at_host_dtypes->size(); i++) {
     86     if ((*recv_at_host_dtypes)[i] == DT_INVALID) {
     87       return errors::Internal("Cannot get datatype for input ", i);
     88     }
     89   }
     90   return Status::OK();
     91 }
     92 
     93 // Builds XlaRecvAtHost node.
     94 xla::StatusOr<Node*> BuildRecvAtHostNode(
     95     Graph* g, const string& oc_cluster_name,
     96     const std::vector<DataType>& recv_at_host_dtypes, Node* key_placeholder) {
     97   NodeDefBuilder recv_at_host_builder(
     98       absl::StrCat("outside_compilation_", oc_cluster_name, "_recv"),
     99       "_XlaRecvAtHost");
    100   NodeDef recv_at_host_def;
    101   recv_at_host_builder.Attr("Toutputs", recv_at_host_dtypes);
    102   // The correct device_ordinal will be inserted during replication in a
    103   // subsequent rewrite.
    104   AttrValue device_ordinal_value;
    105   device_ordinal_value.set_placeholder("device_ordinal");
    106   recv_at_host_builder.Attr("device_ordinal", device_ordinal_value);
    107   recv_at_host_builder.Attr(
    108       "key", absl::StrCat("host_compute_channel_", oc_cluster_name));
    109   recv_at_host_builder.Attr(kXlaHasHostTransferAttrName, true);
    110   recv_at_host_builder.Input(key_placeholder->name(), 0, DT_STRING);
    111   TF_RETURN_IF_ERROR(recv_at_host_builder.Finalize(&recv_at_host_def));
    112   Status s;
    113   Node* recv_at_host_node = g->AddNode(recv_at_host_def, &s);
    114   TF_RETURN_IF_ERROR(s);
    115   return recv_at_host_node;
    116 }
    117 
    118 // Builds XlaRecvAtHost node, and replaces all _Arg nodes with it.
    119 xla::StatusOr<Node*> ReplaceArgNodesWithRecvAtHostNode(
    120     Graph* g, const string& oc_cluster_name,
    121     std::vector<DataType>* recv_at_host_dtypes, Node* key_placeholder) {
    122   // TODO(b/77601805): use out nodes for source node, instead of traversing all
    123   // nodes.
    124   std::vector<Node*> arg_nodes = GatherNodesWithType(*g, "_Arg");
    125   TF_RETURN_IF_ERROR(GetArgDataTypes(arg_nodes, recv_at_host_dtypes));
    126   TF_ASSIGN_OR_RETURN(
    127       Node * recv_at_host_node,
    128       BuildRecvAtHostNode(g, oc_cluster_name, *recv_at_host_dtypes,
    129                           key_placeholder));
    130   for (auto* n : arg_nodes) {
    131     int index;
    132     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
    133     // Record out edges and remove `n` before adding those edges to RecvAtHost.
    134     // This is to avoid multiple producers.
    135     std::vector<OutEdgeInfo> out_edge_info;
    136     for (auto edge : n->out_edges()) {
    137       out_edge_info.push_back(
    138           {edge->dst(), edge->src_output(), edge->dst_input()});
    139     }
    140     g->RemoveNode(n);
    141     for (const OutEdgeInfo& edge : out_edge_info) {
    142       if (edge.dst_input == Graph::kControlSlot) {
    143         g->AddControlEdge(recv_at_host_node, edge.dst);
    144       } else {
    145         g->AddEdge(recv_at_host_node, index, edge.dst, edge.dst_input);
    146       }
    147     }
    148 
    149     // Rewrite dst nodes because their input changed.
    150     for (int i = 0; i < out_edge_info.size(); i++) {
    151       const OutEdgeInfo edge = out_edge_info[i];
    152       if (edge.dst_input == Graph::kControlSlot) {
    153         continue;
    154       }
    155 
    156       Node* dst = edge.dst;
    157       NodeDef new_def = dst->def();
    158       *new_def.mutable_input(edge.dst_input) =
    159           absl::StrCat(recv_at_host_node->name(), ":", index);
    160       TF_ASSIGN_OR_RETURN(Node * dst_replace, ReplaceNode(g, dst, new_def));
    161 
    162       // Other edges might have `dst` as dst node as well. Update those edges
    163       // with `dst_replace`.
    164       for (int j = i + 1; j < out_edge_info.size(); j++) {
    165         if (out_edge_info[j].dst == dst) {
    166           out_edge_info[j].dst = dst_replace;
    167         }
    168       }
    169     }
    170   }
    171   g->AddEdge(key_placeholder, 0, recv_at_host_node, 0);
    172   return recv_at_host_node;
    173 }
    174 
    175 // Gets data types from `ret_nodes` and fills them into `send_from_host_dtypes`.
    176 Status GetRetDataTypes(const std::vector<Node*>& ret_nodes,
    177                        std::vector<DataType>* send_from_host_dtypes) {
    178   send_from_host_dtypes->resize(ret_nodes.size(), DT_INVALID);
    179   for (auto* n : ret_nodes) {
    180     int index;
    181     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
    182     DataType dtype;
    183     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "T", &dtype));
    184     (*send_from_host_dtypes)[index] = dtype;
    185   }
    186   for (int i = 0; i < send_from_host_dtypes->size(); i++) {
    187     if ((*send_from_host_dtypes)[i] == DT_INVALID) {
    188       return errors::Internal("Cannot get datatype for output ", i);
    189     }
    190   }
    191   return Status::OK();
    192 }
    193 
    194 // Builds XlaSendFromHost node.
    195 xla::StatusOr<Node*> BuildSendFromHostNode(
    196     Graph* g, const string& oc_cluster_name,
    197     const std::vector<Node*>& ret_nodes,
    198     const std::vector<DataType>& send_from_host_dtypes, Node* key_placeholder) {
    199   NodeDefBuilder send_from_host_builder(
    200       absl::StrCat("outside_compilation_", oc_cluster_name, "_send"),
    201       "_XlaSendFromHost");
    202   NodeDef send_from_host_def;
    203   send_from_host_builder.Attr("Tinputs", send_from_host_dtypes);
    204   // The correct device_ordinal will be inserted during replication in a
    205   // subsequent rewrite.
    206   AttrValue device_ordinal_value;
    207   device_ordinal_value.set_placeholder("device_ordinal");
    208   send_from_host_builder.Attr("device_ordinal", device_ordinal_value);
    209   send_from_host_builder.Attr(
    210       "key", absl::StrCat("host_compute_channel_", oc_cluster_name));
    211   send_from_host_builder.Attr(kXlaHasHostTransferAttrName, true);
    212   std::vector<NodeDefBuilder::NodeOut> inputs(send_from_host_dtypes.size());
    213   for (auto* n : ret_nodes) {
    214     int index;
    215     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
    216     if (index < 0 || index >= send_from_host_dtypes.size()) {
    217       return errors::Internal("Invalid _Retval index: ", index);
    218     }
    219     for (auto edge : n->in_edges()) {
    220       inputs[index] =
    221           NodeDefBuilder::NodeOut{edge->src()->name(), edge->src_output(),
    222                                   edge->src()->output_type(edge->src_output())};
    223     }
    224   }
    225   send_from_host_builder.Input(inputs);
    226   send_from_host_builder.Input(key_placeholder->name(), 0, DT_STRING);
    227   TF_RETURN_IF_ERROR(send_from_host_builder.Finalize(&send_from_host_def));
    228   Status s;
    229   Node* send_from_host_node = g->AddNode(send_from_host_def, &s);
    230   TF_RETURN_IF_ERROR(s);
    231   return send_from_host_node;
    232 }
    233 
    234 // Builds XlaSendFromHost node, and replaces all _Retval nodes with it.
    235 xla::StatusOr<Node*> ReplaceRetNodesWithSendFromHostNode(
    236     Graph* g, const string& oc_cluster_name,
    237     std::vector<DataType>* send_from_host_dtypes, Node* key_placeholder) {
    238   // TODO(b/77601805): use in nodes for sink node, instead of traversing all
    239   // nodes.
    240   std::vector<Node*> ret_nodes = GatherNodesWithType(*g, "_Retval");
    241   TF_RETURN_IF_ERROR(GetRetDataTypes(ret_nodes, send_from_host_dtypes));
    242   TF_ASSIGN_OR_RETURN(
    243       Node * send_from_host_node,
    244       BuildSendFromHostNode(g, oc_cluster_name, ret_nodes,
    245                             *send_from_host_dtypes, key_placeholder));
    246   for (auto* n : ret_nodes) {
    247     int index;
    248     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
    249     for (auto edge : n->in_edges()) {
    250       if (edge->src_output() == Graph::kControlSlot) {
    251         g->AddControlEdge(edge->src(), send_from_host_node);
    252       } else {
    253         g->AddEdge(edge->src(), edge->src_output(), send_from_host_node, index);
    254       }
    255     }
    256     g->RemoveNode(n);
    257   }
    258   g->AddEdge(key_placeholder, 0, send_from_host_node,
    259              send_from_host_dtypes->size());
    260   return send_from_host_node;
    261 }
    262 
    263 // Returns input shapes (excluding key placeholder) for `send_from_host_node`
    264 // if they are all fully defined; absl::nullopt otherwise.
    265 absl::optional<std::vector<PartialTensorShape>> GetInferredInputShapes(
    266     int num_inputs, Node* send_from_host_node) {
    267   std::vector<PartialTensorShape> results(num_inputs);
    268   for (int i = 0; i < num_inputs; i++) {
    269     const Edge* e;
    270     if (!send_from_host_node->input_edge(i, &e).ok()) {
    271       return absl::nullopt;
    272     }
    273 
    274     std::vector<PartialTensorShape> shapes;
    275     if (!GetNodeAttr(e->src()->attrs(), kXlaInferredShapesAttrName, &shapes)
    276              .ok()) {
    277       return absl::nullopt;
    278     }
    279 
    280     const PartialTensorShape shape = shapes[e->src_output()];
    281     if (!shape.IsFullyDefined()) {
    282       return absl::nullopt;
    283     }
    284 
    285     results[e->dst_input()] = shape;
    286   }
    287   return results;
    288 }
    289 
    290 // Builds XlaHostCompute NodeDef from the outside compilation call node.
    291 xla::StatusOr<NodeDef> BuildXlaHostComputeNodeDef(
    292     const Node* call_node, const std::map<string, int>& host_compute_core) {
    293   string original_oc_name;
    294   TF_RETURN_IF_ERROR(GetNodeAttr(
    295       call_node->attrs(), "_outside_compilation_subgraph", &original_oc_name));
    296   NodeDefBuilder host_compute_builder(
    297       absl::StrCat("outside_compilation_", original_oc_name, "_host_compute"),
    298       "XlaHostCompute");
    299 
    300   // Copy all attributes.
    301   for (auto attr : call_node->attrs()) {
    302     host_compute_builder.Attr(attr.first, attr.second);
    303   }
    304 
    305   // Populate tpu_core assignment.
    306   const auto iter = host_compute_core.find(original_oc_name);
    307   if (iter != host_compute_core.end()) {
    308     int core = iter->second;
    309     host_compute_builder.Attr("tpu_core", core);
    310   }
    311 
    312   // Set input tokens.
    313   host_compute_builder.Attr(kXlaTokenInputNodesAttrName,
    314                             std::vector<string>{kXlaTokenArgNodeName});
    315 
    316   // Populate inputs.
    317   std::vector<DataType> input_dtypes;
    318   TF_RETURN_IF_ERROR(GetNodeAttr(call_node->attrs(), "Tinputs", &input_dtypes));
    319   std::vector<NodeDefBuilder::NodeOut> inputs(input_dtypes.size());
    320   for (auto e : call_node->in_edges()) {
    321     if (e->IsControlEdge()) {
    322       continue;
    323     }
    324 
    325     if (e->dst_input() < 0 || e->dst_input() >= input_dtypes.size()) {
    326       return errors::Internal("Invalid dst_input: ", e->dst_input());
    327     }
    328     inputs[e->dst_input()] = NodeDefBuilder::NodeOut{
    329         e->src()->name(), e->src_output(), input_dtypes[e->dst_input()]};
    330   }
    331   host_compute_builder.Input(inputs);
    332 
    333   NodeDef new_def;
    334   TF_RETURN_IF_ERROR(host_compute_builder.Finalize(&new_def));
    335   return new_def;
    336 }
    337 
    338 Status ValidateOutsideCompilationCallNode(Node* call_node) {
    339   // DT_INT64 as input/output for outside compilation is not supported yet:
    340   // b/120809951.
    341   for (const Edge* e : call_node->in_edges()) {
    342     if (e->IsControlEdge()) {
    343       continue;
    344     }
    345     DataType dtype = e->src()->output_type(e->src_output());
    346     if (dtype == DT_INT64) {
    347       return errors::Unimplemented(
    348           "int64 input for outside compilation is not supported yet: "
    349           "b/120809951. Please cast output of node ",
    350           e->src()->DebugString(),
    351           " to int32 before feeding it into outside compilation.");
    352     }
    353   }
    354   for (const Edge* e : call_node->out_edges()) {
    355     if (e->IsControlEdge()) {
    356       continue;
    357     }
    358     DataType dtype = e->dst()->input_type(e->dst_input());
    359     if (dtype == DT_INT64) {
    360       return errors::Unimplemented(
    361           "int64 output for outside compilation is not supported yet: "
    362           "b/120809951. Please cast input of node ",
    363           e->dst()->DebugString(),
    364           " to int32 before returning it from outside compilation.");
    365     }
    366   }
    367   return Status::OK();
    368 }
    369 
    370 // Replace outside compilation function call node with XlaHostCompute node.
    371 // If the function call node has no input/output edges, we will just remove it
    372 // and not create a XlaHostCompute node.
    373 Status ReplaceOrRemoveOutsideCompilationCallNode(
    374     Graph* g, Node* call_node, const std::map<string, int>& host_compute_core) {
    375   // If the function call node has no input/output edges, just remove it.
    376   bool has_edge = false;
    377   for (auto e : call_node->in_edges()) {
    378     if (!e->IsControlEdge() || e->src() != g->source_node()) {
    379       has_edge = true;
    380       break;
    381     }
    382   }
    383   for (auto e : call_node->out_edges()) {
    384     if (!e->IsControlEdge() || e->dst() != g->sink_node()) {
    385       has_edge = true;
    386       break;
    387     }
    388   }
    389   if (!has_edge) {
    390     VLOG(4) << "Did not add HostCompute node for " << call_node->DebugString();
    391     g->RemoveNode(call_node);
    392     return Status::OK();
    393   }
    394 
    395   // Build XlaHostCompute NodeDef.
    396   TF_ASSIGN_OR_RETURN(NodeDef node_def,
    397                       BuildXlaHostComputeNodeDef(call_node, host_compute_core));
    398   TF_ASSIGN_OR_RETURN(Node * host_compute_node,
    399                       ReplaceNode(g, call_node, node_def));
    400   VLOG(4) << "Added HostCompute node: " << host_compute_node->DebugString();
    401 
    402   return Status::OK();
    403 }
    404 
    405 // Resets "device_ordinal" attr to placeholder value for related nodes
    406 // (XlaRecvAtHost nodes; XlaSendFromHost nodes; If/While/FuncCall nodes
    407 // containing XlaRecvAtHost/XlaSendFromHost).
    408 Status ResetDeviceOrdinalToPlaceholderValue(Graph* g) {
    409   AttrValue device_ordinal_value;
    410   device_ordinal_value.set_placeholder("device_ordinal");
    411   for (Node* n : g->nodes()) {
    412     if (!HasNodeAttr(n->def(), kXlaHasHostTransferAttrName)) {
    413       continue;
    414     }
    415 
    416     if (n->type_string() == "_XlaRecvAtHost" ||
    417         n->type_string() == "_XlaSendFromHost") {
    418       n->ClearAttr("device_ordinal");
    419       n->AddAttr("device_ordinal", device_ordinal_value);
    420     } else if (n->type_string() == "If") {
    421       for (const string& attr_name :
    422            std::vector<string>{"then_branch", "else_branch"}) {
    423         NameAttrList branch_func;
    424         TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func));
    425         (*branch_func.mutable_attr())["device_ordinal"] = device_ordinal_value;
    426         n->ClearAttr(attr_name);
    427         n->AddAttr(attr_name, branch_func);
    428       }
    429     } else if (n->type_string() == "While") {
    430       for (const string& attr_name : std::vector<string>{"cond", "body"}) {
    431         NameAttrList branch_func;
    432         TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func));
    433         (*branch_func.mutable_attr())["device_ordinal"] = device_ordinal_value;
    434         n->ClearAttr(attr_name);
    435         n->AddAttr(attr_name, branch_func);
    436       }
    437     } else if (HasNodeAttr(n->def(), "device_ordinal")) {
    438       // Function call node containing outside compilation.
    439       n->ClearAttr("device_ordinal");
    440       n->AddAttr("device_ordinal", device_ordinal_value);
    441     } else {
    442       return errors::Internal("Unknown node marked with ",
    443                               kXlaHasHostTransferAttrName, ": ",
    444                               n->DebugString());
    445     }
    446   }
    447   return Status::OK();
    448 }
    449 
    450 // For an XLA computation, builds host side graph given all outside compilation
    451 // graphs inside it. The host side graph contains:
    452 // 1) a "sequencer" node (we will add control edge between XlaRecvAtHost and
    453 //    XlaSendFromHost to this sequencer node, so all outside compilation nodes
    454 //    will be executed *before* this sequencer).
    455 // 2) a "key placeholder" node. Later in ExpandHostGraphIntoMainGraph(), we will
    456 //    replace this node with compilation result node.
    457 // 3) all outside compilation graphs.
    458 Status ConstructHostGraph(
    459     const string& xla_cluster_name, const string& outside_compilation_attr_name,
    460     const std::vector<string>& outside_compilation_host_graphs,
    461     FunctionLibraryDefinition* fld, const string& host_graph_func_name) {
    462   Graph host_graph(fld);
    463 
    464   // Create sequencer node in host graph.
    465   NodeDefBuilder sequencer_builder(absl::StrCat(xla_cluster_name, "_sequencer"),
    466                                    "NoOp");
    467   sequencer_builder.Attr("_xla_host_transfer_sequencer", xla_cluster_name);
    468   NodeDef sequencer_def;
    469   TF_RETURN_IF_ERROR(sequencer_builder.Finalize(&sequencer_def));
    470   Status s;
    471   Node* sequencer = host_graph.AddNode(sequencer_def, &s);
    472   TF_RETURN_IF_ERROR(s);
    473 
    474   // Create key placeholder in host graph.
    475   TF_ASSIGN_OR_RETURN(
    476       Node * key_placeholder,
    477       AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph));
    478 
    479   // For each outside compilation graph, copy them to host graph with the
    480   // following changes:
    481   // a) Use key_placeholder in host graph instead of its own.
    482   // b) Add control edge from host transfer nodes (XlaRecvAtHost,
    483   //    XlaSendFromHost, If/While nodes containing
    484   //    XlaRecvAtHost/XlaSendFromHost) to sequencer node.
    485   // c) Clear node_def.device(), so device placer won't get confused.
    486   for (const string& host_func : outside_compilation_host_graphs) {
    487     VLOG(4) << "Expanding host graph " << host_func;
    488     // Temporarily use "0" as "device_ordinal". It will be reset to placeholder
    489     // value after we expanded all host graphs. We cannot just use placeholder
    490     // value here because FunctionDef instantiation does not allow placeholder
    491     // value for attributes.
    492     AttrValue device_ordinal_attr;
    493     device_ordinal_attr.set_i(0);
    494     protobuf::Map<string, AttrValue> attrs;
    495     attrs["device_ordinal"] = device_ordinal_attr;
    496     FunctionBody* host_fbody = nullptr;
    497     TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
    498         *fld->Find(host_func), AttrSlice(&attrs), fld,
    499         [&](const string& op, const OpDef** sig) {
    500           return fld->LookUpOpDef(op, sig);
    501         },
    502         &host_fbody));
    503     std::unique_ptr<FunctionBody> host_fbody_deleter(host_fbody);
    504 
    505     // We use ReverseDFS() to copy nodes. Make sure all nodes are reverse
    506     // reachable from sink node so all nodes will be copied.
    507     // TODO(b/77601805): consolidate copy graph functions.
    508     FixupSourceAndSinkEdges(host_fbody->graph);
    509 
    510     std::map<const Node*, Node*> node_map;
    511     node_map[host_fbody->graph->source_node()] = host_graph.source_node();
    512     node_map[host_fbody->graph->sink_node()] = host_graph.sink_node();
    513     Status s;
    514     ReverseDFS(
    515         *host_fbody->graph, /*enter=*/nullptr,
    516         [&](const Node* n) {
    517           if (!s.ok()) {
    518             return;
    519           }
    520 
    521           Node* copy;
    522           if (node_map.find(n) != node_map.end()) {
    523             // Already copied this node.
    524             copy = node_map.at(n);
    525           } else if (IsKeyPlaceholderNode(*n)) {
    526             // Change a).
    527             copy = key_placeholder;
    528             node_map[n] = copy;
    529           } else {
    530             // Copy the node.
    531             NodeDef copy_def = n->def();
    532             // Change c).
    533             copy_def.clear_device();
    534             copy = host_graph.AddNode(copy_def, &s);
    535             if (!s.ok()) {
    536               return;
    537             }
    538             node_map[n] = copy;
    539           }
    540 
    541           // Only handle input edges. Output edges will be added later as
    542           // its output nodes' input edges.
    543           for (auto e : n->in_edges()) {
    544             if (node_map.find(e->src()) == node_map.end()) {
    545               s = errors::Internal("Cannot find node image for ",
    546                                    e->src()->DebugString());
    547               return;
    548             }
    549             host_graph.AddEdge(node_map[e->src()], e->src_output(), copy,
    550                                e->dst_input());
    551           }
    552 
    553           // Change b).
    554           if (HasNodeAttr(copy->def(), kXlaHasHostTransferAttrName)) {
    555             host_graph.AddControlEdge(copy, sequencer);
    556           }
    557         },
    558         NodeComparatorID());
    559 
    560     if (!s.ok()) {
    561       return s;
    562     }
    563   }
    564   // Reset "device_ordinal" to placeholder value.
    565   TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(&host_graph));
    566 
    567   // sequencer and key_placeholder might be dead nodes. Prune them if necessary.
    568   // - sequencer should be pruned iff it has no input control edges from
    569   //   RecvAtHost/SendFromHost. If it has input control edge, we connect it to
    570   //   sink node so it won't be pruned.
    571   // - key_placeholder should be pruned iff there's no RecvAtHost/SendFromHost.
    572   //   We don't need to do anything special.
    573   if (!sequencer->in_edges().empty()) {
    574     host_graph.AddControlEdge(sequencer, host_graph.sink_node());
    575   }
    576   PruneForReverseReachability(
    577       &host_graph, std::unordered_set<const Node*>{host_graph.sink_node()});
    578 
    579   // Postprocess edges between different outside compilations.
    580   TF_RETURN_IF_ERROR(PostprocessEdgesBetweenOutsideCompilations(
    581       &host_graph, outside_compilation_attr_name));
    582 
    583   if (VLOG_IS_ON(4)) {
    584     DumpGraphToFile(absl::StrCat("extract_outside_compilation_host_graph_for_",
    585                                  xla_cluster_name),
    586                     host_graph, fld);
    587   }
    588 
    589   FunctionDef host_graph_fdef;
    590   TF_RETURN_IF_ERROR(
    591       GraphToFunctionDef(host_graph, host_graph_func_name, &host_graph_fdef));
    592   if (fld->Find(host_graph_func_name)) {
    593     TF_RETURN_IF_ERROR(
    594         fld->ReplaceFunction(host_graph_func_name, host_graph_fdef));
    595   } else {
    596     TF_RETURN_IF_ERROR(fld->AddFunctionDef(host_graph_fdef));
    597   }
    598 
    599   return Status::OK();
    600 }
    601 
    602 // Expand XLA computation's outside compilation host side graph into main graph.
    603 // Add a control edge between sequencer node and the XLA computation node.
    604 Status ExpandHostGraphIntoMainGraph(Graph* main_graph,
    605                                     FunctionLibraryDefinition* fld,
    606                                     const string& host_graph_func_name,
    607                                     Node* xla_computation_node) {
    608   // Temporarily use "0" as "device_ordinal". It will be rewritten with the
    609   // correct value in a later pass. We cannot just use placeholder value here
    610   // because FunctionDef instantiation does not allow placeholder value for
    611   // attributes.
    612   AttrValue device_ordinal_attr;
    613   device_ordinal_attr.set_i(0);
    614   protobuf::Map<string, AttrValue> attrs;
    615   attrs["device_ordinal"] = device_ordinal_attr;
    616   FunctionBody* fbody = nullptr;
    617   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
    618       *fld->Find(host_graph_func_name), AttrSlice(&attrs), fld,
    619       [&](const string& op, const OpDef** sig) {
    620         return fld->LookUpOpDef(op, sig);
    621       },
    622       &fbody));
    623   std::unique_ptr<FunctionBody> fbody_deleter(fbody);
    624   Graph* host_graph = fbody->graph;
    625 
    626   // We use ReverseDFS() to copy nodes. Make sure all nodes are reverse
    627   // reachable from sink node so all nodes will be copied.
    628   // TODO(b/77601805): consolidate copy graph functions.
    629   FixupSourceAndSinkEdges(host_graph);
    630 
    631   // Copy all nodes.
    632   std::map<const Node*, Node*> node_map;
    633   node_map[host_graph->source_node()] = main_graph->source_node();
    634   node_map[host_graph->sink_node()] = main_graph->sink_node();
    635   Status s = Status::OK();
    636   auto copy_node_fn = [&](const Node* n) {
    637     if (!s.ok()) {
    638       return;
    639     }
    640 
    641     Node* copy;
    642     if (node_map.find(n) != node_map.end()) {
    643       // Already copied this node.
    644       copy = node_map.at(n);
    645     } else {
    646       // Copy the node.
    647       NodeDef copy_def = n->def();
    648       copy = main_graph->AddNode(copy_def, &s);
    649       if (!s.ok()) {
    650         return;
    651       }
    652       node_map[n] = copy;
    653     }
    654 
    655     // Only handle input edges. Output edges will be added later as its output
    656     // nodes' input edges.
    657     for (auto e : n->in_edges()) {
    658       if (node_map.find(e->src()) == node_map.end()) {
    659         s = errors::Internal("Cannot find node image for ",
    660                              e->src()->DebugString());
    661         return;
    662       }
    663       main_graph->AddEdge(node_map[e->src()], e->src_output(), copy,
    664                           e->dst_input());
    665     }
    666 
    667     // Add control edge from sequencer to XLA computation node.
    668     if (copy->type_string() == "NoOp" &&
    669         HasNodeAttr(copy->def(), "_xla_host_transfer_sequencer")) {
    670       main_graph->AddControlEdge(copy, xla_computation_node);
    671     }
    672   };
    673   ReverseDFS(*host_graph, /*enter=*/nullptr, copy_node_fn, NodeComparatorID());
    674   return s;
    675 }
    676 
    677 // Rewrites shape inference graph for outside compilation:
    678 // 1) If XlaSendFromHost also exists in `host_graph`, copy nodes from
    679 //    `host_graph`. Because we might still have outside compilation to outside
    680 //    compilation placeholder nodes in shape inference graph, which will prevent
    681 //    us from inferring XlaSendFromHost shape. But in `host_graph`, we already
    682 //    removed those placeholder nodes.
    683 // 2) Remove control edges.
    684 // 3) Prune nodes that are not useful for shape inference.
    685 Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name,
    686                                   Graph* host_graph,
    687                                   FunctionLibraryDefinition* fld) {
    688   // Use "0" as "device_ordinal". It does not matter for shape inference.
    689   AttrValue device_ordinal_attr;
    690   device_ordinal_attr.set_i(0);
    691   protobuf::Map<string, AttrValue> attrs;
    692   attrs["device_ordinal"] = device_ordinal_attr;
    693   FunctionBody* fbody = nullptr;
    694   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
    695       *fld->Find(shape_inference_graph_name), AttrSlice(&attrs), fld,
    696       [&](const string& op, const OpDef** sig) {
    697         return fld->LookUpOpDef(op, sig);
    698       },
    699       &fbody));
    700   std::unique_ptr<FunctionBody> fbody_deleter(fbody);
    701   Graph* g = fbody->graph;
    702 
    703   // Find SendFromHost node.
    704   Node* send_from_host = nullptr;
    705   for (Node* n : g->nodes()) {
    706     if (n->type_string() == "_XlaSendFromHost") {
    707       send_from_host = n;
    708       break;
    709     }
    710   }
    711   if (!send_from_host) {
    712     return errors::Internal("Shape inference graph ",
    713                             shape_inference_graph_name,
    714                             " does not have _XlaSendFromHost node.");
    715   }
    716 
    717   // See if the SendFromHost node exists in `host_graph`.
    718   Node* send_from_host_main_graph = nullptr;
    719   for (Node* n : host_graph->nodes()) {
    720     if (n->name() == send_from_host->name()) {
    721       send_from_host_main_graph = n;
    722       break;
    723     }
    724   }
    725   if (send_from_host_main_graph) {
    726     // This is an "top-level" outside compilation. Clear the graph, and copy
    727     // SendFromHost and all its predecessors from `host_graph`.
    728     std::vector<Node*> nodes;
    729     for (Node* n : g->op_nodes()) {
    730       nodes.push_back(n);
    731     }
    732     for (Node* n : nodes) {
    733       g->RemoveNode(n);
    734     }
    735 
    736     std::map<const Node*, Node*> node_map;
    737     node_map[host_graph->source_node()] = g->source_node();
    738     Status s;
    739     auto copy_node_fn = [&](const Node* n) {
    740       if (!s.ok()) {
    741         return;
    742       }
    743 
    744       if (node_map.find(n) != node_map.end()) {
    745         return;
    746       }
    747 
    748       NodeDef copy_def = n->def();
    749       Node* copy = g->AddNode(copy_def, &s);
    750       if (!s.ok()) {
    751         return;
    752       }
    753       for (auto e : n->in_edges()) {
    754         if (node_map.find(e->src()) == node_map.end()) {
    755           s = errors::Internal("Cannot find node image for ",
    756                                e->src()->DebugString());
    757           return;
    758         }
    759         g->AddEdge(node_map[e->src()], e->src_output(), copy, e->dst_input());
    760       }
    761 
    762       node_map[n] = copy;
    763     };
    764     // TODO(b/77601805): consolidate copy graph functions.
    765     ReverseDFSFrom(*host_graph,
    766                    std::vector<const Node*>{send_from_host_main_graph},
    767                    /*enter=*/nullptr, copy_node_fn, NodeComparatorID());
    768     if (!s.ok()) {
    769       return s;
    770     }
    771 
    772     send_from_host = node_map[send_from_host_main_graph];
    773   } else {
    774     // This is an outside compilation embedded in If/While/gradient/etc.
    775     // It will be enough for shape inference. Leave `g` unchanged.
    776   }
    777 
    778   // Control edges are not useful for shape inference. Remove them.
    779   for (auto e : g->edges()) {
    780     if (e->IsControlEdge()) {
    781       g->RemoveEdge(e);
    782     }
    783   }
    784 
    785   // Nodes that are not reverse reachable from SendFromHost are not useful for
    786   // shape inference. Prune them.
    787   PruneForReverseReachability(g,
    788                               std::unordered_set<const Node*>{send_from_host});
    789 
    790   if (VLOG_IS_ON(4)) {
    791     DumpGraphToFile(shape_inference_graph_name, *g, fld);
    792   }
    793 
    794   // Replace original shape inference graph.
    795   FunctionDef fdef_replace;
    796   TF_RETURN_IF_ERROR(
    797       GraphToFunctionDef(*g, shape_inference_graph_name, &fdef_replace));
    798   TF_RETURN_IF_ERROR(
    799       fld->ReplaceFunction(shape_inference_graph_name, fdef_replace));
    800 
    801   return Status::OK();
    802 }
    803 
    804 // Builds XlaSendToHost node which sends cond predicate to host.
    805 xla::StatusOr<Node*> BuildSendIfPredNode(const string& name,
    806                                          const string& host_transfer_key,
    807                                          Node* pred_node, Graph* g) {
    808   NodeDefBuilder send_pred_builder(name, "XlaSendToHost");
    809   send_pred_builder.Attr("Tinput", DT_BOOL);
    810   send_pred_builder.Attr("key", absl::StrCat(host_transfer_key, "_dtoh_0"));
    811   send_pred_builder.Attr(kXlaTokenInputNodesAttrName,
    812                          std::vector<string>{kXlaTokenArgNodeName});
    813   send_pred_builder.Input(pred_node->name(), 0, DT_BOOL);
    814   NodeDef send_pred_def;
    815   TF_RETURN_IF_ERROR(send_pred_builder.Finalize(&send_pred_def));
    816   Status s;
    817   Node* send_pred_node = g->AddNode(send_pred_def, &s);
    818   TF_RETURN_IF_ERROR(s);
    819   g->AddEdge(pred_node, 0, send_pred_node, 0);
    820   return send_pred_node;
    821 }
    822 
    823 // Replaces key placeholder node with an _Arg node.
    824 Status ReplaceKeyPlaceholderWithArgNode(const string& xla_cluster_name,
    825                                         const string& func_name,
    826                                         FunctionLibraryDefinition* fld) {
    827   // Temporarily use "0" as "device_ordinal". It will be reset to placeholder
    828   // value after rewriting.
    829   AttrValue device_ordinal_attr;
    830   device_ordinal_attr.set_i(0);
    831   protobuf::Map<string, AttrValue> attrs;
    832   attrs["device_ordinal"] = device_ordinal_attr;
    833   FunctionBody* fbody = nullptr;
    834   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
    835       *fld->Find(func_name), AttrSlice(&attrs), fld,
    836       [&](const string& op, const OpDef** sig) {
    837         return fld->LookUpOpDef(op, sig);
    838       },
    839       &fbody));
    840   std::unique_ptr<FunctionBody> fbody_deleter(fbody);
    841   Graph* g = fbody->graph;
    842 
    843   // Find or create the key placeholder node.
    844   Node* key_placeholder = nullptr;
    845   for (Node* n : g->nodes()) {
    846     if (IsKeyPlaceholderNode(*n)) {
    847       key_placeholder = n;
    848       break;
    849     }
    850   }
    851   if (!key_placeholder) {
    852     TF_ASSIGN_OR_RETURN(key_placeholder,
    853                         AddHostComputeKeyPlaceholder(xla_cluster_name, g));
    854   }
    855 
    856   // Build the _Arg node, and replace key placeholder node with it.
    857   NodeDefBuilder arg_builder("key_arg", FunctionLibraryDefinition::kArgOp);
    858   arg_builder.Attr("T", DT_STRING);
    859   arg_builder.Attr("index", 0);
    860   NodeDef arg_def;
    861   TF_RETURN_IF_ERROR(arg_builder.Finalize(&arg_def));
    862   TF_RETURN_IF_ERROR(ReplaceNode(g, key_placeholder, arg_def).status());
    863 
    864   // Reset "device_ordinal" to placeholder value.
    865   TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(g));
    866 
    867   FunctionDef replace_fdef;
    868   TF_RETURN_IF_ERROR(GraphToFunctionDef(*g, func_name, &replace_fdef));
    869   TF_RETURN_IF_ERROR(fld->ReplaceFunction(func_name, replace_fdef));
    870   return Status::OK();
    871 }
    872 
    873 // Builds host side graph for If node.
    874 Status BuildHostGraphForIfNode(const string& xla_cluster_attr_name,
    875                                const string& outside_compilation_attr_name,
    876                                const string& xla_cluster_name,
    877                                const string& if_node_name,
    878                                const string& host_transfer_key,
    879                                const string& host_graph_func_name,
    880                                FunctionLibraryDefinition* fld,
    881                                const string& then_branch_host_func_name,
    882                                const string& else_branch_host_func_name) {
    883   Graph host_graph(fld);
    884   string outside_compilation_name = absl::StrCat("oc_if_", if_node_name);
    885   AttrValue device_ordinal_value;
    886   device_ordinal_value.set_placeholder("device_ordinal");
    887 
    888   // Step 1: add key placeholder node.
    889   TF_ASSIGN_OR_RETURN(
    890       Node * key_placeholder,
    891       AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph));
    892 
    893   // Step 2: build XlaRecvAtHost node to recv predicate.
    894   NodeDefBuilder recv_pred_builder(
    895       absl::StrCat("recv_oc_if_pred_", if_node_name), "_XlaRecvAtHost");
    896   recv_pred_builder.Attr("Toutputs", std::vector<DataType>{DT_BOOL});
    897   recv_pred_builder.Attr("key", host_transfer_key);
    898   recv_pred_builder.Attr("device_ordinal", device_ordinal_value);
    899   recv_pred_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
    900   recv_pred_builder.Attr(outside_compilation_attr_name,
    901                          outside_compilation_name);
    902   recv_pred_builder.Attr(kXlaHasHostTransferAttrName, true);
    903   recv_pred_builder.Input(key_placeholder->name(), 0, DT_STRING);
    904   NodeDef recv_pred_def;
    905   TF_RETURN_IF_ERROR(recv_pred_builder.Finalize(&recv_pred_def));
    906   Status s;
    907   Node* recv_pred_node = host_graph.AddNode(recv_pred_def, &s);
    908   TF_RETURN_IF_ERROR(s);
    909   host_graph.AddEdge(key_placeholder, 0, recv_pred_node, 0);
    910 
    911   // Step 3: rewrite `{then, else}_branch_host_func_name`, replace key
    912   // placeholder with an _Arg node.
    913   TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
    914       xla_cluster_name, then_branch_host_func_name, fld));
    915   TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
    916       xla_cluster_name, else_branch_host_func_name, fld));
    917 
    918   // Step 4: build If node to choose between `{then, else}_branch_host_graph`.
    919   NodeDefBuilder if_builder(absl::StrCat("oc_if_", if_node_name), "If");
    920   if_builder.Attr("Tcond", DT_BOOL);
    921   if_builder.Attr("Tin", std::vector<DataType>{DT_STRING});
    922   if_builder.Attr("Tout", std::vector<DataType>{});
    923   NameAttrList host_then_branch, host_else_branch;
    924   host_then_branch.set_name(then_branch_host_func_name);
    925   (*host_then_branch.mutable_attr())["device_ordinal"] = device_ordinal_value;
    926   host_else_branch.set_name(else_branch_host_func_name);
    927   (*host_else_branch.mutable_attr())["device_ordinal"] = device_ordinal_value;
    928   if_builder.Attr("then_branch", host_then_branch);
    929   if_builder.Attr("else_branch", host_else_branch);
    930   if_builder.Attr(kXlaHasHostTransferAttrName, true);
    931   if_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
    932   if_builder.Attr(outside_compilation_attr_name, outside_compilation_name);
    933   if_builder.Input(recv_pred_node->name(), 0, DT_BOOL);
    934   std::vector<NodeDefBuilder::NodeOut> if_inputs{
    935       {key_placeholder->name(), 0, DT_STRING}};
    936   if_builder.Input(if_inputs);
    937   NodeDef if_def;
    938   TF_RETURN_IF_ERROR(if_builder.Finalize(&if_def));
    939   Node* if_node = host_graph.AddNode(if_def, &s);
    940   TF_RETURN_IF_ERROR(s);
    941   host_graph.AddEdge(recv_pred_node, 0, if_node, 0);
    942   host_graph.AddEdge(key_placeholder, 0, if_node, 1);
    943 
    944   // Convert `host_graph` to function, and add a "device_ordinal" attr.
    945   FunctionDef oc_host_graph_fdef;
    946   TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name,
    947                                         &oc_host_graph_fdef));
    948   if (fld->Find(host_graph_func_name)) {
    949     TF_RETURN_IF_ERROR(
    950         fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef));
    951   } else {
    952     TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef));
    953   }
    954 
    955   return Status::OK();
    956 }
    957 
    958 // Rewrites loop cond to add a node which sends loop cond to host.
    959 Status AddSendLoopPredToLoopCond(FunctionLibraryDefinition* fld,
    960                                  const NameAttrList& loop_cond_func,
    961                                  const string& while_node_name,
    962                                  const string& host_transfer_key) {
    963   // Instantiate the loop cond function.
    964   FunctionBody* fbody = nullptr;
    965   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
    966       *fld->Find(loop_cond_func.name()), AttrSlice(&loop_cond_func.attr()), fld,
    967       [&](const string& op, const OpDef** sig) {
    968         return fld->LookUpOpDef(op, sig);
    969       },
    970       &fbody));
    971   std::unique_ptr<FunctionBody> fbody_deleter(fbody);
    972   Graph* g = fbody->graph;
    973 
    974   // Find the _Retval node and the loop cond node.
    975   Node* ret_node = nullptr;
    976   for (Node* n : g->nodes()) {
    977     if (n->type_string() == "_Retval") {
    978       if (ret_node) {
    979         return errors::Internal("Multiple return node for loop cond function ",
    980                                 loop_cond_func.name(), ": ",
    981                                 ret_node->DebugString(), " and ",
    982                                 n->DebugString());
    983       } else {
    984         ret_node = n;
    985       }
    986     }
    987   }
    988   if (!ret_node) {
    989     return errors::Internal("No _Retval node for loop cond function ",
    990                             loop_cond_func.name());
    991   }
    992   Node* loop_cond;
    993   TF_RETURN_IF_ERROR(ret_node->input_node(0, &loop_cond));
    994 
    995   // Build the XlaSendToHost node.
    996   NodeDefBuilder send_loop_cond_builder(
    997       absl::StrCat("send_oc_while_cond_", while_node_name), "XlaSendToHost");
    998   send_loop_cond_builder.Attr("Tinput", DT_BOOL);
    999   send_loop_cond_builder.Attr("key",
   1000                               absl::StrCat(host_transfer_key, "_dtoh_0"));
   1001   send_loop_cond_builder.Attr(kXlaTokenInputNodesAttrName,
   1002                               std::vector<string>{kXlaTokenArgNodeName});
   1003   send_loop_cond_builder.Input(loop_cond->name(), 0, DT_BOOL);
   1004   NodeDef send_loop_cond_def;
   1005   TF_RETURN_IF_ERROR(send_loop_cond_builder.Finalize(&send_loop_cond_def));
   1006   Status s;
   1007   Node* send_loop_cond_node = g->AddNode(send_loop_cond_def, &s);
   1008   TF_RETURN_IF_ERROR(s);
   1009   g->AddEdge(loop_cond, 0, send_loop_cond_node, 0);
   1010 
   1011   // Replace original function.
   1012   FunctionDef replace_fdef;
   1013   TF_RETURN_IF_ERROR(
   1014       GraphToFunctionDef(*g, loop_cond_func.name(), &replace_fdef));
   1015   TF_RETURN_IF_ERROR(fld->ReplaceFunction(loop_cond_func.name(), replace_fdef));
   1016 
   1017   return Status::OK();
   1018 }
   1019 
   1020 // Rewrites while loop cond function for host.
   1021 Status RewriteHostWhileLoopCond(
   1022     const string& cond_host_func_name, const string& while_node_name,
   1023     const string& host_transfer_key, const string& xla_cluster_attr_name,
   1024     const string& xla_cluster_name, const string& outside_compilation_attr_name,
   1025     const string& outside_compilation_name, FunctionLibraryDefinition* fld) {
   1026   // Replace key placeholder node with _Arg node.
   1027   TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
   1028       xla_cluster_name, cond_host_func_name, fld));
   1029 
   1030   // Instantiate cond function.
   1031   AttrValue device_ordinal_temp_value;
   1032   device_ordinal_temp_value.set_i(0);
   1033   protobuf::Map<string, AttrValue> attrs;
   1034   attrs["device_ordinal"] = device_ordinal_temp_value;
   1035   FunctionBody* cond_fbody = nullptr;
   1036   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
   1037       *fld->Find(cond_host_func_name), AttrSlice(&attrs), fld,
   1038       [&](const string& op, const OpDef** sig) {
   1039         return fld->LookUpOpDef(op, sig);
   1040       },
   1041       &cond_fbody));
   1042   std::unique_ptr<FunctionBody> cond_fbody_deleter(cond_fbody);
   1043   Graph* cond_graph = cond_fbody->graph;
   1044   Node* key_arg = nullptr;
   1045   for (Node* n : cond_graph->nodes()) {
   1046     if (n->type_string() == "_Arg") {
   1047       key_arg = n;
   1048     }
   1049   }
   1050   if (!key_arg) {
   1051     return errors::Internal(
   1052         "No _Arg node found for host compute key in function ",
   1053         cond_host_func_name);
   1054   }
   1055 
   1056   // Add an XlaRecvAtHost node to use as cond function return value.
   1057   // We don't need to set kXlaHasHostTransferAttrName for this node, because
   1058   // it's already added for the "While" node on the host.
   1059   NodeDefBuilder recv_pred_builder(
   1060       absl::StrCat("recv_oc_while_cond_", while_node_name), "_XlaRecvAtHost");
   1061   recv_pred_builder.Attr("Toutputs", std::vector<DataType>{DT_BOOL});
   1062   recv_pred_builder.Attr("key", host_transfer_key);
   1063   AttrValue device_ordinal_value;
   1064   device_ordinal_value.set_placeholder("device_ordinal");
   1065   recv_pred_builder.Attr("device_ordinal", device_ordinal_value);
   1066   recv_pred_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
   1067   recv_pred_builder.Attr(outside_compilation_attr_name,
   1068                          outside_compilation_name);
   1069   recv_pred_builder.Input(key_arg->name(), 0, DT_STRING);
   1070   NodeDef recv_pred_def;
   1071   TF_RETURN_IF_ERROR(recv_pred_builder.Finalize(&recv_pred_def));
   1072   Status s;
   1073   Node* recv_pred_node = cond_graph->AddNode(recv_pred_def, &s);
   1074   TF_RETURN_IF_ERROR(s);
   1075   cond_graph->AddEdge(key_arg, 0, recv_pred_node, 0);
   1076   NodeDefBuilder ret_builder(
   1077       absl::StrCat("recv_oc_while_cond_ret_", while_node_name), "_Retval");
   1078   ret_builder.Attr("T", DT_BOOL);
   1079   ret_builder.Attr("index", 0);
   1080   ret_builder.Input(recv_pred_node->name(), 0, DT_BOOL);
   1081   NodeDef ret_def;
   1082   TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def));
   1083   Node* ret_node = cond_graph->AddNode(ret_def, &s);
   1084   TF_RETURN_IF_ERROR(s);
   1085   cond_graph->AddEdge(recv_pred_node, 0, ret_node, 0);
   1086 
   1087   // Reset device_ordinal to placeholder value.
   1088   TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(cond_graph));
   1089 
   1090   // Replace original function.
   1091   FunctionDef cond_replace_fdef;
   1092   TF_RETURN_IF_ERROR(
   1093       GraphToFunctionDef(*cond_graph, cond_host_func_name, &cond_replace_fdef));
   1094   TF_RETURN_IF_ERROR(
   1095       fld->ReplaceFunction(cond_host_func_name, cond_replace_fdef));
   1096 
   1097   return Status::OK();
   1098 }
   1099 
   1100 // Rewrites while loop body function for host.
   1101 Status RewriteHostWhileLoopBody(
   1102     const string& body_host_func_name, const string& while_node_name,
   1103     const string& host_transfer_key, const string& xla_cluster_attr_name,
   1104     const string& xla_cluster_name, const string& outside_compilation_attr_name,
   1105     const string& outside_compilation_name, FunctionLibraryDefinition* fld) {
   1106   // Replace key placeholder node with _Arg node.
   1107   TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
   1108       xla_cluster_name, body_host_func_name, fld));
   1109 
   1110   // Instantiate body function.
   1111   AttrValue device_ordinal_temp_value;
   1112   device_ordinal_temp_value.set_i(0);
   1113   protobuf::Map<string, AttrValue> attrs;
   1114   attrs["device_ordinal"] = device_ordinal_temp_value;
   1115   FunctionBody* body_fbody = nullptr;
   1116   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
   1117       *fld->Find(body_host_func_name), AttrSlice(&attrs), fld,
   1118       [&](const string& op, const OpDef** sig) {
   1119         return fld->LookUpOpDef(op, sig);
   1120       },
   1121       &body_fbody));
   1122   std::unique_ptr<FunctionBody> body_fbody_deleter(body_fbody);
   1123   Graph* body_graph = body_fbody->graph;
   1124   Node* key_arg = nullptr;
   1125   for (Node* n : body_graph->nodes()) {
   1126     if (n->type_string() == "_Arg") {
   1127       key_arg = n;
   1128     }
   1129   }
   1130   if (!key_arg) {
   1131     return errors::Internal(
   1132         "No _Arg node found for host compute key in function ",
   1133         body_host_func_name);
   1134   }
   1135 
   1136   // Add a _Retval node to loop body.
   1137   NodeDefBuilder ret_builder(
   1138       absl::StrCat("recv_oc_while_body_ret_", while_node_name), "_Retval");
   1139   ret_builder.Attr("T", DT_STRING);
   1140   ret_builder.Attr("index", 0);
   1141   ret_builder.Input(key_arg->name(), 0, DT_STRING);
   1142   NodeDef ret_def;
   1143   TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def));
   1144   Status s;
   1145   Node* ret_node = body_graph->AddNode(ret_def, &s);
   1146   TF_RETURN_IF_ERROR(s);
   1147   body_graph->AddEdge(key_arg, 0, ret_node, 0);
   1148 
   1149   // Reset device_ordinal to placeholder value.
   1150   TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(body_graph));
   1151 
   1152   // Replace original function.
   1153   FunctionDef body_replace_fdef;
   1154   TF_RETURN_IF_ERROR(
   1155       GraphToFunctionDef(*body_graph, body_host_func_name, &body_replace_fdef));
   1156   TF_RETURN_IF_ERROR(
   1157       fld->ReplaceFunction(body_host_func_name, body_replace_fdef));
   1158 
   1159   return Status::OK();
   1160 }
   1161 
   1162 // Builds host side graph for while node.
   1163 Status BuildHostGraphForWhileNode(
   1164     const string& xla_cluster_attr_name,
   1165     const string& outside_compilation_attr_name, const string& xla_cluster_name,
   1166     const string& while_node_name, const string& host_transfer_key,
   1167     const string& host_graph_func_name, FunctionLibraryDefinition* fld,
   1168     const string& cond_host_func_name, const string& body_host_func_name) {
   1169   Graph host_graph(fld);
   1170   string outside_compilation_name = absl::StrCat("oc_while_", while_node_name);
   1171 
   1172   // Step 1: add key placeholder node.
   1173   TF_ASSIGN_OR_RETURN(
   1174       Node * key_placeholder,
   1175       AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph));
   1176 
   1177   // Step 2: rewrite cond function.
   1178   TF_RETURN_IF_ERROR(RewriteHostWhileLoopCond(
   1179       cond_host_func_name, while_node_name, host_transfer_key,
   1180       xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name,
   1181       outside_compilation_name, fld));
   1182 
   1183   // Step 3: rewrite body function.
   1184   TF_RETURN_IF_ERROR(RewriteHostWhileLoopBody(
   1185       body_host_func_name, while_node_name, host_transfer_key,
   1186       xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name,
   1187       outside_compilation_name, fld));
   1188 
   1189   // Step 4: build While node.
   1190   NodeDefBuilder while_builder(absl::StrCat("oc_while_", while_node_name),
   1191                                "While");
   1192   while_builder.Attr("T", std::vector<DataType>{DT_STRING});
   1193   NameAttrList func;
   1194   AttrValue device_ordinal_value;
   1195   device_ordinal_value.set_placeholder("device_ordinal");
   1196   (*func.mutable_attr())["device_ordinal"] = device_ordinal_value;
   1197   func.set_name(cond_host_func_name);
   1198   while_builder.Attr("cond", func);
   1199   func.set_name(body_host_func_name);
   1200   while_builder.Attr("body", func);
   1201   while_builder.Attr(kXlaHasHostTransferAttrName, true);
   1202   while_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
   1203   while_builder.Attr(outside_compilation_attr_name, outside_compilation_name);
   1204   std::vector<NodeDefBuilder::NodeOut> while_inputs{
   1205       {key_placeholder->name(), 0, DT_STRING}};
   1206   while_builder.Input(while_inputs);
   1207   NodeDef while_def;
   1208   TF_RETURN_IF_ERROR(while_builder.Finalize(&while_def));
   1209   Status s;
   1210   Node* while_node = host_graph.AddNode(while_def, &s);
   1211   TF_RETURN_IF_ERROR(s);
   1212   host_graph.AddEdge(key_placeholder, 0, while_node, 0);
   1213 
   1214   // Convert `host_graph` to function.
   1215   FunctionDef oc_host_graph_fdef;
   1216   TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name,
   1217                                         &oc_host_graph_fdef));
   1218   if (fld->Find(host_graph_func_name)) {
   1219     TF_RETURN_IF_ERROR(
   1220         fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef));
   1221   } else {
   1222     TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef));
   1223   }
   1224 
   1225   return Status::OK();
   1226 }
   1227 
   1228 // Builds host graph for func call nodes.
   1229 Status BuildHostGraphForFuncCallNode(const string& func_call_node_name,
   1230                                      const string& xla_cluster_name,
   1231                                      const string& func_call_host_func_name,
   1232                                      const string& host_graph_func_name,
   1233                                      FunctionLibraryDefinition* fld) {
   1234   Graph host_graph(fld);
   1235   AttrValue device_ordinal_value;
   1236   device_ordinal_value.set_placeholder("device_ordinal");
   1237 
   1238   // Step 1: add key placeholder node.
   1239   TF_ASSIGN_OR_RETURN(
   1240       Node * key_placeholder,
   1241       AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph));
   1242 
   1243   // Step 2: rewrite `host_func_name`, replace key placeholder with an _Arg
   1244   // node.
   1245   TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
   1246       xla_cluster_name, func_call_host_func_name, fld));
   1247 
   1248   // Step 3: build a function call node with `host_func_name`, with
   1249   // `key_placeholder` as input.
   1250   NodeDefBuilder call_builder(absl::StrCat("oc_call_", func_call_node_name),
   1251                               func_call_host_func_name, fld);
   1252   call_builder.Input(key_placeholder->name(), 0, DT_STRING);
   1253   call_builder.Attr("device_ordinal", device_ordinal_value);
   1254   call_builder.Attr(kXlaHasHostTransferAttrName, true);
   1255   NodeDef call_def;
   1256   TF_RETURN_IF_ERROR(call_builder.Finalize(&call_def));
   1257   Status s;
   1258   Node* call_node = host_graph.AddNode(call_def, &s);
   1259   TF_RETURN_IF_ERROR(s);
   1260   host_graph.AddEdge(key_placeholder, 0, call_node, 0);
   1261 
   1262   // Convert `host_graph` to function, and add a "device_ordinal" attr.
   1263   FunctionDef oc_host_graph_fdef;
   1264   TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name,
   1265                                         &oc_host_graph_fdef));
   1266   if (fld->Find(host_graph_func_name)) {
   1267     TF_RETURN_IF_ERROR(
   1268         fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef));
   1269   } else {
   1270     TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef));
   1271   }
   1272 
   1273   return Status::OK();
   1274 }
   1275 
   1276 Status ExtractOutsideCompilationForNodesWithAssociatedFunctions(
   1277     Graph* g, const string& xla_cluster_attr_name,
   1278     const string& outside_compilation_attr_name, const string& xla_cluster_name,
   1279     const std::map<string, int>& host_compute_core, FunctionLibraryRuntime* flr,
   1280     FunctionLibraryDefinition* fld, std::vector<string>* host_graphs,
   1281     std::vector<string>* shape_inference_graphs,
   1282     bool* has_outside_compilation) {
   1283   std::vector<Node*> if_nodes, while_nodes, func_call_nodes;
   1284   for (Node* n : g->nodes()) {
   1285     if (n->type_string() == "If") {
   1286       if_nodes.push_back(n);
   1287     } else if (n->type_string() == "While") {
   1288       while_nodes.push_back(n);
   1289     } else if (fld->Contains(n->type_string())) {
   1290       func_call_nodes.push_back(n);
   1291     } else if (n->type_string() == FunctionLibraryDefinition::kGradientOp) {
   1292       // Only gradient for user-defined function should be considered as
   1293       // function call node.
   1294       NameAttrList original_func;
   1295       TF_RETURN_IF_ERROR(GetNodeAttr(
   1296           n->def(), FunctionLibraryDefinition::kFuncAttr, &original_func));
   1297       if (fld->Contains(original_func.name())) {
   1298         func_call_nodes.push_back(n);
   1299       }
   1300     }
   1301   }
   1302 
   1303   for (Node* n : func_call_nodes) {
   1304     // Extract outside compilation for the function call.
   1305     bool func_has_outside_compilation = false;
   1306     NameAttrList func;
   1307     func.set_name(n->type_string());
   1308     typedef protobuf::Map<string, AttrValue> AttrMap;
   1309     *func.mutable_attr() = AttrMap(n->attrs().begin(), n->attrs().end());
   1310     string new_func_name = absl::StrCat(n->name(), "_oc");
   1311     string host_func_name = absl::StrCat("oc_func_call_host_", n->name());
   1312     TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
   1313         xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
   1314         func, new_func_name, host_func_name, host_compute_core, flr, fld,
   1315         shape_inference_graphs, &func_has_outside_compilation));
   1316 
   1317     // If the function call does not have outside compilation, nothing to do.
   1318     if (!func_has_outside_compilation) {
   1319       continue;
   1320     }
   1321 
   1322     *has_outside_compilation = true;
   1323 
   1324     // Change `n` to call the new function directly.
   1325     NodeDefBuilder replace_builder(n->name(), new_func_name, fld);
   1326     for (const Edge* e : n->in_edges()) {
   1327       if (e->IsControlEdge()) {
   1328         continue;
   1329       }
   1330       replace_builder.Input(e->src()->name(), e->src_output(),
   1331                             e->src()->output_type(e->src_output()));
   1332     }
   1333     for (const auto& attr : n->attrs()) {
   1334       replace_builder.Attr(attr.first, attr.second);
   1335     }
   1336     NodeDef replace_def;
   1337     TF_RETURN_IF_ERROR(replace_builder.Finalize(&replace_def));
   1338     TF_ASSIGN_OR_RETURN(Node * replace, ReplaceNode(g, n, replace_def));
   1339     replace->AddAttr(kXlaTokenInputNodesAttrName,
   1340                      std::vector<string>{kXlaTokenArgNodeName});
   1341 
   1342     // Build host side graph for the function call.
   1343     string oc_host_graph_name =
   1344         absl::StrCat("oc_func_host_graph_", replace->name());
   1345     TF_RETURN_IF_ERROR(
   1346         BuildHostGraphForFuncCallNode(replace->name(), xla_cluster_name,
   1347                                       host_func_name, oc_host_graph_name, fld));
   1348 
   1349     // Record the host graph.
   1350     host_graphs->push_back(oc_host_graph_name);
   1351   }
   1352 
   1353   for (Node* n : if_nodes) {
   1354     // Instantiate "then_branch" and "else_branch".
   1355     NameAttrList then_branch, else_branch;
   1356     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "then_branch", &then_branch));
   1357     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "else_branch", &else_branch));
   1358 
   1359     // Extract outside compilation for then_branch and else_branch.
   1360     bool then_branch_has_outside_compilation = false;
   1361     bool else_branch_has_outside_compilation = false;
   1362     string then_branch_host_func_name =
   1363                absl::StrCat("oc_then_branch_host_if_", n->name()),
   1364            else_branch_host_func_name =
   1365                absl::StrCat("oc_else_branch_host_if_", n->name());
   1366     string then_branch_xla_func_name = absl::StrCat(then_branch.name(), "_oc"),
   1367            else_branch_xla_func_name = absl::StrCat(else_branch.name(), "_oc");
   1368     TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
   1369         xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
   1370         then_branch, then_branch_xla_func_name, then_branch_host_func_name,
   1371         host_compute_core, flr, fld, shape_inference_graphs,
   1372         &then_branch_has_outside_compilation));
   1373     TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
   1374         xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
   1375         else_branch, else_branch_xla_func_name, else_branch_host_func_name,
   1376         host_compute_core, flr, fld, shape_inference_graphs,
   1377         &else_branch_has_outside_compilation));
   1378 
   1379     // If then/else branch do not have outside compilation, nothing to do.
   1380     if (!then_branch_has_outside_compilation &&
   1381         !else_branch_has_outside_compilation) {
   1382       continue;
   1383     }
   1384 
   1385     *has_outside_compilation = true;
   1386 
   1387     // Change If node to call the new functions.
   1388     then_branch.set_name(then_branch_xla_func_name);
   1389     n->ClearAttr("then_branch");
   1390     n->AddAttr("then_branch", then_branch);
   1391     else_branch.set_name(else_branch_xla_func_name);
   1392     n->ClearAttr("else_branch");
   1393     n->AddAttr("else_branch", else_branch);
   1394 
   1395     string host_transfer_key = absl::StrCat("oc_if_pred_", n->name());
   1396 
   1397     // XLA computation: add a SendToHost node to send cond predicate.
   1398     Node* pred_node;
   1399     TF_RETURN_IF_ERROR(n->input_node(0, &pred_node));
   1400     TF_ASSIGN_OR_RETURN(
   1401         Node * send_pred_node,
   1402         BuildSendIfPredNode(absl::StrCat("send_oc_if_pred_", n->name()),
   1403                             host_transfer_key, pred_node, g));
   1404     n->AddAttr(kXlaTokenInputNodesAttrName,
   1405                std::vector<string>{send_pred_node->name()});
   1406 
   1407     // Add a control edge from `send_pred_node` to If node, so XlaCompiler will
   1408     // visit If node after `send_pred_node`, thus the token output for
   1409     // `send_pred_node` has been generated.
   1410     g->AddControlEdge(send_pred_node, n);
   1411 
   1412     // Build host side graph for the "If" node.
   1413     string oc_host_graph_name = absl::StrCat("oc_if_host_graph_", n->name());
   1414     TF_RETURN_IF_ERROR(BuildHostGraphForIfNode(
   1415         xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
   1416         n->name(), host_transfer_key, oc_host_graph_name, fld,
   1417         then_branch_host_func_name, else_branch_host_func_name));
   1418     host_graphs->push_back(oc_host_graph_name);
   1419   }
   1420 
   1421   for (Node* n : while_nodes) {
   1422     // Instantiate "cond" and "body".
   1423     NameAttrList cond, body;
   1424     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "cond", &cond));
   1425     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "body", &body));
   1426 
   1427     // Extract outside compilation for cond and body.
   1428     bool cond_has_outside_compilation = false;
   1429     bool body_has_outside_compilation = false;
   1430     string cond_host_func_name = absl::StrCat("oc_cond_host_while_", n->name()),
   1431            body_host_func_name = absl::StrCat("oc_body_host_while_", n->name());
   1432     string cond_xla_func_name = absl::StrCat(cond.name(), "_oc"),
   1433            body_xla_func_name = absl::StrCat(body.name(), "_oc");
   1434     TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
   1435         xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
   1436         cond, cond_xla_func_name, cond_host_func_name, host_compute_core, flr,
   1437         fld, shape_inference_graphs, &cond_has_outside_compilation));
   1438     TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
   1439         xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
   1440         body, body_xla_func_name, body_host_func_name, host_compute_core, flr,
   1441         fld, shape_inference_graphs, &body_has_outside_compilation));
   1442 
   1443     // If cond/body do not have outside compilation, nothing to do.
   1444     if (!cond_has_outside_compilation && !body_has_outside_compilation) {
   1445       continue;
   1446     }
   1447 
   1448     *has_outside_compilation = true;
   1449 
   1450     // Change While node to call the new functions.
   1451     cond.set_name(cond_xla_func_name);
   1452     n->ClearAttr("cond");
   1453     n->AddAttr("cond", cond);
   1454     body.set_name(body_xla_func_name);
   1455     n->ClearAttr("body");
   1456     n->AddAttr("body", body);
   1457 
   1458     string host_transfer_key = absl::StrCat("oc_while_pred_", n->name());
   1459 
   1460     // XLA computation: rewrite cond function to add a SendToHost node to send
   1461     // loop predicate.
   1462     TF_RETURN_IF_ERROR(
   1463         AddSendLoopPredToLoopCond(fld, cond, n->name(), host_transfer_key));
   1464     n->AddAttr(kXlaTokenInputNodesAttrName,
   1465                std::vector<string>{kXlaTokenArgNodeName});
   1466 
   1467     // Build host side graph for the "While" node.
   1468     string oc_host_graph_name = absl::StrCat("oc_while_host_graph_", n->name());
   1469     TF_RETURN_IF_ERROR(BuildHostGraphForWhileNode(
   1470         xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
   1471         n->name(), host_transfer_key, oc_host_graph_name, fld,
   1472         cond_host_func_name, body_host_func_name));
   1473     host_graphs->push_back(oc_host_graph_name);
   1474   }
   1475 
   1476   return Status::OK();
   1477 }
   1478 
   1479 }  // namespace
   1480 
   1481 Status RewriteOutsideCompilationSubgraphFn::operator()(
   1482     const std::vector<OutputTensor>& arg_source_tensors,
   1483     std::unique_ptr<Graph>* graph, std::vector<int>* input_permutation,
   1484     std::vector<int>* output_permutation, NodeDef* node_def) {
   1485   string old_name = node_def->op();
   1486   string new_name = absl::StrCat(xla_cluster_name_, "_", old_name);
   1487   node_def->set_op(new_name);
   1488   node_def->set_name(new_name);
   1489 
   1490   // Later we will run PruneForReverseReachability(), so make sure all original
   1491   // nodes are reachable from sink node and won't be removed.
   1492   FixupSourceAndSinkEdges(graph->get());
   1493 
   1494   // Step 1: create a key placeholder node.
   1495   TF_ASSIGN_OR_RETURN(
   1496       Node * key_placeholder,
   1497       AddHostComputeKeyPlaceholder(xla_cluster_name_, graph->get()));
   1498 
   1499   // Step 2: build RecvAtHost node, and replace all _Arg nodes with it.
   1500   std::vector<DataType> recv_at_host_dtypes;
   1501   TF_ASSIGN_OR_RETURN(
   1502       Node * recv_at_host_node,
   1503       ReplaceArgNodesWithRecvAtHostNode(graph->get(), new_name,
   1504                                         &recv_at_host_dtypes, key_placeholder));
   1505 
   1506   // Step 3: build SendFromHost node, and replace all _Retval nodes with it.
   1507   std::vector<DataType> send_from_host_dtypes;
   1508   TF_ASSIGN_OR_RETURN(
   1509       Node * send_from_host_node,
   1510       ReplaceRetNodesWithSendFromHostNode(
   1511           graph->get(), new_name, &send_from_host_dtypes, key_placeholder));
   1512 
   1513   // Step 4: add XLA cluster and outside compilation attr.
   1514   for (Node* n : (*graph)->nodes()) {
   1515     if (IsKeyPlaceholderNode(*n)) {
   1516       continue;
   1517     }
   1518 
   1519     n->AddAttr(xla_cluster_attr_name_, xla_cluster_name_);
   1520     n->AddAttr(outside_compilation_attr_name_, old_name);
   1521   }
   1522 
   1523   // Check whether we have all input shapes for XlaSendFromHost. If we do, we
   1524   // will set `shapes` attr for the call node; otherwise we will save the
   1525   // shape inference graph and set `shape_inference_graph` for the call node.
   1526   absl::optional<std::vector<PartialTensorShape>> shapes =
   1527       GetInferredInputShapes(send_from_host_dtypes.size(), send_from_host_node);
   1528   for (Node* n : (*graph)->nodes()) {
   1529     n->ClearAttr(kXlaInferredShapesAttrName);
   1530   }
   1531 
   1532   // Step 5: add control edges for originally XLA <-> outside compilation
   1533   // control edges.
   1534   for (Node* n : (*graph)->nodes()) {
   1535     if (HasNodeAttr(n->def(), kXlaConnectedToXlaComputationAttrName)) {
   1536       (*graph)->AddControlEdge(n, send_from_host_node);
   1537       n->ClearAttr(kXlaConnectedToXlaComputationAttrName);
   1538     }
   1539     if (HasNodeAttr(n->def(), kXlaConnectedFromXlaComputationAttrName)) {
   1540       (*graph)->AddControlEdge(recv_at_host_node, n);
   1541       n->ClearAttr(kXlaConnectedFromXlaComputationAttrName);
   1542     }
   1543   }
   1544 
   1545   // Step 6: RecvAtHost/SendFromHost/key_placeholder might be dead nodes. Prune
   1546   // them if necessary.
   1547   // - RecvAtHost should be pruned iff it has no output data/control edges. If
   1548   //   it has any output edge, it will be reverse reachable from sink node. We
   1549   //   don't need to do anything special.
   1550   // - SendFromHost should be pruned iff it has no input data/control edges. If
   1551   //   it has input edges other than key_placeholder, we connect it to sink
   1552   //   node so it won't be pruned.
   1553   // - key_placeholder should be pruned iff RecvAtHost/SendFromHost are pruned.
   1554   //   We don't need to do anything special.
   1555   if (send_from_host_node->in_edges().size() > 1) {
   1556     (*graph)->AddControlEdge(send_from_host_node, (*graph)->sink_node());
   1557   }
   1558   PruneForReverseReachability(
   1559       graph->get(), std::unordered_set<const Node*>{(*graph)->sink_node()});
   1560 
   1561   // Step 7: add necessary attributes to function call node, so we can replace
   1562   // it with HostCompute node later.
   1563   AddNodeAttr("_outside_compilation_subgraph", old_name, node_def);
   1564   if (shapes) {
   1565     NameAttrList shape_inference_graph;
   1566     AddNodeAttr("shape_inference_graph", shape_inference_graph, node_def);
   1567     AddNodeAttr("shapes", *shapes, node_def);
   1568   } else {
   1569     string shape_inference_func_name =
   1570         absl::StrCat("_outside_compilation_shape_inference_", new_name);
   1571     NameAttrList shape_inference_graph;
   1572     shape_inference_graph.set_name(shape_inference_func_name);
   1573     AddNodeAttr("shape_inference_graph", shape_inference_graph, node_def);
   1574     AddNodeAttr("shapes", std::vector<TensorShapeProto>{}, node_def);
   1575   }
   1576   AddNodeAttr("ancestors", std::vector<string>{}, node_def);
   1577   AddNodeAttr("Tinputs", recv_at_host_dtypes, node_def);
   1578   AddNodeAttr("Toutputs", send_from_host_dtypes, node_def);
   1579   AddNodeAttr("key", absl::StrCat("host_compute_channel_", new_name), node_def);
   1580 
   1581   return Status::OK();
   1582 }
   1583 
   1584 Status ExtractOutsideCompilationForFunction(
   1585     const string& xla_cluster_attr_name,
   1586     const string& outside_compilation_attr_name, const string& xla_cluster_name,
   1587     const NameAttrList& func_name_attrs, const string& new_func_name,
   1588     const string& host_graph_func_name,
   1589     const std::map<string, int>& host_compute_core, FunctionLibraryRuntime* flr,
   1590     FunctionLibraryDefinition* fld, std::vector<string>* shape_inference_graphs,
   1591     bool* has_outside_compilation) {
   1592   // Convert the function to graph.
   1593   const string& func_name = func_name_attrs.name();
   1594   FunctionLibraryRuntime::Handle handle;
   1595   TF_RETURN_IF_ERROR(
   1596       flr->Instantiate(func_name, AttrSlice(&func_name_attrs.attr()), &handle));
   1597   Status ret_status = Status::OK();
   1598   auto cleanup_handle = gtl::MakeCleanup([&]() {
   1599     auto s = flr->ReleaseHandle(handle);
   1600     if (!s.ok()) {
   1601       ret_status.Update(s);
   1602     }
   1603   });
   1604   const FunctionBody* fbody = flr->GetFunctionBody(handle);
   1605 
   1606   // Check if we have outside compilation nodes.
   1607   *has_outside_compilation = false;
   1608   for (Node* n : fbody->graph->nodes()) {
   1609     if (HasNodeAttr(n->def(), outside_compilation_attr_name)) {
   1610       *has_outside_compilation = true;
   1611       break;
   1612     }
   1613   }
   1614   // We cannot early return here, because we might have outside compilation in
   1615   // If/While function body.
   1616 
   1617   // Preprocess edges between different outside compilations. They will be
   1618   // restored in `ConstructHostGraph()`.
   1619   TF_RETURN_IF_ERROR(PreprocessEdgesBetweenOutsideCompilations(
   1620       fbody->graph, outside_compilation_attr_name));
   1621   if (VLOG_IS_ON(4)) {
   1622     DumpGraphToFile(
   1623         absl::StrCat("extract_outside_compilation_for_func_before_", func_name),
   1624         *fbody->graph, fld);
   1625   }
   1626 
   1627   // Encapsulate outside_compilation cluster into function call node.
   1628   std::unique_ptr<Graph> graph_out;
   1629   RewriteOutsideCompilationSubgraphFn rewrite_fn(
   1630       xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name);
   1631   TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions(
   1632       outside_compilation_attr_name, "", *fbody->graph, rewrite_fn,
   1633       /*reuse_existing_functions=*/true, &graph_out, fld));
   1634 
   1635   // Replace outside_compilation function nodes with HostCompute ops.
   1636   std::vector<Node*> outside_compilation_nodes;
   1637   std::vector<string> outside_compilation_host_graphs;
   1638   for (Node* n : graph_out->nodes()) {
   1639     if (HasNodeAttr(n->def(), "_outside_compilation_subgraph")) {
   1640       outside_compilation_nodes.push_back(n);
   1641       outside_compilation_host_graphs.push_back(n->name());
   1642 
   1643       // If we could not infer shapes for XlaSendFromHost inputs statically, we
   1644       // will set the "shape_inference_graph" attribute. In that case, copy
   1645       // outside compilation subgraph as shape inference graph in `fld`.
   1646       NameAttrList shape_inference_graph;
   1647       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "shape_inference_graph",
   1648                                      &shape_inference_graph));
   1649       if (!shape_inference_graph.name().empty()) {
   1650         shape_inference_graphs->push_back(shape_inference_graph.name());
   1651 
   1652         const FunctionDef* xla_fdef = fld->Find(n->name());
   1653         if (!xla_fdef) {
   1654           return errors::Internal("Cannot find XLA function ", n->name());
   1655         }
   1656         FunctionDef shape_inference_fdef = *xla_fdef;
   1657         shape_inference_fdef.mutable_signature()->set_name(
   1658             shape_inference_graph.name());
   1659         if (fld->Find(shape_inference_graph.name())) {
   1660           TF_RETURN_IF_ERROR(fld->ReplaceFunction(shape_inference_graph.name(),
   1661                                                   shape_inference_fdef));
   1662         } else {
   1663           TF_RETURN_IF_ERROR(fld->AddFunctionDef(shape_inference_fdef));
   1664         }
   1665       }
   1666     }
   1667   }
   1668   for (Node* n : outside_compilation_nodes) {
   1669     TF_RETURN_IF_ERROR(ValidateOutsideCompilationCallNode(n));
   1670     TF_RETURN_IF_ERROR(ReplaceOrRemoveOutsideCompilationCallNode(
   1671         graph_out.get(), n, host_compute_core));
   1672   }
   1673 
   1674   // Handle nodes with associated functions.
   1675   TF_RETURN_IF_ERROR(ExtractOutsideCompilationForNodesWithAssociatedFunctions(
   1676       graph_out.get(), xla_cluster_attr_name, outside_compilation_attr_name,
   1677       xla_cluster_name, host_compute_core, flr, fld,
   1678       &outside_compilation_host_graphs, shape_inference_graphs,
   1679       has_outside_compilation));
   1680 
   1681   // Construct host graph.
   1682   TF_RETURN_IF_ERROR(ConstructHostGraph(
   1683       xla_cluster_name, outside_compilation_attr_name,
   1684       outside_compilation_host_graphs, fld, host_graph_func_name));
   1685 
   1686   // Remove the outside compilation graphs from function library.
   1687   for (const string& func : outside_compilation_host_graphs) {
   1688     TF_RETURN_IF_ERROR(fld->RemoveFunction(func));
   1689   }
   1690 
   1691   // Replace original function.
   1692   FunctionDef updated_fdef;
   1693   TF_RETURN_IF_ERROR(
   1694       GraphToFunctionDef(*graph_out, new_func_name, &updated_fdef));
   1695   const FunctionDef* original_fdef = fld->Find(func_name);
   1696   if (original_fdef) {
   1697     for (const auto& attr : original_fdef->attr()) {
   1698       (*updated_fdef.mutable_attr())[attr.first] = attr.second;
   1699     }
   1700   }
   1701   if (fld->Find(new_func_name)) {
   1702     TF_RETURN_IF_ERROR(fld->ReplaceFunction(new_func_name, updated_fdef));
   1703   } else {
   1704     TF_RETURN_IF_ERROR(fld->AddFunctionDef(updated_fdef));
   1705   }
   1706   if (VLOG_IS_ON(4)) {
   1707     DumpGraphToFile(
   1708         absl::StrCat("extract_outside_compilation_for_func_after_", func_name),
   1709         *graph_out, fld);
   1710   }
   1711 
   1712   return ret_status;
   1713 }
   1714 
   1715 Status ExtractOutsideCompilation(
   1716     const string& xla_cluster_attr_name,
   1717     const string& outside_compilation_attr_name,
   1718     const std::unordered_map<string, XlaClusterInfo>& clusters, Graph* g,
   1719     FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld) {
   1720   if (VLOG_IS_ON(4)) {
   1721     DumpGraphToFile("extract_outside_compilation_before", *g, fld);
   1722   }
   1723 
   1724   std::vector<string> shape_inference_graphs;
   1725   for (auto& iter : clusters) {
   1726     string xla_cluster_name = iter.first;
   1727     Node* n = iter.second.node;
   1728     auto const& func_name_attrs = iter.second.func_name_attrs;
   1729     auto const& host_compute_core = iter.second.host_compute_core;
   1730 
   1731     bool has_outside_compilation;
   1732     string host_graph_func_name = absl::StrCat("oc_host_graph_", n->name());
   1733     TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
   1734         xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
   1735         func_name_attrs, func_name_attrs.name(), host_graph_func_name,
   1736         host_compute_core, flr, fld, &shape_inference_graphs,
   1737         &has_outside_compilation));
   1738     TF_RETURN_IF_ERROR(
   1739         ExpandHostGraphIntoMainGraph(g, fld, host_graph_func_name, n));
   1740     TF_RETURN_IF_ERROR(fld->RemoveFunction(host_graph_func_name));
   1741   }
   1742 
   1743   for (auto shape_inference_graph_name : shape_inference_graphs) {
   1744     TF_RETURN_IF_ERROR(
   1745         RewriteShapeInferenceGraph(shape_inference_graph_name, g, fld));
   1746   }
   1747 
   1748   if (VLOG_IS_ON(4)) {
   1749     DumpGraphToFile("extract_outside_compilation_after", *g, fld);
   1750   }
   1751   return Status::OK();
   1752 }
   1753 
   1754 }  // namespace tensorflow
   1755