Home | History | Annotate | Download | only in jit
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
     17 
     18 #include <functional>
     19 #include <memory>
     20 #include <numeric>
     21 #include <string>
     22 #include <unordered_map>
     23 #include <vector>
     24 
     25 #include "tensorflow/compiler/jit/graph_to_functiondef.h"
     26 #include "tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h"
     27 #include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
     28 #include "tensorflow/compiler/tf2xla/const_analysis.h"
     29 #include "tensorflow/compiler/tf2xla/dump_graph.h"
     30 #include "tensorflow/compiler/xla/status_macros.h"
     31 #include "tensorflow/core/common_runtime/function.h"
     32 #include "tensorflow/core/common_runtime/optimization_registry.h"
     33 #include "tensorflow/core/common_runtime/shape_refiner.h"
     34 #include "tensorflow/core/framework/function.h"
     35 #include "tensorflow/core/framework/graph_def_util.h"
     36 #include "tensorflow/core/framework/node_def_builder.h"
     37 #include "tensorflow/core/framework/node_def_util.h"
     38 #include "tensorflow/core/graph/algorithm.h"
     39 #include "tensorflow/core/graph/graph.h"
     40 #include "tensorflow/core/graph/graph_def_builder.h"
     41 #include "tensorflow/core/graph/tensor_id.h"
     42 #include "tensorflow/core/lib/gtl/flatset.h"
     43 #include "tensorflow/core/lib/gtl/map_util.h"
     44 #include "tensorflow/core/lib/hash/hash.h"
     45 #include "tensorflow/core/lib/strings/str_util.h"
     46 #include "tensorflow/core/lib/strings/strcat.h"
     47 #include "tensorflow/core/public/session_options.h"
     48 #include "tensorflow/core/public/version.h"
     49 #include "tensorflow/core/util/device_name_utils.h"
     50 
     51 namespace tensorflow {
     52 
     53 const char* const kXlaCompiledKernelAttr = "_XlaCompiledKernel";
     54 const char* const kXlaNumConstantArgsAttr = "_XlaNumConstantArgs";
     55 const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs";
     56 
     57 namespace {
     58 
     59 bool AreAllParentsConst(const Node& n,
     60                         const gtl::FlatSet<const Node*>& runtime_const_nodes) {
     61   if (n.type_string() == "GuaranteeConst" || n.type_string() == "Const") {
     62     // If the current node is itself a cast-to-const, no need
     63     // to look at the incoming edges.
     64     return true;
     65   }
     66 
     67   bool all_parents_const = true;
     68   bool atleast_one_non_control_edge = false;
     69   for (const Edge* in : n.in_edges()) {
     70     atleast_one_non_control_edge =
     71         atleast_one_non_control_edge || !in->IsControlEdge();
     72     if (!in->IsControlEdge() && runtime_const_nodes.count(in->src()) == 0) {
     73       all_parents_const = false;
     74       break;
     75     }
     76   }
     77   return all_parents_const && atleast_one_non_control_edge;
     78 }
     79 
     80 void MarkGuaranteedConstants(
     81     const Graph& graph,
     82     const std::vector<std::pair<const Node*, Node*>>& src_arg_pairs) {
     83   gtl::FlatSet<const Node*> guaranteed_const_nodes;
     84   std::vector<const Node*> srcs;
     85   srcs.reserve(src_arg_pairs.size());
     86   for (const auto& src_arg : src_arg_pairs) {
     87     srcs.push_back(src_arg.first);
     88   }
     89   ReverseDFSFrom(graph, srcs, /*enter=*/nullptr,
     90                  /*leave=*/[&guaranteed_const_nodes](const Node* n) {
     91                    // TODO(vinuraja): Doesn't work in the presence of loops.
     92                    if (AreAllParentsConst(*n, guaranteed_const_nodes)) {
     93                      guaranteed_const_nodes.insert(n);
     94                    }
     95                  });
     96 
     97   for (auto& src_arg : src_arg_pairs) {
     98     if (guaranteed_const_nodes.count(src_arg.first) != 0) {
     99       VLOG(1) << "Guaranteed const found: " << src_arg.first->DebugString();
    100       src_arg.second->AddAttr("_is_guaranteed_constant", true);
    101     }
    102   }
    103 }
    104 
    105 // A node/slot pair.
    106 // TODO(phawkins): is there a common definition of this?
    107 struct NodeSlot {
    108   NodeSlot() : node(nullptr), slot(-1), dtype(DT_INVALID) {}
    109   NodeSlot(const Node* node, int slot)
    110       : node(node), slot(slot), dtype(DT_INVALID) {}
    111   NodeSlot(const Node* node, int slot, DataType dtype)
    112       : node(node), slot(slot), dtype(dtype) {}
    113 
    114   const Node* node;
    115   int slot;
    116 
    117   // Optional: used to record the destination type of a source NodeSlot in case
    118   // the source output is a Ref type that is cast to a Tensor at the
    119   // destination.
    120   DataType dtype;
    121 
    122   bool operator==(const NodeSlot& other) const {
    123     return node == other.node && slot == other.slot && dtype == other.dtype;
    124   }
    125 
    126   // Leave dtype out of the hash since there are never two NodeSlots with the
    127   // same node and slot and different dtypes.
    128   struct Hasher {
    129     uint64 operator()(NodeSlot const& s) const {
    130       return Hash64Combine(std::hash<const Node*>()(s.node),
    131                            std::hash<int>()(s.slot));
    132     }
    133   };
    134 
    135   struct PairHasher {
    136     uint64 operator()(std::pair<NodeSlot, NodeSlot> const& s) const {
    137       return Hash64Combine(Hasher()(s.first), Hasher()(s.second));
    138     }
    139   };
    140 };
    141 
    142 // TODO(phawkins) add a canonical copy of these operator names and refactor
    143 // everything to use it.
    144 static const char* const kArgOp = "_Arg";
    145 static const char* const kRetValOp = "_Retval";
    146 static const char* const kHostComputeOp = "_XlaHostCompute";
    147 static const char* const kSendFromHostOp = "_XlaSendFromHost";
    148 static const char* const kRecvAtHostOp = "_XlaRecvAtHost";
    149 
    150 class Encapsulator {
    151  public:
    152   Encapsulator(string group_attribute, string outside_compilation_attribute,
    153                Graph const* graph_in)
    154       : group_attribute_(std::move(group_attribute)),
    155         outside_compilation_attribute_(
    156             std::move(outside_compilation_attribute)),
    157         graph_in_(graph_in) {}
    158 
    159   // Find subgraphs marked with 'group_attribute', and build a new
    160   // subgraph, one for each value of 'group_attribute'.
    161   Status SplitIntoSubgraphs();
    162 
    163   // Build a FunctionDef for each subgraph, and add it 'library'. The values of
    164   // the 'group_attribute' annotations become the function names.
    165   // If 'reuse_existing_functions' is set, use an existing function with the
    166   // same name, if any.
    167   // If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before
    168   // function conversion.
    169   Status BuildFunctionDefs(const RewriteSubgraphFn& rewrite_subgraph_fn,
    170                            bool reuse_existing_functions,
    171                            FunctionLibraryDefinition* library);
    172 
    173   // Write a copy of the input graph to 'graph_out', where the subgraphs are
    174   // replaced with calls to the new functions.
    175   Status BuildOutputGraph(bool parallel_checking, Graph* graph_out,
    176                           FunctionLibraryDefinition* library);
    177 
    178  private:
    179   // A subgraph of the input, all marked with a common 'group_attribute'
    180   // value. A subgraph may contain multiple `outside_compilation' clusters.
    181   //
    182   // In the following simple example, A, B, ..., E are nodes in the original
    183   // graph. The group attributes and outside_compilation attributes g and oc are
    184   // each shown as either 0 or empty.
    185   //
    186   //  A  -->  B  -->  C  -->  D  -->  E
    187   //  g:      g:0     g:0     g:0     g:
    188   //  oc:     oc:     oc:0    oc:     oc:
    189   //
    190   // The example is rewritten to two graphs; one on the host and one to be
    191   // compiled. The host graph is as follows. RAH is a RecvAtHost node receiving
    192   // input from the compiled cluster, and SFH is a SendFromHost node sending
    193   // input back to the compiled cluster. Dotted edges are control edges. A
    194   // 'sequencing' node S is inserted, and both RAH and SFH are connected via S
    195   // to E (and in general all nodes that depend on nodes in the compiled
    196   // cluster) to ensure that they are not pruned.
    197   //
    198   //  A  -->  Call  -->  E
    199   //                     ^
    200   //                     .
    201   //           ........> S
    202   //       ....          ^
    203   //     ..             .
    204   //  RAH -->  C  --> SFH
    205   //
    206   // The compiled cluster is as follows. HC is a HostCompute node which is the
    207   // source of a channel to the RAH node above and the destination of a channel
    208   // from the SFH node above.
    209   //
    210   //  Arg  --> B  --> HC  --> D --> Retval
    211   //
    212   // The channels HC/RAH and SFH/HC each transmit multiple tensors, so there is
    213   // at most one RAH and SFH in each outside_compilation cluster. This design is
    214   // preferred over adding separate Arg/Retval nodes for each transmitted value
    215   // because it allows optimizations to the host code that would like to limit
    216   // communication between host and device and, e.g., raise only one interrupt
    217   // per channel rather than one per transmitted value.
    218   //
    219   // The shapes of the outputs from the HC node in general cannot be determined
    220   // until the shapes of its inputs are known at compile time, since e.g.,
    221   // above, the shape of C's outputs aren't known until the shape of its inputs
    222   // are known. If the shapes of the HC's outputs can be determined during the
    223   // rewrite, they are stored in the node's 'shapes' attr. Otherwise a minimal
    224   // graph is stored in the shape_inference_graph attr. This graph can be used
    225   // when compiling the HC Op to determined the shape of the SFH inputs given
    226   // the shapes of any ancestor RAH outputs. If it can be determined that the
    227   // shape of the SFH inputs will not be inferrable even once the shapes of the
    228   // RAH outputs are known, an error is returned by the rewriter.
    229   class Subgraph {
    230    public:
    231     // Creates a graph to build the subgraph in, if it doesn't already exist,
    232     // using the same op registry and versions as graph_in.
    233     Node* MakeNodeImage(const Graph* graph_in, Node* node);
    234 
    235     // Returns the graph the subgraph is being built in.
    236     Graph* GetGraph() const;
    237 
    238     // Builds a FunctionDef, and adds it to 'library'. The value of the
    239     // 'group_attribute' annotations becomes the function name.  If
    240     // 'reuse_existing_functions' is set, use an existing function with the same
    241     // name, if any.  If 'rewrite_subgraph_fn' is set, it is applied to the
    242     // subgraph before function conversion.
    243     Status BuildFunctionDef(const string& name_in,
    244                             const RewriteSubgraphFn& rewrite_subgraph_fn,
    245                             bool reuse_existing_functions,
    246                             FunctionLibraryDefinition* library);
    247 
    248     // Adds the function call node to graph_out.
    249     Status AddFunctionCallNode(
    250         const std::unordered_map<const Node*, Node*>& node_images,
    251         bool parallel_checking, Graph* graph_out);
    252 
    253     // Adds _RecvAtHost and _SendFromHost nodes, where needed, to graph_out.
    254     Status AddOutsideCompilationHostIONodes(
    255         const string& subgraph_name,
    256         const std::unordered_map<const Node*, Node*>& node_images,
    257         Graph* graph_out);
    258 
    259     // Returns the names of all the outside_compilation subgraphs in this
    260     // Subgraph.
    261     void GetOutsideCompilationSubgraphNames(std::vector<string>* names) const;
    262 
    263     // Returns the Node that inputs to the function should be wired up to.
    264     Node* GetCallNodeForInputs() const;
    265 
    266     // Returns the Node that outputs to the function should be wired up to.
    267     Node* GetCallNodeForOutputs() const;
    268 
    269     // Returns the index of the arg that the dst of edge should connect to.
    270     int GetArgIndexForEdge(const Edge* edge) const;
    271 
    272     // Returns the index of the result that the src of edge should connect to.
    273     int GetResultIndexForEdge(const Edge* edge) const;
    274 
    275     // Returns the RecvAtHost node for an outside_compilation subgraph.
    276     Node* GetRecvAtHostNode(
    277         const string& outside_compilation_subgraph_name) const;
    278 
    279     // Returns the output slot for the RecvAtHost node that corresponds to the
    280     // source of edge in an outside_compilation subgraph.
    281     int GetRecvAtHostSlot(const string& outside_compilation_subgraph_name,
    282                           const Edge* edge) const;
    283 
    284     // Returns the SendFromHost node for an outside_compilation subgraph.
    285     Node* GetSendFromHostNode(
    286         const string& outside_compilation_subgraph_name) const;
    287 
    288     // Returns the input slot for the SendFromHost node that corresponds to the
    289     // destination of edge in an outside_compilation subgraph.
    290     int GetSendFromHostSlot(const string& outside_compilation_subgraph_name,
    291                             const Edge* edge) const;
    292 
    293     // Creates an _Arg node for the src node of edge, and add its index to
    294     // args_by_src_, if none exists yet. Also adds its index to args_by_dst_,
    295     // and adds the edge within the subgraph from the _Arg node to the image of
    296     // the dst node.
    297     Status RecordArg(const Edge* edge,
    298                      const std::unordered_map<const Node*, Node*>& node_images,
    299                      std::vector<std::pair<const Node*, Node*>>* src_arg_pairs);
    300 
    301     // Creates a _Retval node for the src node of edge, and add it to results_,
    302     // if none exists yet. If a new _Retval node is created, also adds the edge
    303     // within the subgraph from the src to the _Retval node.
    304     Status RecordResult(
    305         const Edge* edge,
    306         const std::unordered_map<const Node*, Node*>& node_images);
    307 
    308     // Creates an outside_compilation subgraph for outside_compilation_id if
    309     // none exists yet. Creates an entry for the src node of edge in the list of
    310     // inputs for the outside_compilation subgraph, if none exists yet.
    311     void RecordOutsideCompilationInputOrControl(
    312         const string& outside_compilation_id, const Edge* edge);
    313 
    314     // Creates an outside_compilation subgraph for outside_compilation_id if
    315     // none exists yet. Creates an entry for the src node of edge in the list of
    316     // outputs by src for the outside_compilation subgraph, if none exists
    317     // yet. Creates an entry for the dst node of edge in the list of outputs by
    318     // dst for the outside_compilation subgraph.
    319     void RecordOutsideCompilationOutputOrControl(
    320         const string& outside_compilation_id, const Edge* edge);
    321 
    322     // Adds the HostCompute nodes for each outside_compilation subgraph.
    323     Status AddHostComputes(
    324         const string& subgraph_name,
    325         const std::unordered_map<const Node*, Node*>& node_images);
    326 
    327     // Creates the sequencer node if it doesn't exist, adding it to graph_out.
    328     Status MakeSequencingNode(const string& subgraph_name, Graph* graph_out);
    329 
    330     // If there is a sequencer node, adds a control edge from the sequencer to
    331     // all the downstream nodes of call_node_outputs.
    332     void ConnectSequencerToOutputs(Graph* graph_out);
    333 
    334     Status AddShapeInferenceInfo(
    335         const string& outside_compilation_subgraph_name,
    336         const std::vector<TensorShapeProto>& shapes, GraphDef* inference_graph);
    337 
    338     Status ReplaceFunctionDef(FunctionLibraryDefinition* library);
    339 
    340    private:
    341     struct OutsideCompilationSubgraph {
    342       // Map from source (producer node/slot) tensors in the original graph to
    343       // input index (slot number in the HostCompute/RecvAtHost nodes that will
    344       // be created) for the outside_compilation subgraph.
    345       std::unordered_map<NodeSlot, int, NodeSlot::Hasher> inputs;
    346 
    347       // Set of nodes in the original graph that are the source of control edges
    348       // that cross from the containing compiled subgraph into the
    349       // outside_compilation subgraph. These are recorded by
    350       // RecordOutsideCompilationInputOrControl while walking all the subgraph
    351       // edges, and lifted control edges within the subgraph are added by
    352       // AddSendsToOutsideCompilation once the _HostCompute node has been
    353       // created. The matching control edge from _RecvAtHost to the
    354       // destination is added by CopyEdgeToOutputGraph.
    355       std::unordered_set<const Node*> control_inputs;
    356 
    357       // Maps from source (producer node/slot) and destination (consumer
    358       // node/slot) tensors in the original graph to output index (slot number
    359       // in the SendFromHost/HostCompute nodes that will be created) for the
    360       // outside_compilation subgraph.
    361       std::unordered_map<NodeSlot, int, NodeSlot::Hasher> outputs_by_src;
    362       std::unordered_map<NodeSlot, int, NodeSlot::Hasher> outputs_by_dst;
    363 
    364       // Set of nodes in the original graph that are the destination of control
    365       // edges that cross from the outside_compilation subgraph into the
    366       // containing compiled subgraph. These are recorded by
    367       // RecordOutsideCompilationOutputOrControl while walking all the subgraph
    368       // edges, and lifted control edges within the subgraph are added by
    369       // AddRecvsFromToOutsideCompilation once the _HostCompute node has been
    370       // created. The matching control edge from the source to _SendFromHost to
    371       // the destination is added by CopyEdgeToOutputGraph.
    372       std::unordered_set<const Node*> control_outputs;
    373 
    374       // Name of the _HostCompute node in the subgraph.
    375       string host_compute_name;
    376 
    377       // _RecvAtHost node in the output graph. Not owned.
    378       Node* recv_at_host = nullptr;
    379 
    380       // _SendFromHost node in the output graph. Not owned.
    381       Node* send_from_host = nullptr;
    382     };
    383 
    384     // Builds a ParallelCheck op that compares the output of the original
    385     // subgraph with the encapsulated subgraph.
    386     Status BuildParallelCheckOp(
    387         const std::unordered_map<const Node*, Node*>& node_images,
    388         Graph* graph_out);
    389 
    390     // Builds a _RecvAtHost node producing all the inputs of an
    391     // outside_compilation subgraph and stores it in oc_subgraph.recv_at_host.
    392     Status AddRecvAtHostNode(const string& subgraph_name,
    393                              const string& oc_subgraph_name,
    394                              OutsideCompilationSubgraph* oc_subgraph,
    395                              Graph* graph_out);
    396 
    397     // Builds a _SendFromHost node consuming all the outputs of an
    398     // outside_compilation subgraph and stores it in oc_subgraph.send_from_host.
    399     Status AddSendFromHostNode(
    400         const std::unordered_map<const Node*, Node*>& node_images,
    401         const string& subgraph_name, const string& oc_subgraph_name,
    402         OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out);
    403 
    404     // The subgraph extracted from the input graph, suitable for being turned
    405     // into a FunctionDef. Inputs are fed by _Arg nodes, and outputs are
    406     // returned by _Retval nodes.
    407     std::unique_ptr<Graph> graph_;
    408 
    409     // Which device are these nodes on? Used to assign a device to the call
    410     // node.
    411     string device_;
    412 
    413     // NodeDef for the function call node.
    414     NodeDef call_node_def_;
    415 
    416     // Function call node(s) in the output graph. Not owned.
    417     // If parallel_checking is enabled, 'call_node_inputs' is the function call
    418     // node to which inputs should be fed, and 'call_node_outputs' is the
    419     // parallel check op from which outputs should be read. If parallel checking
    420     // is disabled, both point to the function call node.
    421     Node* call_node_inputs_;
    422     Node* call_node_outputs_;
    423 
    424     // Maps from source (producer node/slot) and destination
    425     // (consumer node/slot) tensors in the input graph to _Arg numbers in
    426     // the subgraph. The source map is one-to-one, whereas the dest map may be
    427     // many-to-one.
    428     std::unordered_map<NodeSlot, int, NodeSlot::Hasher> args_by_src_;
    429     std::unordered_map<NodeSlot, int, NodeSlot::Hasher> args_by_dst_;
    430 
    431     // The _Arg nodes in the subgraph, in order by argument number.
    432     std::vector<Node*> args_;
    433 
    434     // Map from source tensor in the input graph to result #.
    435     std::unordered_map<NodeSlot, int, NodeSlot::Hasher> results_;
    436 
    437     // The outside_compilation clusters in this subgraph.
    438     std::unordered_map<string, OutsideCompilationSubgraph>
    439         outside_compilation_subgraphs_;
    440 
    441     // NoOp node in the output graph that is sequenced after the call node and
    442     // used to prevent host-side outside_compilation sends and recvs from being
    443     // pruned.
    444     Node* sequencer_ = nullptr;
    445   };
    446 
    447   // Returns the key attribute and outside_compilation attribute associated
    448   // with a node in attr, and outside_compilation_attr, respectively. Sets
    449   // either result to the empty string if the respective attribute is not
    450   // found. Returns error status if there is an outside_compilation attribute
    451   // and no key attribute,
    452   Status GetFunctionNameAttr(Node const* node, string* attr,
    453                              string* outside_compilation_attr) const;
    454 
    455   // Copies edges local to a subgraph. Adds _Arg and _Retval nodes to
    456   // subgraphs for data edges that cross subgraph boundaries.
    457   Status CopySubgraphEdges(
    458       const std::unordered_map<const Node*, Node*>& node_images,
    459       std::vector<std::pair<const Node*, Node*>>* src_arg_pairs);
    460 
    461   // Copies all marked nodes to a subgraph. Does nothing for unmarked nodes,
    462   // or nodes marked outside_compilation.
    463   Status CopySubgraphNodes(std::unordered_map<const Node*, Node*>* node_images);
    464 
    465   // Copies all nodes that aren't in a compiled subgraph to the output graph.
    466   Status CopyNodesToOutputGraph(
    467       bool parallel_checking, Graph* graph_out,
    468       std::unordered_map<const Node*, Node*>* node_images);
    469 
    470   // Adds function call nodes for each compiled subgraph.
    471   Status AddFunctionCallNodes(
    472       const std::unordered_map<const Node*, Node*>& node_images,
    473       bool parallel_checking, Graph* graph_out);
    474 
    475   // Adds _RecvAtHost and _SendFromHost nodes, where needed, for all
    476   // outside_compilation subgraphs.
    477   Status AddOutsideCompilationHostIONodes(
    478       const std::unordered_map<const Node*, Node*>& node_images,
    479       Graph* graph_out);
    480 
    481   // Finds the image of an edge source in the output graph. If the edge crosses
    482   // a subgraph boundary it is the output of a call node, otherwise it is a node
    483   // in the output graph.
    484   Status FindOutputImageOfEdgeSrc(
    485       const string& src_func_id, const string& src_outside_compilation_id,
    486       const string& dst_func_id, const string& dst_outside_compilation_id,
    487       const std::unordered_map<const Node*, Node*>& node_images,
    488       const Node* original_src_node, Node** src_image);
    489 
    490   // Finds an edge source slot in the output graph. If the edge crosses a
    491   // subgraph boundary it is a slot on the output of a call node or a
    492   // _RecvAtHost node, otherwise it is a slot on a node in the output graph.
    493   int FindOutputSlotOfEdgeSrc(const string& src_func_id,
    494                               const string& src_outside_compilation_id,
    495                               const string& dst_func_id,
    496                               const string& dst_outside_compilation_id,
    497                               const Edge* edge);
    498 
    499   // Finds the image of an edge destination in the output graph. If the edge
    500   // crosses a subgraph boundary it is the input of a call node or a
    501   // _SendFromHost node, otherwise it is a node in the output graph.
    502   Status FindOutputImageOfEdgeDst(
    503       const string& src_func_id, const string& src_outside_compilation_id,
    504       const string& dst_func_id, const string& dst_outside_compilation_id,
    505       const std::unordered_map<const Node*, Node*>& node_images,
    506       const Node* original_dst_node, Node** dst_image);
    507 
    508   // Finds an edge destination slot in the output graph. If the edge crosses a
    509   // subgraph boundary it is a slot on the input of a call node or a
    510   // _SendFromHost node, otherwise it is a slot on a node in the output graph.
    511   int FindOutputSlotOfEdgeDst(const string& src_func_id,
    512                               const string& src_outside_compilation_id,
    513                               const string& dst_func_id,
    514                               const string& dst_outside_compilation_id,
    515                               const Edge* edge);
    516 
    517   // Copies a single edge to the output graph. The edge is either entirely
    518   // within the output graph, or crosses into or out of a compiled subgraph.
    519   Status CopyEdgeToOutputGraph(
    520       const Edge* edge, const string& src_func_id,
    521       const string& src_outside_compilation_id, const string& dst_func_id,
    522       const string& dst_outside_compilation_id,
    523       const std::unordered_map<const Node*, Node*>& node_images,
    524       bool parallel_checking, Graph* graph_out,
    525       std::unordered_set<std::pair<NodeSlot, NodeSlot>, NodeSlot::PairHasher>*
    526           edges_added);
    527 
    528   // Adds all edges to the output graph.
    529   Status AddEdgesToOutputGraph(
    530       const std::unordered_map<const Node*, Node*>& node_images,
    531       bool parallel_checking, Graph* graph_out);
    532 
    533   // Constructs a minimal shape inference graph that can be used to determine
    534   // the shape of send_node at the time that the subgraph is compiled.
    535   // recv_at_host_nodes contains the names of all the recv_at_host nodes that
    536   // send_node might depend on. These recv_at_host nodes have shapes that are
    537   // not known during the rewrite pass, but will be known at compile time.
    538   //
    539   // If the shapes of all the inputs to send_node can be determined during the
    540   // rewrite pass, on exit graphdef_out is empty and the shapes are returned in
    541   // static_shape_out. Otherwise graphdef_out contains a graph that can be used
    542   // for shape inference at compile time, where all the source nodes of the
    543   // graph are either constants with known shapes, or nodes named in
    544   // recv_at_host_nodes.
    545   //
    546   // A non-OK status is returned if neither of the above conditions can be
    547   // satisfied, e.g., because send_node depends on a node that doesn't have a
    548   // registered shape inference function.
    549   Status DoStaticShapeInferenceForOutsideCompilationSend(
    550       const Graph& graph_in, const ShapeRefiner& shape_refiner,
    551       const std::unordered_set<string>& recv_at_host_nodes, Node* send_node,
    552       FunctionLibraryDefinition* library,
    553       std::vector<TensorShapeProto>* static_shape_out,
    554       std::unique_ptr<GraphDef>* graphdef_out);
    555 
    556   // Makes a copy of graph containing only nodes that are ancestors of at least
    557   // one node in send_from_host_nodes and store it in pruned_graph. On exit
    558   // nodes_images contains a mapping from nodes in graph to nodes in
    559   // pruned_graph. All functions in the copied graph are inlined.
    560   Status MakePrunedGraphCopyAndInline(
    561       const Graph& graph, const std::vector<Node*>& sink_nodes,
    562       std::unique_ptr<Graph>* pruned_graph,
    563       std::unordered_map<const Node*, Node*>* node_images,
    564       FunctionLibraryDefinition* library);
    565 
    566   // Makes a copy of graph containing only nodes that are ancestors of a
    567   // send_from_host node in an outside_compilation subgraph, and store it in
    568   // pruned_graph. Also perform shape inference on the pruned graph, using
    569   // shape_refiner. On exit node_images contains a mapping from nodes in graph
    570   // to nodes in pruned_graph.
    571   Status MakeGraphForOutsideCompilationSends(
    572       const Graph& graph, std::unique_ptr<Graph>* pruned_graph,
    573       ShapeRefiner* shape_refiner,
    574       std::unordered_map<const Node*, Node*>* node_images,
    575       FunctionLibraryDefinition* library);
    576 
    577   // Performs static shape inference, as far as possible, for the send_from_host
    578   // nodes in each outside_compilation subgraph. Where it is not possible to
    579   // determine the shape statically, stores a serialized GraphDef in the
    580   // HostCompute 'shape_inference_graph' attr, to be used at compile time for
    581   // final inference. If the shapes are known statically they are stored in the
    582   // HostCompute 'shapes' attr.
    583   Status GetShapeInfoForOutsideCompilationSends(
    584       Graph* graph_out, FunctionLibraryDefinition* library);
    585 
    586   const string group_attribute_;
    587   const string outside_compilation_attribute_;
    588   const Graph* graph_in_;
    589 
    590   std::unordered_map<string, Subgraph> subgraphs_;
    591 
    592   TF_DISALLOW_COPY_AND_ASSIGN(Encapsulator);
    593 };
    594 
    595 Node* Encapsulator::Subgraph::GetCallNodeForInputs() const {
    596   return call_node_inputs_;
    597 }
    598 
    599 Node* Encapsulator::Subgraph::GetCallNodeForOutputs() const {
    600   return call_node_outputs_;
    601 }
    602 
    603 int Encapsulator::Subgraph::GetArgIndexForEdge(const Edge* edge) const {
    604   return args_by_dst_.at(NodeSlot(edge->dst(), edge->dst_input()));
    605 }
    606 
    607 int Encapsulator::Subgraph::GetResultIndexForEdge(const Edge* edge) const {
    608   return results_.at(NodeSlot(edge->src(), edge->src_output()));
    609 }
    610 
    611 Node* Encapsulator::Subgraph::GetRecvAtHostNode(
    612     const string& outside_compilation_subgraph_name) const {
    613   return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name)
    614       .recv_at_host;
    615 }
    616 
    617 int Encapsulator::Subgraph::GetRecvAtHostSlot(
    618     const string& outside_compilation_subgraph_name, const Edge* edge) const {
    619   return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name)
    620       .inputs.at(NodeSlot(edge->src(), edge->src_output()));
    621 }
    622 
    623 Node* Encapsulator::Subgraph::GetSendFromHostNode(
    624     const string& outside_compilation_subgraph_name) const {
    625   return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name)
    626       .send_from_host;
    627 }
    628 
    629 int Encapsulator::Subgraph::GetSendFromHostSlot(
    630     const string& outside_compilation_subgraph_name, const Edge* edge) const {
    631   return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name)
    632       .outputs_by_dst.at(NodeSlot(edge->dst(), edge->dst_input()));
    633 }
    634 
    635 Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) {
    636   if (!graph_) {
    637     graph_.reset(new Graph(graph_in->op_registry()));
    638     graph_->set_versions(graph_in->versions());
    639   }
    640 
    641   if (device_.empty()) {
    642     device_ = node->assigned_device_name().empty()
    643                   ? node->requested_device()
    644                   : node->assigned_device_name();
    645   }
    646 
    647   return graph_->CopyNode(node);
    648 }
    649 
    650 Graph* Encapsulator::Subgraph::GetGraph() const { return graph_.get(); }
    651 
    652 Status Encapsulator::Subgraph::RecordArg(
    653     const Edge* edge, const std::unordered_map<const Node*, Node*>& node_images,
    654     std::vector<std::pair<const Node*, Node*>>* src_arg_pairs) {
    655   Node* src_node = edge->src();
    656   int src_slot = edge->src_output();
    657   std::unordered_map<NodeSlot, int, NodeSlot::Hasher>::iterator iter;
    658   bool inserted;
    659   std::tie(iter, inserted) =
    660       args_by_src_.emplace(NodeSlot(src_node, src_slot), args_by_src_.size());
    661   int arg_index = iter->second;
    662   if (inserted) {
    663     NodeDef arg_def;
    664     NodeDefBuilder builder(
    665         strings::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp);
    666     DataType dtype = edge->dst()->input_type(edge->dst_input());
    667     builder.Attr("T", dtype);
    668     builder.Attr("index", arg_index);
    669     Status s = builder.Finalize(&arg_def);
    670     if (!s.ok()) return s;
    671 
    672     Node* arg = graph_->AddNode(arg_def, &s);
    673     if (!s.ok()) return s;
    674 
    675     src_arg_pairs->push_back({src_node, arg});
    676     args_.push_back(arg);
    677   }
    678   Node* dst_node = edge->dst();
    679   Node* dst_image = node_images.at(dst_node);
    680   int dst_slot = edge->dst_input();
    681   args_by_dst_[NodeSlot(dst_node, dst_slot)] = arg_index;
    682   graph_->AddEdge(args_[arg_index], 0, dst_image, dst_slot);
    683   return Status::OK();
    684 }
    685 
    686 Status Encapsulator::Subgraph::RecordResult(
    687     const Edge* edge,
    688     const std::unordered_map<const Node*, Node*>& node_images) {
    689   Node* src_node = edge->src();
    690   Node* src_image = node_images.at(src_node);
    691   int src_slot = edge->src_output();
    692   std::unordered_map<NodeSlot, int, NodeSlot::Hasher>::iterator iter;
    693   bool inserted;
    694   std::tie(iter, inserted) =
    695       results_.emplace(NodeSlot(src_node, src_slot), results_.size());
    696   int ret_index = iter->second;
    697   if (inserted) {
    698     NodeDef ret_def;
    699     NodeDefBuilder builder(
    700         strings::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp);
    701     DataType dtype = src_node->output_type(src_slot);
    702     builder.Attr("T", dtype);
    703     builder.Attr("index", ret_index);
    704     builder.Input(src_image->name(), src_slot, dtype);
    705     Status s = builder.Finalize(&ret_def);
    706     if (!s.ok()) return s;
    707     Node* ret = graph_->AddNode(ret_def, &s);
    708     if (!s.ok()) return s;
    709 
    710     graph_->AddEdge(src_image, src_slot, ret, 0);
    711   }
    712   return Status::OK();
    713 }
    714 
    715 void Encapsulator::Subgraph::RecordOutsideCompilationInputOrControl(
    716     const string& outside_compilation_id, const Edge* edge) {
    717   auto iter = outside_compilation_subgraphs_
    718                   .emplace(outside_compilation_id, OutsideCompilationSubgraph())
    719                   .first;
    720   OutsideCompilationSubgraph& outside_subgraph = iter->second;
    721   if (edge->IsControlEdge()) {
    722     outside_subgraph.control_inputs.insert(edge->src());
    723   } else {
    724     int input_index = outside_subgraph.inputs.size();
    725     outside_subgraph.inputs.emplace(NodeSlot(edge->src(), edge->src_output()),
    726                                     input_index);
    727   }
    728 }
    729 
    730 void Encapsulator::Subgraph::RecordOutsideCompilationOutputOrControl(
    731     const string& outside_compilation_id, const Edge* edge) {
    732   auto subgraph_iter =
    733       outside_compilation_subgraphs_
    734           .emplace(outside_compilation_id, OutsideCompilationSubgraph())
    735           .first;
    736   OutsideCompilationSubgraph& outside_subgraph = subgraph_iter->second;
    737   if (edge->IsControlEdge()) {
    738     outside_subgraph.control_outputs.insert(edge->dst());
    739   } else {
    740     DataType dtype = edge->dst()->input_type(edge->dst_input());
    741     auto output_iter =
    742         outside_subgraph.outputs_by_src
    743             .emplace(NodeSlot(edge->src(), edge->src_output(), dtype),
    744                      outside_subgraph.outputs_by_src.size())
    745             .first;
    746     int output_index = output_iter->second;
    747     outside_subgraph.outputs_by_dst[NodeSlot(edge->dst(), edge->dst_input())] =
    748         output_index;
    749   }
    750 }
    751 
    752 Status Encapsulator::Subgraph::AddHostComputes(
    753     const string& subgraph_name,
    754     const std::unordered_map<const Node*, Node*>& node_images) {
    755   for (auto& oc_subgraph_iter : outside_compilation_subgraphs_) {
    756     const string& oc_subgraph_name = oc_subgraph_iter.first;
    757     OutsideCompilationSubgraph& oc_subgraph = oc_subgraph_iter.second;
    758     if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty() ||
    759         !oc_subgraph.outputs_by_src.empty() ||
    760         !oc_subgraph.control_outputs.empty()) {
    761       // Build a _HostCompute node.
    762       std::vector<NodeDefBuilder::NodeOut> inputs(oc_subgraph.inputs.size());
    763       std::vector<DataType> input_dtypes(oc_subgraph.inputs.size(), DT_INVALID);
    764       std::vector<DataType> output_dtypes(oc_subgraph.outputs_by_src.size(),
    765                                           DT_INVALID);
    766 
    767       for (const auto& input_src : oc_subgraph.inputs) {
    768         const Node* src_node = input_src.first.node;
    769         Node* src_image = node_images.at(src_node);
    770         int src_slot = input_src.first.slot;
    771         int input_index = input_src.second;
    772 
    773         DataType dtype = src_node->output_type(src_slot);
    774         inputs[input_index].Reset(src_image->name(), src_slot, dtype);
    775         input_dtypes[input_index] = dtype;
    776       }
    777 
    778       for (const auto& output : oc_subgraph.outputs_by_src) {
    779         DataType dtype = output.first.dtype;
    780         int output_index = output.second;
    781         output_dtypes[output_index] = dtype;
    782       }
    783 
    784       NodeDef host_compute_def;
    785       NodeDefBuilder builder(strings::StrCat("outside_compilation_",
    786                                              oc_subgraph_name, "_host_compute"),
    787                              kHostComputeOp);
    788       builder.Input(inputs);
    789       builder.Attr("Tinputs", input_dtypes);
    790       builder.Attr("Toutputs", output_dtypes);
    791       builder.Attr("key",
    792                    strings::StrCat("host_compute_channel_", subgraph_name, "_",
    793                                    oc_subgraph_name));
    794       Status s = builder.Finalize(&host_compute_def);
    795       if (!s.ok()) return s;
    796 
    797       Node* host_compute = graph_->AddNode(host_compute_def, &s);
    798       if (!s.ok()) return s;
    799       oc_subgraph.host_compute_name = host_compute->name();
    800 
    801       // Connect the _HostCompute node to its producers in the subgraph.
    802       for (auto& input_src : oc_subgraph.inputs) {
    803         const Node* src_node = input_src.first.node;
    804         Node* src_image = node_images.at(src_node);
    805         int src_slot = input_src.first.slot;
    806         int input_index = input_src.second;
    807         graph_->AddEdge(src_image, src_slot, host_compute, input_index);
    808       }
    809 
    810       // Connect the _HostCompute node to its control edge producers in the
    811       // subgraph.
    812       for (const auto& src_node : oc_subgraph.control_inputs) {
    813         Node* src_image = node_images.at(src_node);
    814         graph_->AddControlEdge(src_image, host_compute);
    815       }
    816 
    817       // Connect the consumers in the subgraph to the _HostCompute node.
    818       for (const auto& output : oc_subgraph.outputs_by_dst) {
    819         const Node* dst_node = output.first.node;
    820         Node* dst_image = node_images.at(dst_node);
    821         int dst_slot = output.first.slot;
    822         int output_index = output.second;
    823 
    824         graph_->AddEdge(host_compute, output_index, dst_image, dst_slot);
    825       }
    826 
    827       // Connect the control edge consumers in the subgraph to the _HostCompute
    828       // node.
    829       for (const auto& dst_node : oc_subgraph.control_outputs) {
    830         Node* dst_image = node_images.at(dst_node);
    831         graph_->AddControlEdge(host_compute, dst_image);
    832       }
    833     }
    834   }
    835 
    836   return Status::OK();
    837 }
    838 
    839 Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name,
    840                                                   Graph* graph_out) {
    841   if (sequencer_ == nullptr) {
    842     NodeDef seq_def;
    843     NodeDefBuilder builder(strings::StrCat(subgraph_name, "_sequencer"),
    844                            "NoOp");
    845     Status s = builder.Finalize(&seq_def);
    846     if (!s.ok()) return s;
    847 
    848     sequencer_ = graph_out->AddNode(seq_def, &s);
    849     if (!s.ok()) return s;
    850     sequencer_->set_assigned_device_name(device_);
    851   }
    852   return Status::OK();
    853 }
    854 
    855 void Encapsulator::Subgraph::ConnectSequencerToOutputs(Graph* graph_out) {
    856   if (sequencer_ != nullptr) {
    857     std::unordered_set<Node*> output_dependencies;
    858     for (Node* node : call_node_outputs_->out_nodes()) {
    859       output_dependencies.insert(node);
    860     }
    861     for (Node* node : output_dependencies) {
    862       graph_out->AddControlEdge(sequencer_, node);
    863     }
    864   }
    865 }
    866 
    867 Status Encapsulator::Subgraph::BuildFunctionDef(
    868     const string& name_in, const RewriteSubgraphFn& rewrite_subgraph_fn,
    869     bool reuse_existing_functions, FunctionLibraryDefinition* library) {
    870   // name_in is copied here because name may be modified below if
    871   // rewrite_subgraph_fn is true.
    872   string name = name_in;
    873   call_node_def_.set_op(name);
    874   call_node_def_.set_name(name);
    875   call_node_def_.set_device(device_);
    876 
    877   if (rewrite_subgraph_fn) {
    878     // Initialize the input and output permutations to the identity.
    879     std::vector<int> input_permutation(args_by_src_.size());
    880     std::iota(input_permutation.begin(), input_permutation.end(), 0);
    881     std::vector<int> output_permutation(results_.size());
    882     std::iota(output_permutation.begin(), output_permutation.end(), 0);
    883 
    884     TF_RETURN_IF_ERROR(rewrite_subgraph_fn(
    885         &graph_, &input_permutation, &output_permutation, &call_node_def_));
    886 
    887     // Apply the input/output permutations to the 'args_by_...' and 'results_'
    888     // mappings, so when we build edges in BuildOutputGraph() we
    889     // connect them to the right input/output positions.
    890     if (input_permutation.size() != args_by_src_.size()) {
    891       return errors::InvalidArgument("Input permutation has incorrect size.");
    892     }
    893     if (output_permutation.size() != results_.size()) {
    894       return errors::InvalidArgument("Output permutation has incorrect size.");
    895     }
    896     for (auto& arg : args_by_src_) {
    897       arg.second = input_permutation[arg.second];
    898     }
    899     for (auto& arg : args_by_dst_) {
    900       arg.second = input_permutation[arg.second];
    901     }
    902     for (auto& result : results_) {
    903       result.second = output_permutation[result.second];
    904     }
    905 
    906     name = call_node_def_.op();
    907   }
    908 
    909   FunctionDef fdef;
    910   TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef));
    911 
    912   if (VLOG_IS_ON(1)) {
    913     VLOG(2) << "Build function def " << name;
    914     dump_graph::DumpGraphToFile(
    915         strings::StrCat("encapsulate_fdef_graph_", name), *graph_, library);
    916     dump_graph::DumpFunctionDefToFile(
    917         strings::StrCat("encapsulate_fdef_", name), fdef);
    918   }
    919 
    920   if (!reuse_existing_functions || library->Find(name) == nullptr) {
    921     TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
    922   }
    923   return Status::OK();
    924 }
    925 
    926 Status Encapsulator::Subgraph::AddShapeInferenceInfo(
    927     const string& outside_compilation_subgraph_name,
    928     const std::vector<TensorShapeProto>& shapes, GraphDef* inference_graph) {
    929   OutsideCompilationSubgraph& oc_subgraph =
    930       outside_compilation_subgraphs_.at(outside_compilation_subgraph_name);
    931 
    932   Node* host_compute = nullptr;
    933   for (Node* n : graph_->nodes()) {
    934     if (n->name() == oc_subgraph.host_compute_name) {
    935       host_compute = n;
    936       break;
    937     }
    938   }
    939   if (host_compute == nullptr) {
    940     return errors::InvalidArgument(
    941         "After rewriting subgraph ", outside_compilation_subgraph_name,
    942         " there is no HostCompute Op for outside compilation subgraph ",
    943         oc_subgraph.host_compute_name);
    944   }
    945 
    946   if (inference_graph == nullptr) {
    947     host_compute->AddAttr("shape_inference_graph", "");
    948     host_compute->AddAttr("shapes", shapes);
    949   } else {
    950     string serialized_graph;
    951     if (!inference_graph->SerializeToString(&serialized_graph)) {
    952       return errors::Internal(
    953           "Failed to serialize graph for outside compilation subgraph ",
    954           oc_subgraph.host_compute_name);
    955     }
    956     host_compute->AddAttr("shape_inference_graph", serialized_graph);
    957     host_compute->AddAttr("shapes", std::vector<TensorShapeProto>());
    958   }
    959   return Status::OK();
    960 }
    961 
    962 Status Encapsulator::Subgraph::ReplaceFunctionDef(
    963     FunctionLibraryDefinition* library) {
    964   const string& name = call_node_def_.name();
    965 
    966   FunctionDef fdef;
    967   TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef));
    968 
    969   if (VLOG_IS_ON(1)) {
    970     VLOG(2) << "Replace function def " << name;
    971     dump_graph::DumpGraphToFile(
    972         strings::StrCat("replace_encapsulate_fdef_graph_", name), *graph_,
    973         library);
    974     dump_graph::DumpFunctionDefToFile(
    975         strings::StrCat("replace_encapsulate_fdef_", name), fdef);
    976   }
    977 
    978   TF_RETURN_IF_ERROR(library->RemoveFunction(name));
    979   TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
    980   return Status::OK();
    981 }
    982 
    983 Status Encapsulator::Subgraph::BuildParallelCheckOp(
    984     const std::unordered_map<const Node*, Node*>& node_images,
    985     Graph* graph_out) {
    986   // Build an index mapping output positions to node/slot pairs in the
    987   // original graph.
    988   std::vector<NodeSlot> results_by_num(results_.size());
    989   for (const auto& entry : results_) {
    990     results_by_num[entry.second] = entry.first;
    991   }
    992 
    993   // Build a parallel check NodeDef.
    994   int num_results = results_by_num.size();
    995   std::vector<DataType> result_dtypes(num_results);
    996   std::vector<NodeDefBuilder::NodeOut> expected_outputs(num_results);
    997   std::vector<NodeDefBuilder::NodeOut> actual_outputs(num_results);
    998   for (int i = 0; i < num_results; ++i) {
    999     const NodeSlot& node_slot = results_by_num[i];
   1000     result_dtypes[i] = node_slot.node->output_type(node_slot.slot);
   1001     expected_outputs[i] =
   1002         NodeDefBuilder::NodeOut(node_images.at(node_slot.node)->name(),
   1003                                 node_slot.slot, result_dtypes[i]);
   1004     actual_outputs[i] =
   1005         NodeDefBuilder::NodeOut(call_node_def_.name(), i, result_dtypes[i]);
   1006   }
   1007   // Assign the parallel check op to a CPU on the same task as the cluster it is
   1008   // checking.
   1009   string device, dummy;
   1010   if (!DeviceNameUtils::SplitDeviceName(
   1011           call_node_inputs_->assigned_device_name(), &device, &dummy)) {
   1012     return errors::InvalidArgument("Could not parse device name");
   1013   }
   1014   strings::StrAppend(&device, "/cpu:0");
   1015 
   1016   NodeDef check_def;
   1017   TF_RETURN_IF_ERROR(
   1018       NodeDefBuilder(graph_out->NewName(strings::StrCat(call_node_def_.name(),
   1019                                                         "_parallel_check")),
   1020                      "ParallelCheck")
   1021           .Device(device)
   1022           .Attr("T", result_dtypes)
   1023           .Input(expected_outputs)
   1024           .Input(actual_outputs)
   1025           .Finalize(&check_def));
   1026 
   1027   Status s;
   1028   Node* check_op = graph_out->AddNode(check_def, &s);
   1029   if (!s.ok()) return s;
   1030   check_op->set_assigned_device_name(device);
   1031 
   1032   // TODO(phawkins): it seems redundant to call AddEdge as well as
   1033   // pass Inputs to the NodeDefBuilder, but I have been unable to find a
   1034   // way to avoid it.
   1035   for (int i = 0; i < num_results; ++i) {
   1036     const NodeSlot& node_slot = results_by_num[i];
   1037     graph_out->AddEdge(node_images.at(node_slot.node), node_slot.slot, check_op,
   1038                        i);
   1039     graph_out->AddEdge(call_node_inputs_, i, check_op, num_results + i);
   1040   }
   1041 
   1042   call_node_outputs_ = check_op;
   1043   return Status::OK();
   1044 }
   1045 
   1046 Status Encapsulator::Subgraph::AddFunctionCallNode(
   1047     const std::unordered_map<const Node*, Node*>& node_images,
   1048     bool parallel_checking, Graph* graph_out) {
   1049   Status s;
   1050   call_node_inputs_ = graph_out->AddNode(call_node_def_, &s);
   1051   if (!s.ok()) return s;
   1052 
   1053   // Copy the assigned device and the key_annotation over.
   1054   call_node_inputs_->set_assigned_device_name(device_);
   1055   call_node_outputs_ = call_node_inputs_;
   1056 
   1057   if (parallel_checking) {
   1058     TF_RETURN_IF_ERROR(BuildParallelCheckOp(node_images, graph_out));
   1059   }
   1060   return Status::OK();
   1061 }
   1062 
   1063 Status Encapsulator::Subgraph::AddRecvAtHostNode(
   1064     const string& subgraph_name, const string& oc_subgraph_name,
   1065     OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out) {
   1066   std::vector<DataType> dtypes(oc_subgraph->inputs.size(), DT_INVALID);
   1067 
   1068   for (const auto& input : oc_subgraph->inputs) {
   1069     const Node* src_node = input.first.node;
   1070     int src_slot = input.first.slot;
   1071     int input_index = input.second;
   1072 
   1073     DataType dtype = src_node->output_type(src_slot);
   1074     dtypes[input_index] = dtype;
   1075   }
   1076 
   1077   NodeDef recv_def;
   1078   NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name,
   1079                                          "_", oc_subgraph_name, "_recv"),
   1080                          kRecvAtHostOp);
   1081   builder.Attr("Toutputs", dtypes);
   1082   builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name,
   1083                                       "_", oc_subgraph_name));
   1084   Status s = builder.Finalize(&recv_def);
   1085   if (!s.ok()) return s;
   1086 
   1087   oc_subgraph->recv_at_host = graph_out->AddNode(recv_def, &s);
   1088   if (!s.ok()) return s;
   1089   oc_subgraph->recv_at_host->set_assigned_device_name(device_);
   1090 
   1091   // Add a control dependency forcing the RecvAtHost to run before the subgraph
   1092   // completes. This has no effect on execution order but prevents the
   1093   // RecvAtHost being pruned.
   1094   TF_RETURN_IF_ERROR(MakeSequencingNode(subgraph_name, graph_out));
   1095   graph_out->AddControlEdge(oc_subgraph->recv_at_host, sequencer_);
   1096 
   1097   return Status::OK();
   1098 }
   1099 
   1100 Status Encapsulator::Subgraph::AddSendFromHostNode(
   1101     const std::unordered_map<const Node*, Node*>& node_images,
   1102     const string& subgraph_name, const string& oc_subgraph_name,
   1103     OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out) {
   1104   std::vector<DataType> dtypes(oc_subgraph->outputs_by_src.size(), DT_INVALID);
   1105   std::vector<NodeDefBuilder::NodeOut> inputs(
   1106       oc_subgraph->outputs_by_src.size());
   1107 
   1108   for (const auto& output : oc_subgraph->outputs_by_src) {
   1109     const Node* src_node = output.first.node;
   1110     Node* src_image = node_images.at(src_node);
   1111     int src_slot = output.first.slot;
   1112     int output_index = output.second;
   1113 
   1114     DataType dtype = src_node->output_type(src_slot);
   1115     dtypes[output_index] = dtype;
   1116     inputs[output_index].Reset(src_image->name(), src_slot, dtype);
   1117   }
   1118 
   1119   NodeDef send_def;
   1120   NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name,
   1121                                          "_", oc_subgraph_name, "_send"),
   1122                          kSendFromHostOp);
   1123   builder.Attr("Tinputs", dtypes);
   1124   builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name,
   1125                                       "_", oc_subgraph_name));
   1126   builder.Input(inputs);
   1127   Status s = builder.Finalize(&send_def);
   1128   if (!s.ok()) return s;
   1129 
   1130   oc_subgraph->send_from_host = graph_out->AddNode(send_def, &s);
   1131   if (!s.ok()) return s;
   1132   oc_subgraph->send_from_host->set_assigned_device_name(device_);
   1133 
   1134   // Add a control dependency forcing the SendFromHost to run before the
   1135   // subgraph completes. This has no effect on execution order but prevents the
   1136   // RecvAtHost being pruned.
   1137   TF_RETURN_IF_ERROR(MakeSequencingNode(subgraph_name, graph_out));
   1138   graph_out->AddControlEdge(oc_subgraph->send_from_host, sequencer_);
   1139 
   1140   return Status::OK();
   1141 }
   1142 
   1143 Status Encapsulator::Subgraph::AddOutsideCompilationHostIONodes(
   1144     const string& subgraph_name,
   1145     const std::unordered_map<const Node*, Node*>& node_images,
   1146     Graph* graph_out) {
   1147   for (auto& outside_compilation_subgraph_entry :
   1148        outside_compilation_subgraphs_) {
   1149     const string& oc_name = outside_compilation_subgraph_entry.first;
   1150     OutsideCompilationSubgraph& oc_subgraph =
   1151         outside_compilation_subgraph_entry.second;
   1152 
   1153     if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty()) {
   1154       TF_RETURN_IF_ERROR(
   1155           AddRecvAtHostNode(subgraph_name, oc_name, &oc_subgraph, graph_out));
   1156     }
   1157 
   1158     if (!oc_subgraph.outputs_by_src.empty() ||
   1159         !oc_subgraph.control_outputs.empty()) {
   1160       TF_RETURN_IF_ERROR(AddSendFromHostNode(node_images, subgraph_name,
   1161                                              oc_name, &oc_subgraph, graph_out));
   1162     }
   1163   }
   1164   return Status::OK();
   1165 }
   1166 
   1167 void Encapsulator::Subgraph::GetOutsideCompilationSubgraphNames(
   1168     std::vector<string>* names) const {
   1169   for (auto& entry : outside_compilation_subgraphs_) {
   1170     names->push_back(entry.first);
   1171   }
   1172 }
   1173 
   1174 Status Encapsulator::GetFunctionNameAttr(
   1175     Node const* node, string* attr, string* outside_compilation_attr) const {
   1176   Status s = GetNodeAttr(node->attrs(), group_attribute_, attr);
   1177   if (s.code() == error::Code::NOT_FOUND) {
   1178     // Return empty attr if there's no group_attribute.
   1179     attr->clear();
   1180   } else {
   1181     TF_RETURN_IF_ERROR(s);
   1182   }
   1183   bool has_group_attr = s.ok();
   1184   s = GetNodeAttr(node->attrs(), outside_compilation_attribute_,
   1185                   outside_compilation_attr);
   1186   if (s.code() == error::Code::NOT_FOUND) {
   1187     // Return empty attr if there's no outside_compilation attribute.
   1188     outside_compilation_attr->clear();
   1189   } else {
   1190     TF_RETURN_IF_ERROR(s);
   1191     if (!has_group_attr) {
   1192       return errors::InvalidArgument(
   1193           "Node ", node->name(), " has ", outside_compilation_attribute_,
   1194           " attribute but no ", group_attribute_, " attribute.");
   1195     }
   1196   }
   1197   return Status::OK();
   1198 }
   1199 
   1200 bool IsInSubgraph(const string& func_id, const string& outside_compilation_id) {
   1201   return !func_id.empty() && outside_compilation_id.empty();
   1202 }
   1203 
   1204 Status Encapsulator::CopySubgraphNodes(
   1205     std::unordered_map<const Node*, Node*>* node_images) {
   1206   for (Node* node : graph_in_->op_nodes()) {
   1207     string func_id;
   1208     string outside_compilation_id;
   1209     TF_RETURN_IF_ERROR(
   1210         GetFunctionNameAttr(node, &func_id, &outside_compilation_id));
   1211     if (!IsInSubgraph(func_id, outside_compilation_id)) continue;
   1212 
   1213     Subgraph& subgraph = subgraphs_[func_id];
   1214     Node* image = subgraph.MakeNodeImage(graph_in_, node);
   1215     image->ClearAttr(group_attribute_);
   1216     (*node_images)[node] = image;
   1217   }
   1218   return Status::OK();
   1219 }
   1220 
   1221 Status Encapsulator::CopySubgraphEdges(
   1222     const std::unordered_map<const Node*, Node*>& node_images,
   1223     std::vector<std::pair<const Node*, Node*>>* src_arg_pairs) {
   1224   for (const Edge* edge : graph_in_->edges()) {
   1225     string src_func_id;
   1226     string src_outside_compilation_id;
   1227     TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id,
   1228                                            &src_outside_compilation_id));
   1229     string dst_func_id;
   1230     string dst_outside_compilation_id;
   1231     TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id,
   1232                                            &dst_outside_compilation_id));
   1233     Node* src_image = gtl::FindWithDefault(node_images, edge->src(), nullptr);
   1234     Node* dst_image = gtl::FindWithDefault(node_images, edge->dst(), nullptr);
   1235 
   1236     // Copy edges that are local to a subgraph.
   1237     if (IsInSubgraph(src_func_id, src_outside_compilation_id) &&
   1238         IsInSubgraph(dst_func_id, dst_outside_compilation_id) &&
   1239         src_func_id == dst_func_id) {
   1240       Graph* g = subgraphs_[src_func_id].GetGraph();
   1241       if (edge->IsControlEdge()) {
   1242         g->AddControlEdge(src_image, dst_image);
   1243       } else {
   1244         g->AddEdge(src_image, edge->src_output(), dst_image, edge->dst_input());
   1245       }
   1246       continue;
   1247     }
   1248 
   1249     // Record 'src' as an output of its subgraph, if applicable.
   1250     if (IsInSubgraph(src_func_id, src_outside_compilation_id)) {
   1251       if (!edge->IsControlEdge()) {
   1252         DataType dtype = edge->src()->output_type(edge->src_output());
   1253         if (IsRefType(dtype)) {
   1254           return errors::InvalidArgument(
   1255               "Ref Tensors (e.g., Variables) are not supported as results: "
   1256               "tensor ",
   1257               edge->src()->name(), ":", edge->src_output());
   1258         }
   1259       }
   1260 
   1261       Subgraph& src_subgraph = subgraphs_[src_func_id];
   1262       if (src_func_id == dst_func_id) {
   1263         // src is in the subgraph and dst is outside_compilation in the same
   1264         // subgraph.
   1265         src_subgraph.RecordOutsideCompilationInputOrControl(
   1266             dst_outside_compilation_id, edge);
   1267       } else {
   1268         // Ignore control edges leaving the subgraph. We will lift them onto the
   1269         // enclosing call operators in BuildOutputGraph().
   1270         if (!edge->IsControlEdge()) {
   1271           TF_RETURN_IF_ERROR(src_subgraph.RecordResult(edge, node_images));
   1272         }
   1273       }
   1274     }
   1275 
   1276     // Record 'dst' as an input of its subgraph, if applicable.
   1277     if (IsInSubgraph(dst_func_id, dst_outside_compilation_id)) {
   1278       // Look at the type of the destination not the source, since Ref output
   1279       // Tensors can be automatically cast to non-Ref Tensors at the
   1280       // destination.
   1281       if (!edge->IsControlEdge()) {
   1282         DataType dtype = edge->dst()->input_type(edge->dst_input());
   1283         if (IsRefType(dtype)) {
   1284           return errors::InvalidArgument(
   1285               "Ref Tensors (e.g., Variables) are not supported as args: "
   1286               "tensor ",
   1287               edge->src()->name(), ":", edge->src_output());
   1288         }
   1289       }
   1290 
   1291       Subgraph& dst_subgraph = subgraphs_[dst_func_id];
   1292       if (src_func_id == dst_func_id) {
   1293         // dst is in the subgraph and src is outside_compilation in the same
   1294         // subgraph.
   1295         dst_subgraph.RecordOutsideCompilationOutputOrControl(
   1296             src_outside_compilation_id, edge);
   1297       } else {
   1298         // Ignore control edges entering the subgraph. We will lift them onto
   1299         // the enclosing call operators in BuildOutputGraph().
   1300         if (!edge->IsControlEdge()) {
   1301           TF_RETURN_IF_ERROR(
   1302               dst_subgraph.RecordArg(edge, node_images, src_arg_pairs));
   1303         }
   1304       }
   1305     }
   1306   }
   1307   return Status::OK();
   1308 }
   1309 
   1310 Status Encapsulator::SplitIntoSubgraphs() {
   1311   Status s;
   1312 
   1313   // Map from input graph nodes to subgraph nodes.
   1314   std::unordered_map<const Node*, Node*> node_images;
   1315 
   1316   // Each entry of src_arg_pairs is a pair whose first element is a node in the
   1317   // original graph that has an output edge in the subgraph, and whose second
   1318   // element is the arg node in the subgraph that it sends to. The vector will
   1319   // be filled in below in AddArgs.
   1320   std::vector<std::pair<const Node*, Node*>> src_arg_pairs;
   1321 
   1322   TF_RETURN_IF_ERROR(CopySubgraphNodes(&node_images));
   1323   TF_RETURN_IF_ERROR(CopySubgraphEdges(node_images, &src_arg_pairs));
   1324 
   1325   // For each subgraph, add the nodes that deal with inputs and outputs its
   1326   // nested outside_compilation subgraphs. These could not be added earlier
   1327   // during CopySubgraphEdges since we need to discover all the types of the
   1328   // inputs and outputs for an outside_compilation subgraph before creating a
   1329   // single input and output node for it.
   1330   for (auto& entry : subgraphs_) {
   1331     Subgraph& subgraph = entry.second;
   1332     TF_RETURN_IF_ERROR(subgraph.AddHostComputes(entry.first, node_images));
   1333   }
   1334 
   1335   MarkGuaranteedConstants(*graph_in_, src_arg_pairs);
   1336 
   1337   for (auto& entry : subgraphs_) {
   1338     Subgraph& subgraph = entry.second;
   1339     FixupSourceAndSinkEdges(subgraph.GetGraph());
   1340   }
   1341 
   1342   return s;
   1343 }
   1344 
   1345 Status Encapsulator::BuildFunctionDefs(
   1346     const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions,
   1347     FunctionLibraryDefinition* library) {
   1348   for (auto& subgraph_entry : subgraphs_) {
   1349     string name = subgraph_entry.first;
   1350     Subgraph& subgraph = subgraph_entry.second;
   1351     TF_RETURN_IF_ERROR(subgraph.BuildFunctionDef(
   1352         name, rewrite_subgraph_fn, reuse_existing_functions, library));
   1353   }
   1354   return Status::OK();
   1355 }
   1356 
   1357 Status Encapsulator::CopyNodesToOutputGraph(
   1358     bool parallel_checking, Graph* graph_out,
   1359     std::unordered_map<const Node*, Node*>* node_images) {
   1360   for (Node* node : graph_in_->op_nodes()) {
   1361     string func_id;
   1362     string outside_compilation_id;
   1363     TF_RETURN_IF_ERROR(
   1364         GetFunctionNameAttr(node, &func_id, &outside_compilation_id));
   1365 
   1366     // Don't copy nodes that going to be encapsulated, unless parallel checking
   1367     // is enabled.
   1368     if (IsInSubgraph(func_id, outside_compilation_id) && !parallel_checking)
   1369       continue;
   1370 
   1371     Node* image = graph_out->CopyNode(node);
   1372     if (!outside_compilation_id.empty()) {
   1373       if (parallel_checking) {
   1374         return errors::InvalidArgument(
   1375             "Parallel checking is not supported when outside_compilation "
   1376             "clusters are present.");
   1377       }
   1378       image->ClearAttr(group_attribute_);
   1379       image->ClearAttr(outside_compilation_attribute_);
   1380     }
   1381     (*node_images)[node] = image;
   1382   }
   1383   (*node_images)[graph_in_->source_node()] = graph_out->source_node();
   1384   (*node_images)[graph_in_->sink_node()] = graph_out->sink_node();
   1385   return Status::OK();
   1386 }
   1387 
   1388 Status Encapsulator::AddFunctionCallNodes(
   1389     const std::unordered_map<const Node*, Node*>& node_images,
   1390     bool parallel_checking, Graph* graph_out) {
   1391   for (auto& subgraph_entry : subgraphs_) {
   1392     TF_RETURN_IF_ERROR(subgraph_entry.second.AddFunctionCallNode(
   1393         node_images, parallel_checking, graph_out));
   1394   }
   1395   return Status::OK();
   1396 }
   1397 
   1398 Status Encapsulator::AddOutsideCompilationHostIONodes(
   1399     const std::unordered_map<const Node*, Node*>& node_images,
   1400     Graph* graph_out) {
   1401   for (auto& subgraph_entry : subgraphs_) {
   1402     const string& subgraph_name = subgraph_entry.first;
   1403     Subgraph& subgraph = subgraph_entry.second;
   1404     TF_RETURN_IF_ERROR(subgraph.AddOutsideCompilationHostIONodes(
   1405         subgraph_name, node_images, graph_out));
   1406   }
   1407   return Status::OK();
   1408 }
   1409 
   1410 Status Encapsulator::FindOutputImageOfEdgeSrc(
   1411     const string& src_func_id, const string& src_outside_compilation_id,
   1412     const string& dst_func_id, const string& dst_outside_compilation_id,
   1413     const std::unordered_map<const Node*, Node*>& node_images,
   1414     const Node* original_src_node, Node** src_image) {
   1415   if (IsInSubgraph(src_func_id, src_outside_compilation_id)) {
   1416     if (dst_func_id == src_func_id) {
   1417       // The edge is from a subgraph to an outside_compilation cluster in the
   1418       // same subgraph so use the appropriate _RecvAtHost node in the output
   1419       // graph.
   1420       TF_RET_CHECK(!dst_outside_compilation_id.empty());
   1421       *src_image = subgraphs_.at(src_func_id)
   1422                        .GetRecvAtHostNode(dst_outside_compilation_id);
   1423     } else {
   1424       // The edge is from a subgraph to a regular node in the output graph so
   1425       // use the subgraph's call node output.
   1426       *src_image = subgraphs_.at(src_func_id).GetCallNodeForOutputs();
   1427     }
   1428   } else {
   1429     // The source of the edge is in the output graph so use the node image in
   1430     // the output graph.
   1431     *src_image = node_images.at(original_src_node);
   1432   }
   1433   return Status::OK();
   1434 }
   1435 
   1436 int Encapsulator::FindOutputSlotOfEdgeSrc(
   1437     const string& src_func_id, const string& src_outside_compilation_id,
   1438     const string& dst_func_id, const string& dst_outside_compilation_id,
   1439     const Edge* edge) {
   1440   if (IsInSubgraph(src_func_id, src_outside_compilation_id)) {
   1441     const Subgraph& src_subgraph = subgraphs_.at(src_func_id);
   1442     if (src_func_id == dst_func_id) {
   1443       // 'src' is in a subgraph and 'dst' is outside_compilation in the same
   1444       // subgraph. Use the corresponding _RecvAtHost output instead.
   1445       return src_subgraph.GetRecvAtHostSlot(dst_outside_compilation_id, edge);
   1446     } else {
   1447       // 'src' is in a subgraph and 'dst' is a regular node in the output
   1448       // graph. Use the corresponding call output instead.
   1449       return src_subgraph.GetResultIndexForEdge(edge);
   1450     }
   1451   } else {
   1452     // The source of the edge is in the output graph so use the regular edge
   1453     // slot.
   1454     return edge->src_output();
   1455   }
   1456 }
   1457 
   1458 Status Encapsulator::FindOutputImageOfEdgeDst(
   1459     const string& src_func_id, const string& src_outside_compilation_id,
   1460     const string& dst_func_id, const string& dst_outside_compilation_id,
   1461     const std::unordered_map<const Node*, Node*>& node_images,
   1462     const Node* original_dst_node, Node** dst_image) {
   1463   if (IsInSubgraph(dst_func_id, dst_outside_compilation_id)) {
   1464     if (src_func_id == dst_func_id) {
   1465       // The edge is to a subgraph from an outside_compilation cluster in the
   1466       // same subgraph so use the appropriate _SendFromHost node in the output
   1467       // graph.
   1468       TF_RET_CHECK(!src_outside_compilation_id.empty());
   1469       *dst_image = subgraphs_.at(dst_func_id)
   1470                        .GetSendFromHostNode(src_outside_compilation_id);
   1471     } else {
   1472       // The edge is to a subgraph from a regular node in the output graph so
   1473       // use the subgraph's call node input.
   1474       *dst_image = subgraphs_.at(dst_func_id).GetCallNodeForInputs();
   1475     }
   1476   } else {
   1477     // The destination of the edge is in the output graph so use the node image
   1478     // in the output graph.
   1479     *dst_image = node_images.at(original_dst_node);
   1480   }
   1481   return Status::OK();
   1482 }
   1483 
   1484 int Encapsulator::FindOutputSlotOfEdgeDst(
   1485     const string& src_func_id, const string& src_outside_compilation_id,
   1486     const string& dst_func_id, const string& dst_outside_compilation_id,
   1487     const Edge* edge) {
   1488   if (IsInSubgraph(dst_func_id, dst_outside_compilation_id)) {
   1489     const Subgraph& dst_subgraph = subgraphs_.at(dst_func_id);
   1490     if (dst_func_id == src_func_id) {
   1491       // 'dst' is in a subgraph and 'src' is outside_compilation in the same
   1492       // subgraph. Use the corresponding _SendFromHost input instead.
   1493       return dst_subgraph.GetSendFromHostSlot(src_outside_compilation_id, edge);
   1494     } else {
   1495       // 'dst' is in a subgraph and 'src' is a regular node in the output
   1496       // graph. Use the corresponding call input instead.
   1497       return dst_subgraph.GetArgIndexForEdge(edge);
   1498     }
   1499   } else {
   1500     // The destination of the edge is in the output graph so use the regular
   1501     // edge slot.
   1502     return edge->dst_input();
   1503   }
   1504 }
   1505 
   1506 Status Encapsulator::CopyEdgeToOutputGraph(
   1507     const Edge* edge, const string& src_func_id,
   1508     const string& src_outside_compilation_id, const string& dst_func_id,
   1509     const string& dst_outside_compilation_id,
   1510     const std::unordered_map<const Node*, Node*>& node_images,
   1511     bool parallel_checking, Graph* graph_out,
   1512     std::unordered_set<std::pair<NodeSlot, NodeSlot>, NodeSlot::PairHasher>*
   1513         edges_added) {
   1514   Node* src_image;
   1515   TF_RETURN_IF_ERROR(FindOutputImageOfEdgeSrc(
   1516       src_func_id, src_outside_compilation_id, dst_func_id,
   1517       dst_outside_compilation_id, node_images, edge->src(), &src_image));
   1518   Node* dst_image;
   1519   TF_RETURN_IF_ERROR(FindOutputImageOfEdgeDst(
   1520       src_func_id, src_outside_compilation_id, dst_func_id,
   1521       dst_outside_compilation_id, node_images, edge->dst(), &dst_image));
   1522 
   1523   // If this is a control edge then copy it and return. Lift control edges onto
   1524   // the enclosing call operator.
   1525   if (edge->IsControlEdge()) {
   1526     // Add the control edge, if we have not already added it, using the images
   1527     // determined above (potentially call operators or RecvAtHost/SendFromHost).
   1528     if (edges_added->emplace(NodeSlot(src_image, -1), NodeSlot(dst_image, -1))
   1529             .second) {
   1530       graph_out->AddControlEdge(src_image, dst_image);
   1531     }
   1532 
   1533     // If parallel checking is enabled, also add a control edge to the
   1534     // corresponding parallel check op.
   1535     if (parallel_checking) {
   1536       graph_out->AddControlEdge(src_image, node_images.at(edge->dst()));
   1537     }
   1538     return Status::OK();
   1539   }
   1540 
   1541   int src_output =
   1542       FindOutputSlotOfEdgeSrc(src_func_id, src_outside_compilation_id,
   1543                               dst_func_id, dst_outside_compilation_id, edge);
   1544 
   1545   int dst_input =
   1546       FindOutputSlotOfEdgeDst(src_func_id, src_outside_compilation_id,
   1547                               dst_func_id, dst_outside_compilation_id, edge);
   1548 
   1549   if (IsInSubgraph(dst_func_id, dst_outside_compilation_id) &&
   1550       parallel_checking) {
   1551     // If we are parallel checking, also feed the tensor as an input to the
   1552     // corresponding parallel check subgraph.
   1553     graph_out->AddEdge(src_image, src_output, node_images.at(edge->dst()),
   1554                        edge->dst_input());
   1555   }
   1556 
   1557   // Add the edge, if we have not already added it.
   1558   if (edges_added
   1559           ->emplace(NodeSlot(src_image, src_output),
   1560                     NodeSlot(dst_image, dst_input))
   1561           .second) {
   1562     graph_out->AddEdge(src_image, src_output, dst_image, dst_input);
   1563   }
   1564   return Status::OK();
   1565 }
   1566 
   1567 Status Encapsulator::AddEdgesToOutputGraph(
   1568     const std::unordered_map<const Node*, Node*>& node_images,
   1569     bool parallel_checking, Graph* graph_out) {
   1570   // Set of edges already added to the output graph, represented as (src, dst)
   1571   // pairs. We use the set to deduplicate edges; multiple edges in the input
   1572   // graph may map to one edge in the output graph.
   1573   std::unordered_set<std::pair<NodeSlot, NodeSlot>, NodeSlot::PairHasher>
   1574       edges_added;
   1575 
   1576   for (const Edge* edge : graph_in_->edges()) {
   1577     string src_func_id;
   1578     string src_outside_compilation_id;
   1579     TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id,
   1580                                            &src_outside_compilation_id));
   1581     string dst_func_id;
   1582     string dst_outside_compilation_id;
   1583     TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id,
   1584                                            &dst_outside_compilation_id));
   1585 
   1586     // Ignore edges that are strictly contained within one subgraph, unless
   1587     // we are constructing parallel check graphs.
   1588     if (IsInSubgraph(src_func_id, src_outside_compilation_id) &&
   1589         IsInSubgraph(dst_func_id, dst_outside_compilation_id) &&
   1590         src_func_id == dst_func_id) {
   1591       if (parallel_checking) {
   1592         Node* src_image = node_images.at(edge->src());
   1593         Node* dst_image = node_images.at(edge->dst());
   1594         if (edge->IsControlEdge()) {
   1595           graph_out->AddControlEdge(src_image, dst_image);
   1596         } else {
   1597           graph_out->AddEdge(src_image, edge->src_output(), dst_image,
   1598                              edge->dst_input());
   1599         }
   1600       }
   1601       continue;
   1602     }
   1603 
   1604     // We have an edge that crosses a cluster boundary or is entirely within the
   1605     // unclustered graph.
   1606     TF_RETURN_IF_ERROR(CopyEdgeToOutputGraph(
   1607         edge, src_func_id, src_outside_compilation_id, dst_func_id,
   1608         dst_outside_compilation_id, node_images, parallel_checking, graph_out,
   1609         &edges_added));
   1610   }
   1611 
   1612   for (auto& subgraph_entry : subgraphs_) {
   1613     Subgraph& subgraph = subgraph_entry.second;
   1614     subgraph.ConnectSequencerToOutputs(graph_out);
   1615   }
   1616 
   1617   return Status::OK();
   1618 }
   1619 
   1620 namespace {
   1621 
   1622 // Adds a dummy Const node to graph_out. The "constant" has the type of
   1623 // data_type and the shape indicated in 'shape'. The dummy node is not a valid
   1624 // Const node because it does not have any value defined, but this doesn't
   1625 // matter because it will only be used subsequently for shape inference. (It
   1626 // would be possible to add a switch statement over data_type to create a value
   1627 // for the constant, but that would entail maintaining the logic as new types
   1628 // are added, and is not necessary.)
   1629 Node* AddDummyShapedNode(DataType data_type, const TensorShapeProto& shape,
   1630                          Graph* graph_out) {
   1631   TensorProto dummy_proto;
   1632   dummy_proto.set_dtype(data_type);
   1633   *dummy_proto.mutable_tensor_shape() = shape;
   1634   // Don't set any value field in the proto, since it is only going to be used
   1635   // for shape inference.
   1636 
   1637   GraphDefBuilder::Options options(graph_out, /*status=*/nullptr);
   1638   NodeBuilder node_builder(options.GetNameForOp("KnownShape"), "Const",
   1639                            options.op_registry());
   1640   node_builder.Attr("dtype", data_type).Attr("value", dummy_proto);
   1641   return options.FinalizeBuilder(&node_builder);
   1642 }
   1643 
   1644 // Adds a copy of node_in to graph_out and adds the mapping to
   1645 // copied_node_images.
   1646 Status CopyShapeInferenceNodeToGraph(
   1647     Node* node_in, const Node* send_node,
   1648     const std::unordered_map<Node*, Node*>& dummy_node_images,
   1649     FunctionLibraryDefinition* library,
   1650     std::unordered_map<Node*, Node*>* copied_node_images, Graph* graph_out) {
   1651   // Once all the ancestor nodes have been added to graph_out, add this node
   1652   // and connect it to its ancestors.
   1653   Node* node_out = graph_out->CopyNode(node_in);
   1654   (*copied_node_images)[node_in] = node_out;
   1655   // Don't bother to build the shape inference graph if there's a node with no
   1656   // shape inference function, since it would just result in an error later at
   1657   // compile time.
   1658   const OpRegistrationData* op_reg_data;
   1659   TF_RETURN_IF_ERROR(library->LookUp(node_in->type_string(), &op_reg_data));
   1660   if (op_reg_data->shape_inference_fn == nullptr) {
   1661     return errors::InvalidArgument(
   1662         "Shape inference is not possible for outside_compilation "
   1663         "SendFromHost node ",
   1664         send_node->name(), " because it depends on node ", node_in->name(),
   1665         " which does not have a shape inference function registered.");
   1666   }
   1667   // Add all the edges to the newly copied node.
   1668   for (const Edge* in_edge : node_in->in_edges()) {
   1669     if (!in_edge->IsControlEdge()) {
   1670       Node* src = in_edge->src();
   1671       const auto iter = dummy_node_images.find(src);
   1672       if (iter == dummy_node_images.end()) {
   1673         // The src is a copied node so use the original output port.
   1674         graph_out->AddEdge((*copied_node_images)[in_edge->src()],
   1675                            in_edge->src_output(), node_out,
   1676                            in_edge->dst_input());
   1677       } else {
   1678         // The src is a dummy node so use output port 0.
   1679         graph_out->AddEdge(iter->second, 0, node_out, in_edge->dst_input());
   1680       }
   1681     }
   1682   }
   1683   return Status::OK();
   1684 }
   1685 
   1686 }  // namespace
   1687 
   1688 Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
   1689     const Graph& graph_in, const ShapeRefiner& shape_refiner,
   1690     const std::unordered_set<string>& recv_at_host_nodes, Node* send_node,
   1691     FunctionLibraryDefinition* library,
   1692     std::vector<TensorShapeProto>* static_shape_out,
   1693     std::unique_ptr<GraphDef>* graphdef_out) {
   1694   // Maps from nodes in graph_in to nodes in graph_out.
   1695   //
   1696   // When an edge has fully defined shape the source node in graph_in is
   1697   // replaced in graph_out by a dummy constant node. The mapping from nodes
   1698   // in graph_in to dummy nodes is stored in dummy_node_images.
   1699   //
   1700   // When a node in graph_in has at least one ancestor that doesn't have fully
   1701   // defined shape, it is copied into graph_out. The mapping from nodes in
   1702   // graph_in to copied nodes is stored in copied_node_images.
   1703   //
   1704   // The two types of node are treated differently because, when adding edges to
   1705   // graph_out, an output from a dummy node always uses port 0, whereas an
   1706   // output from a copied node uses the same port that was used in graph_in.
   1707   std::unordered_map<Node*, Node*> dummy_node_images;
   1708   std::unordered_map<Node*, Node*> copied_node_images;
   1709 
   1710   std::unique_ptr<Graph> graph_out(new Graph(graph_in.op_registry()));
   1711   graph_out->set_versions(graph_in.versions());
   1712   static_shape_out->resize(send_node->num_inputs());
   1713 
   1714   // We don't use the standard ReverseDFS because we want to cut off traversal
   1715   // whenever we find an output with fully defined shape.
   1716   // TODO(misard) make this work properly in the presence of control flow.
   1717   struct Work {
   1718     Node* node;
   1719     bool leave;  // Are we entering or leaving node?
   1720   };
   1721   std::vector<Work> stack({{send_node, false}});
   1722   std::vector<bool> visited(graph_in.num_node_ids(), false);
   1723   while (!stack.empty()) {
   1724     Work w = stack.back();
   1725     stack.pop_back();
   1726     Node* n = w.node;
   1727 
   1728     if (w.leave) {
   1729       TF_RETURN_IF_ERROR(CopyShapeInferenceNodeToGraph(
   1730           n, send_node, dummy_node_images, library, &copied_node_images,
   1731           graph_out.get()));
   1732     } else {
   1733       if (visited[n->id()]) continue;
   1734       visited[n->id()] = true;
   1735 
   1736       // Arrange to revisit when all done with all inputs.
   1737       stack.push_back(Work{n, true});
   1738 
   1739       bool has_parent_with_unknown_shape = false;
   1740       for (const Edge* in_edge : n->in_edges()) {
   1741         if (!in_edge->IsControlEdge()) {
   1742           Node* src_node = in_edge->src();
   1743           int src_port = in_edge->src_output();
   1744           shape_inference::InferenceContext* context =
   1745               shape_refiner.GetContext(src_node);
   1746           shape_inference::ShapeHandle shape = context->output(src_port);
   1747           if (context->FullyDefined(shape)) {
   1748             // This ancestor has known shape, so instead of adding it to the
   1749             // stack, add a dummy node with that shape to graph_out and
   1750             // continue.
   1751             TensorShapeProto proto;
   1752             context->ShapeHandleToProto(shape, &proto);
   1753             dummy_node_images[src_node] = AddDummyShapedNode(
   1754                 src_node->output_type(src_port), proto, graph_out.get());
   1755             if (n == send_node) {
   1756               (*static_shape_out)[in_edge->dst_input()] = proto;
   1757             }
   1758           } else {
   1759             if (!visited[src_node->id()]) {
   1760               has_parent_with_unknown_shape = true;
   1761               stack.push_back({src_node, false});
   1762             }
   1763           }
   1764         }
   1765       }
   1766       if (!has_parent_with_unknown_shape) {
   1767         if (n == send_node) {
   1768           // The shapes of all the inputs to send_node are statically known. We
   1769           // won't have to do any inference at compile time so return now: the
   1770           // shapes were stored in static_shape_out above.
   1771           graphdef_out->reset();
   1772           return Status::OK();
   1773         } else {
   1774           // Any shape that is being processed is either the original send node
   1775           // or has at least one output with statically-unknown shape. If the
   1776           // latter and it doesn't have any inputs with statically-unknown
   1777           // shape, then check that it is of the recv nodes that we can fill in
   1778           // the shape of at run-time later. If it isn't one of those, then we
   1779           // won't have any additional knowledge at compile time, so we already
   1780           // know we won't be able to do shape inference and we can return an
   1781           // error now.
   1782           if (recv_at_host_nodes.find(n->name()) == recv_at_host_nodes.end()) {
   1783             return errors::InvalidArgument(
   1784                 "Shape inference is not possible for outside_compilation "
   1785                 "SendFromHost node ",
   1786                 send_node->name(), " because shape of node ", n->name(),
   1787                 " will not be known at compilation time.");
   1788           }
   1789         }
   1790       }
   1791     }
   1792   }
   1793 
   1794   graphdef_out->reset(new GraphDef());
   1795   graph_out->ToGraphDef(graphdef_out->get());
   1796 
   1797   return Status::OK();
   1798 }
   1799 
   1800 Status Encapsulator::MakePrunedGraphCopyAndInline(
   1801     const Graph& graph, const std::vector<Node*>& sink_nodes,
   1802     std::unique_ptr<Graph>* pruned_graph,
   1803     std::unordered_map<const Node*, Node*>* node_images,
   1804     FunctionLibraryDefinition* library) {
   1805   // First copy all ancestor nodes of sink_nodes into a new graph.
   1806   pruned_graph->reset(new Graph(library));
   1807   (*pruned_graph)->set_versions(graph.versions());
   1808   ReverseDFSFrom(graph, sink_nodes,
   1809                  /*enter=*/nullptr,
   1810                  /*leave=*/[&](Node* n) {
   1811                    if (!n->IsSource()) {
   1812                      Node* copied = (*pruned_graph)->CopyNode(n);
   1813                      node_images->emplace(n, copied);
   1814                    }
   1815                  });
   1816 
   1817   // Add all the edges between copied nodes.
   1818   for (auto entry : *node_images) {
   1819     const Node* orig = entry.first;
   1820     Node* image = entry.second;
   1821     for (const Edge* out_edge : orig->out_edges()) {
   1822       auto iter = node_images->find(out_edge->dst());
   1823       if (iter != node_images->end()) {
   1824         // The source and destination are both in the copied graph.
   1825         (*pruned_graph)
   1826             ->AddEdge(image, out_edge->src_output(), iter->second,
   1827                       out_edge->dst_input());
   1828       }
   1829     }
   1830   }
   1831 
   1832   // Find all the function call nodes, and inline them.
   1833   std::vector<Node*> function_nodes;
   1834   for (auto node : (*pruned_graph)->nodes()) {
   1835     const OpRegistrationData* op_reg_data;
   1836     TF_RETURN_IF_ERROR(library->LookUp(node->type_string(), &op_reg_data));
   1837     if (op_reg_data->is_function_op) {
   1838       function_nodes.push_back(node);
   1839     }
   1840   }
   1841   for (auto node : function_nodes) {
   1842     VLOG(2) << "Inlining function " << node->name();
   1843     const FunctionDef* fdef = library->Find(node->type_string());
   1844     if (fdef == nullptr) {
   1845       return errors::Internal("Failed to find function ", node->type_string(),
   1846                               " in function library.");
   1847     }
   1848     FunctionBody* fbody = nullptr;
   1849     TF_RETURN_IF_ERROR(
   1850         FunctionDefToBodyHelper(*fdef, node->attrs(), library,
   1851                                 [library](const string& op, const OpDef** sig) {
   1852                                   return library->LookUpOpDef(op, sig);
   1853                                 },
   1854                                 &fbody));
   1855     InlineFunctionBody(*library, pruned_graph->get(), node, fbody);
   1856     delete fbody;
   1857   }
   1858 
   1859   return Status::OK();
   1860 }
   1861 
   1862 Status Encapsulator::MakeGraphForOutsideCompilationSends(
   1863     const Graph& graph, std::unique_ptr<Graph>* pruned_graph,
   1864     ShapeRefiner* shape_refiner,
   1865     std::unordered_map<const Node*, Node*>* node_images,
   1866     FunctionLibraryDefinition* library) {
   1867   // Find all the send_from_host nodes in all subgraphs, to use as roots for the
   1868   // pruning.
   1869   std::vector<Node*> send_from_host_nodes;
   1870   for (auto& subgraph_entry : subgraphs_) {
   1871     Subgraph& subgraph = subgraph_entry.second;
   1872     std::vector<string> outside_compilation_names;
   1873     subgraph.GetOutsideCompilationSubgraphNames(&outside_compilation_names);
   1874     for (const auto& name : outside_compilation_names) {
   1875       Node* send_node = subgraph.GetSendFromHostNode(name);
   1876       if (send_node != nullptr) {
   1877         send_from_host_nodes.push_back(send_node);
   1878       }
   1879     }
   1880   }
   1881 
   1882   // Make a copy of all the graph nodes needed to evaluate the send_from_host
   1883   // nodes, inlining any functions as needed.
   1884   TF_RETURN_IF_ERROR(MakePrunedGraphCopyAndInline(
   1885       graph, send_from_host_nodes, pruned_graph, node_images, library));
   1886 
   1887   // Perform shape inference on the pruned graph.
   1888   shape_refiner->set_require_shape_inference_fns(false);
   1889   FixupSourceAndSinkEdges(pruned_graph->get());
   1890   std::vector<Node*> post_order;
   1891   GetReversePostOrder(*(*pruned_graph), &post_order);
   1892   for (auto node : post_order) {
   1893     // Ignore the status returned by the shape_refiner. At this point we want
   1894     // the best effort shapes, even if no shape function is registered for a
   1895     // node.
   1896     Status status = shape_refiner->AddNode(node);
   1897     if (!status.ok()) {
   1898       VLOG(1) << "Shape inference failed for node: " << status;
   1899     }
   1900   }
   1901 
   1902   return Status::OK();
   1903 }
   1904 
   1905 Status Encapsulator::GetShapeInfoForOutsideCompilationSends(
   1906     Graph* graph_out, FunctionLibraryDefinition* library) {
   1907   std::unique_ptr<Graph> pruned_graph;
   1908   ShapeRefiner shape_refiner(graph_out->versions(), graph_out->op_registry());
   1909   std::unordered_map<const Node*, Node*> node_images;
   1910   TF_RETURN_IF_ERROR(MakeGraphForOutsideCompilationSends(
   1911       *graph_out, &pruned_graph, &shape_refiner, &node_images, library));
   1912 
   1913   for (auto& subgraph_entry : subgraphs_) {
   1914     Subgraph& subgraph = subgraph_entry.second;
   1915     // Find all the recv_at_host nodes in this subgraph.
   1916     std::vector<string> outside_compilation_names;
   1917     subgraph.GetOutsideCompilationSubgraphNames(&outside_compilation_names);
   1918     std::unordered_set<string> recv_at_host_names;
   1919     for (const auto& name : outside_compilation_names) {
   1920       Node* recv_node = subgraph.GetRecvAtHostNode(name);
   1921       if (recv_node != nullptr) {
   1922         recv_at_host_names.insert(recv_node->name());
   1923       }
   1924     }
   1925     // For each send_from_host node, do as much shape inference as possible
   1926     // without knowing the shape of the recv_at_host nodes, and store the
   1927     // result, along with enough information to complete the job at compile time
   1928     // once the recv_at_host shapes are known.
   1929     for (const auto& name : outside_compilation_names) {
   1930       Node* send_node = subgraph.GetSendFromHostNode(name);
   1931       std::vector<TensorShapeProto> static_shape;
   1932       std::unique_ptr<GraphDef> graphdef;
   1933       if (send_node != nullptr) {
   1934         TF_RETURN_IF_ERROR(DoStaticShapeInferenceForOutsideCompilationSend(
   1935             *pruned_graph, shape_refiner, recv_at_host_names,
   1936             node_images[send_node], library, &static_shape, &graphdef));
   1937         if (graphdef == nullptr) {
   1938           VLOG(2) << "Send node  " << send_node->name() << " shapes";
   1939           for (int i = 0; i < static_shape.size(); ++i) {
   1940             VLOG(2) << static_shape[i].DebugString();
   1941           }
   1942         } else {
   1943           VLOG(2) << "Send node " << send_node->name() << " graph\n"
   1944                   << graphdef->DebugString();
   1945         }
   1946       }
   1947       TF_RETURN_IF_ERROR(
   1948           subgraph.AddShapeInferenceInfo(name, static_shape, graphdef.get()));
   1949     }
   1950     if (!outside_compilation_names.empty()) {
   1951       TF_RETURN_IF_ERROR(subgraph.ReplaceFunctionDef(library));
   1952     }
   1953   }
   1954 
   1955   return Status::OK();
   1956 }
   1957 
   1958 Status Encapsulator::BuildOutputGraph(bool parallel_checking, Graph* graph_out,
   1959                                       FunctionLibraryDefinition* library) {
   1960   // Map from nodes in the input graph to nodes in the output graph.
   1961   std::unordered_map<const Node*, Node*> node_images;
   1962 
   1963   TF_RETURN_IF_ERROR(
   1964       CopyNodesToOutputGraph(parallel_checking, graph_out, &node_images));
   1965   TF_RETURN_IF_ERROR(
   1966       AddFunctionCallNodes(node_images, parallel_checking, graph_out));
   1967   TF_RETURN_IF_ERROR(AddOutsideCompilationHostIONodes(node_images, graph_out));
   1968   TF_RETURN_IF_ERROR(
   1969       AddEdgesToOutputGraph(node_images, parallel_checking, graph_out));
   1970 
   1971   TF_RETURN_IF_ERROR(
   1972       GetShapeInfoForOutsideCompilationSends(graph_out, library));
   1973 
   1974   return Status::OK();
   1975 }
   1976 
   1977 }  // anonymous namespace
   1978 
   1979 Status EncapsulateSubgraphsInFunctions(
   1980     string group_attribute, string outside_compilation_attribute,
   1981     const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn,
   1982     bool parallel_checking, bool reuse_existing_functions,
   1983     std::unique_ptr<Graph>* graph_out, FunctionLibraryDefinition* library) {
   1984   Status s;
   1985 
   1986   Encapsulator encapsulator(std::move(group_attribute),
   1987                             std::move(outside_compilation_attribute),
   1988                             &graph_in);
   1989   TF_RETURN_IF_ERROR(encapsulator.SplitIntoSubgraphs());
   1990 
   1991   TF_RETURN_IF_ERROR(encapsulator.BuildFunctionDefs(
   1992       rewrite_subgraph_fn, reuse_existing_functions, library));
   1993 
   1994   std::unique_ptr<Graph> out(new Graph(library));
   1995   out->set_versions(graph_in.versions());
   1996   TF_RETURN_IF_ERROR(
   1997       encapsulator.BuildOutputGraph(parallel_checking, out.get(), library));
   1998 
   1999   *graph_out = std::move(out);
   2000   return Status::OK();
   2001 }
   2002 
   2003 // Finds the types of the _Arg nodes, indexed by position.
   2004 static Status GetArgTypes(const Graph& graph, DataTypeVector* types) {
   2005   for (Node* n : graph.op_nodes()) {
   2006     if (n->type_string() == kArgOp) {
   2007       int index;
   2008       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
   2009       if (index < 0 || index >= types->size()) {
   2010         return errors::InvalidArgument("Invalid argument number");
   2011       }
   2012       (*types)[index] = n->output_type(0);
   2013     }
   2014   }
   2015   return Status::OK();
   2016 }
   2017 
   2018 // Renumber the indices of _Arg nodes in a graph, according to
   2019 // 'permutation' that maps old indices to new indices.
   2020 static Status RenumberArguments(Graph* graph,
   2021                                 const std::vector<int>& permutation) {
   2022   for (Node* n : graph->op_nodes()) {
   2023     if (n->type_string() == kArgOp) {
   2024       int index;
   2025       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
   2026       if (index < 0 || index >= permutation.size()) {
   2027         return errors::InvalidArgument("Invalid argument number");
   2028       }
   2029       n->AddAttr("index", permutation[index]);
   2030     }
   2031   }
   2032   return Status::OK();
   2033 }
   2034 
   2035 Status EncapsulateSubgraphsPass::Run(
   2036     const GraphOptimizationPassOptions& options) {
   2037   VLOG(1) << "EncapsulateSubgraphsPass::Run";
   2038   legacy_flags::EncapsulateSubgraphsPassFlags* flags =
   2039       legacy_flags::GetEncapsulateSubgraphsPassFlags();
   2040   if (VLOG_IS_ON(1)) {
   2041     dump_graph::DumpGraphToFile("before_encapsulate_subgraphs", **options.graph,
   2042                                 options.flib_def);
   2043   }
   2044 
   2045   std::unique_ptr<Graph> graph_out;
   2046   FunctionLibraryDefinition* const library = options.flib_def;
   2047 
   2048   OptimizerOptions opts;
   2049   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
   2050       new ProcessFunctionLibraryRuntime(nullptr, options.session_options->env,
   2051                                         TF_GRAPH_DEF_VERSION, library, opts));
   2052   FunctionLibraryRuntime* flr =
   2053       pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
   2054 
   2055   auto rewrite_subgraph = [flr](std::unique_ptr<Graph>* subgraph,
   2056                                 std::vector<int>* input_permutation,
   2057                                 std::vector<int>* output_permutation,
   2058                                 NodeDef* node) {
   2059     // Optimize the subgraph.
   2060     OptimizeGraph(flr, subgraph);
   2061 
   2062     const int num_args = input_permutation->size();
   2063     std::vector<bool> const_args(num_args);
   2064     TF_RETURN_IF_ERROR(BackwardsConstAnalysis(**subgraph, &const_args));
   2065 
   2066     DataTypeVector arg_types(num_args);
   2067     TF_RETURN_IF_ERROR(GetArgTypes(**subgraph, &arg_types));
   2068 
   2069     // Compute a permutation of the arguments such that the constant arguments
   2070     // are first.
   2071     const int num_consts =
   2072         std::count(const_args.begin(), const_args.end(), true);
   2073 
   2074     const int num_resources =
   2075         std::count(arg_types.begin(), arg_types.end(), DT_RESOURCE);
   2076     const int num_nonconsts = num_args - num_resources - num_consts;
   2077     if (num_nonconsts < 0) {
   2078       return errors::Internal("num_nonconsts should be >= 0, was ",
   2079                               num_nonconsts);
   2080     }
   2081 
   2082     int const_pos = 0;
   2083     int arg_pos = num_consts;
   2084     int resource_pos = num_consts + num_nonconsts;
   2085     for (int i = 0; i < num_args; ++i) {
   2086       if (const_args[i]) {
   2087         if (arg_types[i] == DT_RESOURCE) {
   2088           return errors::Internal(
   2089               "Resource arguments cannot be constant (argument ", i, ")");
   2090         }
   2091         (*input_permutation)[i] = const_pos;
   2092         ++const_pos;
   2093       } else if (arg_types[i] == DT_RESOURCE) {
   2094         (*input_permutation)[i] = resource_pos;
   2095         ++resource_pos;
   2096       } else {
   2097         (*input_permutation)[i] = arg_pos;
   2098         ++arg_pos;
   2099       }
   2100     }
   2101 
   2102     // Renumber argument nodes in the graph.
   2103     TF_RETURN_IF_ERROR(RenumberArguments(subgraph->get(), *input_permutation));
   2104 
   2105     // TODO(phawkins): add a forward is-constant analysis, similarly split
   2106     // outputs into host-memory constants and device-memory non-constants.
   2107 
   2108     AddNodeAttr(kXlaCompiledKernelAttr, true, node);
   2109     AddNodeAttr(kXlaNumConstantArgsAttr, num_consts, node);
   2110     AddNodeAttr(kXlaNumResourceArgsAttr, num_resources, node);
   2111     return Status::OK();
   2112   };
   2113 
   2114   TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions(
   2115       kXlaClusterAttr, kXlaOutsideCompilationAttr, **options.graph,
   2116       rewrite_subgraph, flags->tf_xla_parallel_checking,
   2117       /*reuse_existing_functions=*/false, &graph_out, library));
   2118 
   2119   if (VLOG_IS_ON(1)) {
   2120     dump_graph::DumpGraphToFile("after_encapsulate_subgraphs", *graph_out,
   2121                                 options.flib_def);
   2122   }
   2123 
   2124   *options.graph = std::move(graph_out);
   2125   return Status::OK();
   2126 }
   2127 
   2128 bool IsXlaCompiledKernel(const Node& node) {
   2129   bool is_compiled = false;
   2130   bool has_compilation_attr =
   2131       GetNodeAttr(node.attrs(), kXlaCompiledKernelAttr, &is_compiled).ok() &&
   2132       is_compiled;
   2133   return has_compilation_attr ? is_compiled : false;
   2134 }
   2135 
   2136 }  // namespace tensorflow
   2137