Home | History | Annotate | Download | only in tf2xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
     17 
     18 #include <algorithm>
     19 #include <deque>
     20 #include <stack>
     21 #include <unordered_set>
     22 #include <vector>
     23 
     24 #include "tensorflow/compiler/jit/graph_to_functiondef.h"
     25 #include "tensorflow/compiler/jit/union_find.h"
     26 #include "tensorflow/compiler/tf2xla/dump_graph.h"
     27 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
     28 #include "tensorflow/compiler/xla/ptr_util.h"
     29 #include "tensorflow/compiler/xla/status_macros.h"
     30 #include "tensorflow/core/common_runtime/function.h"
     31 #include "tensorflow/core/framework/node_def_builder.h"
     32 #include "tensorflow/core/graph/algorithm.h"
     33 #include "tensorflow/core/graph/control_flow.h"
     34 #include "tensorflow/core/lib/gtl/optional.h"
     35 
     36 namespace tensorflow {
     37 
     38 namespace {
     39 
     40 using xla::StatusOr;
     41 
     42 const char* const kArgOp = "_Arg";
     43 const char* const kRetValOp = "_Retval";
     44 
     45 // Information about a loop argument.
     46 struct Arg {
     47   // Every loop argument has an Enter node.
     48   Node* enter;
     49 
     50   // Is the loop argument a loop-invariant value? Taken from the `is_constant`
     51   // attribute on the Enter node.
     52   bool is_loop_invariant;
     53 
     54   // If 'is_loop_invariant' is true, the following are all nullptr. Non-constant
     55   // arguments must have all of the following nodes:
     56   Node* merge = nullptr;
     57   Node* switch_node = nullptr;
     58   Node* next_iteration = nullptr;
     59   Node* exit = nullptr;
     60 };
     61 
     62 // Information about a loop frame.
     63 struct Frame {
     64   string name;
     65 
     66   // Pointer to the parent frame. The root frame has a pointer to itself.
     67   Frame* parent = nullptr;
     68   int num_children = 0;
     69 
     70   // Arguments to this loop.
     71   std::vector<Arg> args;
     72 
     73   // The loop condition of the loop. There should be exactly one loop condition
     74   // in every loop.
     75   Node* loop_cond = nullptr;
     76 
     77   // Set of nodes that belong to the loop frame.
     78   std::unordered_set<Node*> nodes;
     79 };
     80 
     81 // Comparison function used for sorting nodes consistently.
     82 // a) resource variables are last, and
     83 // b) sort lexicographically by name (for deterministic output).
     84 struct NodeCmp {
     85   bool operator()(const Node* lhs, const Node* rhs) const {
     86     bool lhs_is_resource =
     87         lhs->num_inputs() > 0 ? (lhs->input_type(0) == DT_RESOURCE) : false;
     88     bool rhs_is_resource =
     89         rhs->num_inputs() > 0 ? (rhs->input_type(0) == DT_RESOURCE) : false;
     90     return std::tie(lhs_is_resource, lhs->name()) <
     91            std::tie(rhs_is_resource, rhs->name());
     92   }
     93 };
     94 
     95 // Returns a textual representation of the names of the nodes in the input.
     96 template <typename T>
     97 string NodesToString(const T& nodes) {
     98   return strings::StrCat("{",
     99                          str_util::Join(nodes, ",",
    100                                         [](string* output, const Node* node) {
    101                                           strings::StrAppend(output,
    102                                                              node->name());
    103                                         }),
    104                          "}");
    105 }
    106 
    107 // Copies a subgraph from `graph` to `output` by performing a reverse DFS
    108 // starting at nodes in vector `stack`.
    109 // `node_map` is a vector indexed by source node ID to dest nodes.
    110 // Does not traverse into nodes in `node_map`, so by adding nodes to `node_map`
    111 // before the traversal clients can cut the graph. If a frame is provided (frame
    112 // != nullptr), then this functions will return an error if the
    113 // traversal leaves 'frame'; the client must add enough nodes to `node_map` to
    114 // cut the graph and prevent the traversal from escaping.
    115 //
    116 // `squash_src_outputs` contains a bool for each source node ID. If true, then
    117 // the source output on that node will be replaced by zero when copied. This is
    118 // used when replacing a Switch node with an _Arg node. The output we are
    119 // taking from the Switch node was not necessarily the first output, but _Arg
    120 // nodes only have one output. By adding the Switch node to `squash_src_outputs`
    121 // we rewrite the src_output of the corresponding edge to be 0.
    122 Status CopySubgraph(const Graph& graph, const Frame* frame,
    123                     std::vector<Node*> stack,
    124                     const std::vector<bool>& squash_src_outputs,
    125                     std::vector<Node*>* node_map, Graph* output) {
    126   VLOG(3) << "Stack: " << NodesToString(stack);
    127   std::vector<bool> visited(graph.num_node_ids(), false);
    128   while (!stack.empty()) {
    129     Node* n = stack.back();
    130     stack.pop_back();
    131 
    132     VLOG(5) << "Copying node " << n->name();
    133 
    134     if (visited[n->id()]) continue;
    135     visited[n->id()] = true;
    136 
    137     for (const Edge* e : n->in_edges()) {
    138       Node* src = e->src();
    139       if (frame != nullptr && frame->nodes.find(src) == frame->nodes.end()) {
    140         // We traversed out of the loop frame, without encountering a cut node.
    141         return errors::Internal("Graph traversal of loop frame ", frame->name,
    142                                 " escaped frame at ", src->name(),
    143                                 " without encountering an argument node.");
    144       }
    145       if ((*node_map)[src->id()] == nullptr) {
    146         (*node_map)[src->id()] = output->CopyNode(src);
    147         stack.push_back(src);
    148       }
    149       Node* src_copy = (*node_map)[e->src()->id()];
    150       int src_output = squash_src_outputs[e->src()->id()] && !e->IsControlEdge()
    151                            ? 0
    152                            : e->src_output();
    153       Node* dst_copy = (*node_map)[e->dst()->id()];
    154       output->AddEdge(src_copy, src_output, dst_copy, e->dst_input());
    155     }
    156   }
    157   return Status::OK();
    158 }
    159 
    160 StatusOr<Node*> AddNode(const NodeDef& node_def, Graph* graph) {
    161   Status status;
    162   Node* inserted_node = graph->AddNode(node_def, &status);
    163   if (!status.ok()) {
    164     return status;
    165   }
    166   return inserted_node;
    167 }
    168 
    169 StatusOr<Node*> BuildArgNode(Graph* graph, DataType type, int index) {
    170   NodeDef arg_def;
    171   NodeDefBuilder builder(strings::StrCat(kArgOp, index), kArgOp);
    172   builder.Attr("T", type);
    173   builder.Attr("index", index);
    174   TF_RETURN_IF_ERROR(builder.Finalize(&arg_def));
    175   return AddNode(arg_def, graph);
    176 }
    177 
    178 StatusOr<Node*> BuildRetvalNode(Graph* graph, DataType type, int index) {
    179   NodeDef ret_def;
    180   ret_def.set_op(kRetValOp);
    181   ret_def.set_name(strings::StrCat(kRetValOp, index));
    182   AddNodeAttr("T", type, &ret_def);
    183   AddNodeAttr("index", index, &ret_def);
    184   return AddNode(ret_def, graph);
    185 }
    186 
    187 // Builds a graph for the loop condition.
    188 Status BuildLoopCondition(const Graph& graph, Frame* frame,
    189                           std::unique_ptr<Graph>* cond_output) {
    190   VLOG(2) << "Building loop condition for " << frame->name;
    191   *cond_output = xla::MakeUnique<Graph>(graph.op_registry());
    192   Graph* output = cond_output->get();
    193 
    194   // Map from nodes in the original graph to the condition graph.
    195   std::vector<Node*> node_map(graph.num_node_ids(), nullptr);
    196   std::vector<bool> squash_src_outputs(graph.num_node_ids(), false);
    197 
    198   // Build one _Arg node for each Enter node.
    199   for (int i = 0; i < frame->args.size(); ++i) {
    200     const Arg& arg = frame->args[i];
    201 
    202     TF_ASSIGN_OR_RETURN(Node * arg_node,
    203                         BuildArgNode(output, arg.enter->input_type(0), i));
    204     if (arg.is_loop_invariant) {
    205       node_map[arg.enter->id()] = arg_node;
    206     } else {
    207       node_map[arg.merge->id()] = arg_node;
    208     }
    209   }
    210 
    211   // Build a Retval node for the loop condition. The LoopCond nodes are always
    212   // boolean because of the type constraints on the LoopCond op.
    213   TF_ASSIGN_OR_RETURN(node_map[frame->loop_cond->id()],
    214                       BuildRetvalNode(output, DT_BOOL, 0));
    215 
    216   // Performs a reverse DFS, copying nodes and edges to the output graph.
    217   // The _Arg and _Retval nodes were added unconditionally above, so we are
    218   // guaranteed to get the correct function signature.
    219   return CopySubgraph(graph, frame, {frame->loop_cond}, squash_src_outputs,
    220                       &node_map, output);
    221 }
    222 
    223 // Builds a graph for the loop body.
    224 Status BuildLoopBody(const Graph& graph, Frame* frame,
    225                      DataTypeVector* arg_types,
    226                      std::unique_ptr<Graph>* body_output) {
    227   VLOG(2) << "Building loop body for " << frame->name;
    228   *body_output = xla::MakeUnique<Graph>(graph.op_registry());
    229   Graph* output = body_output->get();
    230 
    231   // Map from nodes in the original graph to the condition graph.
    232   std::vector<Node*> node_map(graph.num_node_ids(), nullptr);
    233   std::vector<bool> squash_src_outputs(graph.num_node_ids(), false);
    234 
    235   // Build one _Arg node for each Enter node.
    236   std::vector<Node*> next_iterations;
    237   next_iterations.reserve(frame->args.size());
    238   arg_types->reserve(frame->args.size());
    239   for (int i = 0; i < frame->args.size(); ++i) {
    240     const Arg& arg = frame->args[i];
    241 
    242     DataType dtype = arg.enter->input_type(0);
    243     arg_types->push_back(dtype);
    244 
    245     TF_ASSIGN_OR_RETURN(Node * arg_node, BuildArgNode(output, dtype, i));
    246 
    247     if (dtype == DT_RESOURCE) {
    248       // The convention of the XLA bridge is that resource variable arguments
    249       // are only inputs to the loop body and have no corresponding output.
    250       // TODO(b/37741920): change the convention so that DT_RESOURCE variables
    251       // are both inputs and outputs, and then remove this case.
    252       TF_RET_CHECK(arg.is_loop_invariant);
    253       node_map[arg.enter->id()] = arg_node;
    254     } else {
    255       TF_ASSIGN_OR_RETURN(Node * retval_node,
    256                           BuildRetvalNode(output, dtype, i));
    257 
    258       if (arg.is_loop_invariant) {
    259         // Argument is loop-invariant. Forward it from the Arg to the Retval.
    260         node_map[arg.enter->id()] = arg_node;
    261         output->AddEdge(arg_node, 0, retval_node, 0);
    262       } else {
    263         // Argument is loop-varying.
    264         node_map[arg.switch_node->id()] = arg_node;
    265         // The Switch node has two outputs, but _Arg only has one. This tells
    266         // the CopySubgraph function to rewrite the output number of edges from
    267         // the _Arg node to be 0 rather than copying the output number from the
    268         // Switch node.
    269         squash_src_outputs[arg.switch_node->id()] = true;
    270         node_map[arg.next_iteration->id()] = retval_node;
    271         next_iterations.push_back(arg.next_iteration);
    272       }
    273     }
    274   }
    275 
    276   // Performs a reverse DFS, copying nodes and edges to the output graph.
    277   // The _Arg and _Retval nodes were added unconditionally above, so we are
    278   // guaranteed to get the correct function signature.
    279   TF_RETURN_IF_ERROR(CopySubgraph(graph, frame, std::move(next_iterations),
    280                                   squash_src_outputs, &node_map, output));
    281 
    282   return Status::OK();
    283 }
    284 
    285 Status FunctionalizeLoop(Graph* graph, Frame* frame,
    286                          FunctionLibraryDefinition* library) {
    287   VLOG(2) << "Frame " << frame->name << " before: "
    288           << dump_graph::DumpGraphToFile("functionalize_before", *graph,
    289                                          library);
    290 
    291   // Split loop-varying Enter nodes with multiple successors. If the same
    292   // Tensor is fed as input to multiple loop arguments, we may end up with a
    293   // shared Enter node. We clone Enter nodes with multiple successors to
    294   // maintain the invariant of a unique Enter node per argument of the final
    295   // loop.
    296   std::vector<Arg> args;
    297   for (const Arg& arg : frame->args) {
    298     if (arg.is_loop_invariant) {
    299       args.push_back(arg);
    300     } else {
    301       std::vector<const Edge*> edges(arg.enter->out_edges().begin(),
    302                                      arg.enter->out_edges().end());
    303       for (int i = 0; i < edges.size(); ++i) {
    304         if (edges[i]->IsControlEdge() && edges[i]->dst()->IsSink()) {
    305           continue;
    306         }
    307         TF_RET_CHECK(!edges[i]->IsControlEdge()) << edges[i]->src()->name();
    308         Arg new_arg;
    309         new_arg.is_loop_invariant = false;
    310         if (i == 0) {
    311           new_arg.enter = arg.enter;
    312         } else {
    313           new_arg.enter = graph->CopyNode(arg.enter);
    314           frame->nodes.insert(new_arg.enter);
    315           for (Edge const* e : arg.enter->in_edges()) {
    316             graph->AddEdge(e->src(), e->src_output(), new_arg.enter,
    317                            e->IsControlEdge() ? Graph::kControlSlot : 0);
    318           }
    319           Node* dst = edges[i]->dst();
    320           int dst_input = edges[i]->dst_input();
    321           graph->RemoveEdge(edges[i]);
    322           graph->AddEdge(new_arg.enter, 0, dst, dst_input);
    323         }
    324         args.push_back(new_arg);
    325       }
    326     }
    327   }
    328   frame->args = std::move(args);
    329 
    330   std::sort(
    331       frame->args.begin(), frame->args.end(),
    332       [](const Arg& a, const Arg& b) { return NodeCmp()(a.enter, b.enter); });
    333 
    334   if (frame->loop_cond == nullptr) {
    335     return errors::InvalidArgument("Loop ", frame->name,
    336                                    " has no LoopCond node");
    337   }
    338 
    339   // Find the set of Switch nodes that are successors of the LoopCond.
    340   std::unordered_set<Node*> switches;
    341   for (const Edge* edge : frame->loop_cond->out_edges()) {
    342     if (!edge->IsControlEdge() && IsSwitch(edge->dst()) &&
    343         edge->dst_input() == 1) {
    344       switches.insert(edge->dst());
    345     }
    346   }
    347 
    348   // For each non-constant argument, looks for the following pattern of nodes:
    349   // Enter ----> Merge  -------->  Switch  --> Exit
    350   //               ^                  ^
    351   //               |                  |
    352   //         NextIteration         LoopCond
    353   //               ^                  ^
    354   //               |                  |
    355   //              ...                ...
    356   for (Arg& arg : frame->args) {
    357     if (!arg.is_loop_invariant) {
    358       // Follow the edge from the Enter to Merge.
    359       const Edge* enter_merge = nullptr;
    360       for (const Edge* e : arg.enter->out_edges()) {
    361         // Ignore control-edges to the sink node. These are allowed by the
    362         // graph invariants, although probably they should have been stripped
    363         // off earlier.
    364         if (e->IsControlEdge() && e->dst()->IsSink()) {
    365           continue;
    366         }
    367         if (enter_merge != nullptr) {
    368           return errors::Internal(
    369               "Enter node for loop-varying argument ", arg.enter->name(),
    370               " has multiple successors: ", enter_merge->dst()->name(), " and ",
    371               e->dst()->name());
    372         }
    373         enter_merge = e;
    374       }
    375       if (enter_merge == nullptr) {
    376         return errors::Internal("Enter node for loop-varying argument ",
    377                                 arg.enter->name(), " has zero successors");
    378       }
    379       arg.merge = enter_merge->dst();
    380       if (!IsMerge(arg.merge)) {
    381         return errors::InvalidArgument(
    382             "Successor of Enter node for loop-varying argument ",
    383             arg.merge->name(),
    384             " is not a Merge node; got: ", arg.merge->type_string());
    385       }
    386 
    387       // Find the NextIteration from the merge. There should be two inputs to
    388       // the Merge and the NextIteration should be the other input.
    389       if (arg.merge->input_types().size() != 2) {
    390         return errors::InvalidArgument(
    391             "Unexpected number of inputs to Merge node for loop-varying "
    392             "argument ",
    393             arg.merge->name(), "; expected 2, got ",
    394             arg.merge->input_types().size());
    395       }
    396       TF_RETURN_IF_ERROR(arg.merge->input_node(1 - enter_merge->dst_input(),
    397                                                &arg.next_iteration));
    398       if (!IsNextIteration(arg.next_iteration)) {
    399         return errors::InvalidArgument(
    400             "Expected NextIteration node as input to Merge node; got node ",
    401             arg.next_iteration->name(), " with kind ",
    402             arg.next_iteration->type_string());
    403       }
    404 
    405       // Find the Switch successor of the Merge. There should be exactly one
    406       // Switch node that is a successor of both the Merge and the LoopCond.
    407       for (const Edge* edge : arg.merge->out_edges()) {
    408         if (edge->dst_input() == 0 && IsSwitch(edge->dst()) &&
    409             switches.find(edge->dst()) != switches.end()) {
    410           if (arg.switch_node != nullptr) {
    411             return errors::InvalidArgument("Duplicate Switch successors to ",
    412                                            arg.merge->name());
    413           }
    414           arg.switch_node = edge->dst();
    415         }
    416       }
    417       if (arg.switch_node == nullptr) {
    418         return errors::InvalidArgument("Missing Switch successor to ",
    419                                        arg.merge->name());
    420       }
    421 
    422       // Update the device on the Identity outputs of the switch to match their
    423       // target. These Identity outputs do not
    424 
    425       // Loop over the switch node's output to:
    426       // - Find the Exit successor.
    427       // - Set the sharding on all Identity outputs of the switch. These
    428       //   identity nodes are values used by the loop body or condition.
    429       //   The Identity node may have the wrong device so copy the device from
    430       //   one of its outputs instead.
    431       std::deque<const Edge*> possible_exit;
    432       for (const Edge* edge : arg.switch_node->out_edges()) {
    433         if (edge->src_output() == 0) {
    434           possible_exit.push_back(edge);
    435         }
    436         if (IsIdentity(edge->dst())) {
    437           TF_RETURN_IF_ERROR(
    438               SetNodeShardingFromNeighbors(edge->dst(), /*out_edges=*/true));
    439         }
    440       }
    441       // TODO(b/67425339): Allow general graph between switch and exit.
    442       while (!possible_exit.empty()) {
    443         const Edge* edge = possible_exit.front();
    444         possible_exit.pop_front();
    445         if (IsExit(edge->dst())) {
    446           if (arg.exit != nullptr) {
    447             return errors::InvalidArgument("Duplicate Exit successors to ",
    448                                            arg.switch_node->name());
    449           }
    450           arg.exit = edge->dst();
    451         } else {
    452           if (!IsIdentity(edge->dst())) {
    453             return errors::Unimplemented("General graph between switch (",
    454                                          arg.switch_node->name(),
    455                                          ") and exit node of frame ",
    456                                          frame->name, " not supported yet.");
    457           }
    458           for (const Edge* out : edge->dst()->out_edges()) {
    459             possible_exit.push_back(out);
    460           }
    461         }
    462       }
    463     }
    464   }
    465 
    466   // Builds the condition and body functions.
    467   std::unique_ptr<Graph> cond_graph;
    468   TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph));
    469   DataTypeVector arg_types;
    470   std::unique_ptr<Graph> body_graph;
    471   TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph));
    472 
    473   VLOG(2) << "Frame " << frame->name << " condition: "
    474           << dump_graph::DumpGraphToFile("loop_condition", *cond_graph, library)
    475           << " body: " << dump_graph::DumpGraphToFile("loop_body", *body_graph);
    476 
    477   static std::atomic<int64> sequence_num(0LL);
    478   int64 id = ++sequence_num;
    479   NameAttrList cond_name;
    480   cond_name.set_name(strings::StrCat("_functionalize_cond_", id));
    481   NameAttrList body_name;
    482   body_name.set_name(strings::StrCat("_functionalize_body_", id));
    483   FunctionDef cond_fdef;
    484   TF_RETURN_IF_ERROR(
    485       GraphToFunctionDef(*cond_graph, cond_name.name(), &cond_fdef));
    486   FunctionDef body_fdef;
    487   TF_RETURN_IF_ERROR(
    488       GraphToFunctionDef(*body_graph, body_name.name(), &body_fdef));
    489 
    490   TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef));
    491   TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef));
    492 
    493   // Builds a While operator.
    494   NodeDef while_def;
    495   NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile");
    496   builder.Attr("T", arg_types);
    497   builder.Attr("cond", cond_name);
    498   builder.Attr("body", body_name);
    499   std::vector<NodeDefBuilder::NodeOut> inputs;
    500   for (int i = 0; i < frame->args.size(); ++i) {
    501     const Arg& arg = frame->args[i];
    502     const Edge* in_edge;
    503     TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge));
    504     if (in_edge->IsControlEdge()) {
    505       builder.ControlInput(in_edge->src()->name());
    506     } else {
    507       inputs.push_back(NodeDefBuilder::NodeOut(
    508           in_edge->src()->name(), in_edge->src_output(), arg_types[i]));
    509     }
    510   }
    511   builder.Input(inputs);
    512   TF_RETURN_IF_ERROR(builder.Finalize(&while_def));
    513   TF_ASSIGN_OR_RETURN(Node * while_node, AddNode(while_def, graph));
    514 
    515   // Copies edges to the Enter nodes and from the Exit nodes onto the While.
    516   for (int i = 0; i < frame->args.size(); ++i) {
    517     const Arg& arg = frame->args[i];
    518     const Edge* in_edge;
    519     TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge));
    520     if (in_edge->IsControlEdge()) {
    521       graph->AddControlEdge(in_edge->src(), while_node);
    522     } else {
    523       graph->AddEdge(in_edge->src(), in_edge->src_output(), while_node, i);
    524     }
    525 
    526     if (!arg.is_loop_invariant) {
    527       // Add output edges if the output of the loop is consumed.
    528       if (arg.exit != nullptr) {
    529         std::vector<const Edge*> edges(arg.exit->out_edges().begin(),
    530                                        arg.exit->out_edges().end());
    531         for (const Edge* edge : edges) {
    532           Node* dst = edge->dst();
    533           int dst_input = edge->dst_input();
    534           graph->RemoveEdge(edge);
    535 
    536           if (dst_input == Graph::kControlSlot) {
    537             graph->AddControlEdge(while_node, dst);
    538           } else {
    539             graph->AddEdge(while_node, i, dst, dst_input);
    540           }
    541         }
    542       }
    543     }
    544   }
    545 
    546   // Remove the old nodes from the graph, and add the while node to the parent
    547   // frame.
    548   for (Node* node : frame->nodes) {
    549     graph->RemoveNode(node);
    550   }
    551   frame->nodes.clear();
    552   frame->parent->nodes.insert(while_node);
    553 
    554   VLOG(2) << "Frame " << frame->name << " after: "
    555           << dump_graph::DumpGraphToFile("functionalize_after", *graph,
    556                                          library);
    557 
    558   return Status::OK();
    559 }
    560 
    561 class FunctionalizeCond {
    562  public:
    563   // All nodes are assumed to be either in no branch, then branch, else branch,
    564   // or both branches (such as merge nodes).
    565   enum Branch {
    566     kElseBranch = 0,
    567     kThenBranch = 1,
    568     kBoth = 2,
    569     kNeither = 3,
    570     kNumBranchTypes = 4
    571   };
    572 
    573   // Returns a textual representation of the Branch b.
    574   static string Branch_Name(FunctionalizeCond::Branch b);
    575 
    576   // Functionalize all the switch-merge nodes of a loop-free graph into XlaIf
    577   // nodes. That is, attempt to transform every remaining switch and merge nodes
    578   // in the graph into XlaIf nodes.
    579   // Precondition: All while loops have been removed from graph.
    580   static Status Functionalize(Graph* graph, FunctionLibraryDefinition* library);
    581 
    582  private:
    583   // CondArgNode represents a input to the conditional and its corresponding
    584   // switch nodes.
    585   struct CondArgNode {
    586     explicit CondArgNode(Node* input) : input(input) {}
    587     string ToString() const {
    588       return strings::StrCat("input=", input->name(),
    589                              " switches=", NodesToString(switches));
    590     }
    591 
    592     Node* input;
    593     std::vector<Node*> switches;
    594   };
    595   using CondArgNodes = std::vector<CondArgNode>;
    596 
    597   struct ForwardFlowNode {
    598     explicit ForwardFlowNode(Branch branch = Branch::kNeither)
    599         : branch(branch), count(0) {}
    600     string ToString() const {
    601       return strings::StrCat("branch=", Branch_Name(branch), " count=", count);
    602     }
    603     Branch branch;
    604     int count;
    605   };
    606 
    607   // Group of switch nodes that will be part of the same XlaIf.
    608   struct SwitchCluster {
    609     explicit SwitchCluster(Node* predicate) : predicate(predicate) {}
    610     string ToString() const {
    611       return strings::StrCat(name, " predicate=", predicate->name(),
    612                              " switches=", NodesToString(switches));
    613     }
    614 
    615     string name;
    616     Node* predicate;
    617     std::vector<Node*> switches;
    618   };
    619 
    620   FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library,
    621                     bool dump_graphs)
    622       : library_(library), graph_(graph), dump_graphs_(dump_graphs) {}
    623 
    624   // Perform the actual cond functionalization. Iterate over groups of switch
    625   // nodes (linked by common predicate), from innermost to outermost, and
    626   // extract into XlaIf nodes.
    627   Status FunctionalizeInternal();
    628 
    629   // Determines the branch_map (mapping from node to branch of cond) and
    630   // frontier (the nodes where the cond ends).
    631   StatusOr<std::pair<std::unordered_map<Node*, ForwardFlowNode>,
    632                      std::unordered_set<Node*>>>
    633   DetermineBranchMapAndFrontier(const SwitchCluster& switch_cluster);
    634 
    635   // Returns XlaIf node created from subgraph of merge and switch nodes. This
    636   // encapsulates the process of extracting the bodies needed for the then and
    637   // else branch, creates a XlaIf node, removing the nodes of the branches from
    638   // the graph and replacing the merge node with a XlaIf.
    639   StatusOr<Node*> ConvertToXlaIf(const CondArgNodes& cond_arg_nodes,
    640                                  const SwitchCluster& switch_cluster,
    641                                  const std::vector<Node*>& switches);
    642 
    643   // Builds a XlaIfOp to replace the Switch-Graph-Merge cluster with.
    644   StatusOr<Node*> BuildAndAddXlaIfOp(const CondArgNodes& cond_arg_nodes,
    645                                      const SwitchCluster& switch_cluster,
    646                                      const std::vector<Node*>& merge_nodes);
    647 
    648   // Extracts a function body corresponding to the given input edge of the merge
    649   // node.
    650   Status ExtractBody(const CondArgNodes& cond_arg_nodes,
    651                      const std::vector<Node*>& switches,
    652                      const std::vector<Node*>& merge_nodes, int input_edge,
    653                      Graph* body);
    654 
    655   // Adds all the input edges to `if_node` corresponding to the arguments.
    656   Status AddInputEdges(const CondArgNodes& cond_arg_nodes, Node* predicate,
    657                        Node* if_node);
    658 
    659   // Adds all output edges from the `if_node`.
    660   Status AddOutputEdges(const std::vector<Node*>& outputs, Node* if_node);
    661 
    662   // Returns the switch clusters of graph_ in postorder. Dead switch nodes are
    663   // skipped and removed from the graph.
    664   StatusOr<std::vector<SwitchCluster>> DeterminePredicateSwitchOrder();
    665 
    666   // Update the state for destination based on the state of source and the node
    667   // being updated.
    668   Status Join(const ForwardFlowNode& src_state, const Node* dst,
    669               ForwardFlowNode* dst_state);
    670 
    671   // Ensure that all nodes in the branch_map are dominated by the switch
    672   // nodes. Returns nodes that are not dominated by the switches but are a
    673   // control dependency of a node in the cond, and remove such control
    674   // dependencies.
    675   StatusOr<std::vector<Node*>> EnsureDominanceAndReturnNonDominatedControlNodes(
    676       const std::unordered_map<Node*, ForwardFlowNode>& branch_map,
    677       const std::vector<Node*>& switches);
    678 
    679   // Validates that the frontier of nodes for the conditional
    680   // section are as expected.
    681   Status ValidateFrontier(
    682       const std::unordered_map<Node*, ForwardFlowNode>& branch_map,
    683       const std::unordered_set<Node*>& frontier);
    684 
    685   FunctionLibraryDefinition* library_;
    686   Graph* graph_;
    687   bool dump_graphs_;
    688 };
    689 
    690 bool IsDeadSwitch(const Node* node) {
    691   for (const Edge* e : node->out_edges()) {
    692     const Node* dst = e->dst();
    693     if (!dst->IsIdentity()) {
    694       return false;
    695     }
    696     for (const Edge* ee : dst->out_edges()) {
    697       if (!ee->IsControlEdge() || !ee->dst()->IsSink()) {
    698         return false;
    699       }
    700     }
    701   }
    702   return true;
    703 }
    704 
    705 string FunctionalizeCond::Branch_Name(FunctionalizeCond::Branch b) {
    706   const string branch_name[FunctionalizeCond::kNumBranchTypes + 1] = {
    707       "else", "then", "both", "neither", "count"};
    708   return branch_name[b];
    709 }
    710 
    711 Status FunctionalizeCond::ValidateFrontier(
    712     const std::unordered_map<Node*, FunctionalizeCond::ForwardFlowNode>&
    713         branch_map,
    714     const std::unordered_set<Node*>& frontier) {
    715   std::unordered_set<const Node*> pending[kNumBranchTypes];
    716   for (Node* n : frontier) {
    717     pending[branch_map.at(n).branch].insert(n);
    718   }
    719   TF_RET_CHECK(pending[kNeither].empty()) << NodesToString(pending[kNeither]);
    720   for (const Node* n : pending[kBoth]) {
    721     TF_RET_CHECK(IsMerge(n)) << n->DebugString();
    722     // Merge nodes may be in then or else branch too
    723   }
    724   int index = (pending[kThenBranch].size() <= pending[kElseBranch].size())
    725                   ? kThenBranch
    726                   : kElseBranch;
    727   int other = 1 - index;
    728   for (const Node* n : pending[index]) {
    729     if (pending[other].find(n) != pending[other].end()) {
    730       return errors::Internal(
    731           "Node (", n->DebugString().c_str(),
    732           ") in both Else and Then branch should be in Both.");
    733     }
    734   }
    735   // An empty frontier indicates a dead switch. Above we attempt to remove dead
    736   // switch nodes, but not all are removed so don't treat it as an error yet.
    737   // TODO(jpienaar): Find out why dead switch nodes remain.
    738   // if (pending[kBoth].empty() && pending[kThenBranch].empty() &&
    739   //     pending[kElseBranch].empty()) {
    740   //   return errors::Internal("Unexpected empty frontier for switch nodes");
    741   // }
    742   return Status::OK();
    743 }
    744 
    745 Status FunctionalizeCond::Join(const ForwardFlowNode& src_state,
    746                                const Node* dst, ForwardFlowNode* dst_state) {
    747   TF_RET_CHECK(dst_state->branch != Branch::kBoth &&
    748                dst_state->branch != Branch::kNumBranchTypes)
    749       << "Unexpected/Invalid branch type: Merging "
    750       << Branch_Name(src_state.branch) << " with "
    751       << Branch_Name(dst_state->branch);
    752   if (dst_state->branch == Branch::kNeither) {
    753     dst_state->branch = src_state.branch;
    754   } else if (src_state.branch != dst_state->branch &&
    755              src_state.branch != Branch::kNeither) {
    756     if (IsMerge(dst)) {
    757       dst_state->branch = Branch::kBoth;
    758     } else {
    759       return errors::Internal("Illegal merge: ", src_state.ToString(), " with ",
    760                               dst_state->ToString(), " for ",
    761                               dst->DebugString());
    762     }
    763   }
    764   ++dst_state->count;
    765   return Status::OK();
    766 }
    767 
    768 StatusOr<std::vector<FunctionalizeCond::SwitchCluster>>
    769 FunctionalizeCond::DeterminePredicateSwitchOrder() {
    770   struct Cluster {
    771     bool operator==(const Cluster& other) const {
    772       return representative == other.representative;
    773     }
    774     int representative = -1;
    775   };
    776 
    777   // Perform a DFS over the graph and
    778   // * Determine the reverse topological order of the nodes (there should be no
    779   //   cycles at this point so the post-order numbering corresponds to the
    780   //   reverse topological sorting);
    781   // * Identify dead switches;
    782   // * Initialize the cluster's representative;
    783   std::vector<UnionFind<Cluster>> clusters(graph_->num_node_ids());
    784   std::vector<Node*> dead_switches;
    785   std::vector<Node*> switch_order;
    786   std::vector<Node*> rev_topo_sorted_nodes;
    787   DFS(*graph_, nullptr, [&](Node* n) {
    788     clusters[n->id()].Get().representative = n->id();
    789     if (IsSwitch(n)) {
    790       if (IsDeadSwitch(n)) {
    791         dead_switches.push_back(n);
    792       } else {
    793         rev_topo_sorted_nodes.push_back(n);
    794         switch_order.push_back(n);
    795       }
    796     } else if (n->IsOp()) {
    797       // Exclude src and sink nodes from further consideration.
    798       rev_topo_sorted_nodes.push_back(n);
    799     }
    800   });
    801 
    802   std::vector<SwitchCluster> switch_clusters;
    803   // Return early if there are no switches in the graph.
    804   if (switch_order.empty()) {
    805     return switch_clusters;
    806   }
    807 
    808   // Remove all dead switch nodes.
    809   for (Node* n : dead_switches) {
    810     VLOG(2) << "Removing dead switch: " << n->DebugString();
    811     graph_->RemoveNode(n);
    812   }
    813 
    814   // Identify switch nodes that are part of the same control flow context by
    815   // considering the operands of operations: an operation is part of the same
    816   // control context as its operands unless the operation is a switch. Control
    817   // dependencies are considered part of the same control flow context if the
    818   // switch depth is the same (see comment below).
    819 
    820   // entry_cluster records the input cluster to a switch node. This is used when
    821   // merging with a merge node where the dst's cluster is merged with the entry
    822   // cluster of the merge node's cluster (which corresponds to a switch cluster
    823   // and so has an entry cluster).
    824   std::unordered_map<int, UnionFind<Cluster>*> entry_cluster;
    825 
    826   // Returns the output cluster of a node. Where the output cluster is cluster
    827   // where the output of the node is used. For non-merge nodes this is simply
    828   // the cluster they are part of, while for merge nodes it is the entry cluster
    829   // of the cluster they are part of (this will correspond to the entry node of
    830   // a switch node that dominates the merge).
    831   auto find_output_cluster = [&](Node* n) {
    832     UnionFind<Cluster>* cluster = &clusters[n->id()];
    833     if (!IsMerge(n)) return cluster;
    834     auto it = entry_cluster.find(clusters[n->id()].Get().representative);
    835     // If the cluster is not found in the entry_cluster map then an
    836     // instruction not dominated by a switch node has been merged into the
    837     // cluster of the merge. This indicates a failure of the clustering.
    838     CHECK(it != entry_cluster.end())
    839         << "Unable to find entry for n=" << n->id() << " ("
    840         << cluster->Get().representative << ")";
    841     return it->second;
    842   };
    843 
    844   // TODO(jpienaar): This could be combined with DetermineBranchMapAndFrontier.
    845   std::vector<int> switch_depth(graph_->num_node_ids());
    846   for (auto it = rev_topo_sorted_nodes.rbegin();
    847        it != rev_topo_sorted_nodes.rend(); ++it) {
    848     Node* n = *it;
    849 
    850     // Compute switch depth.
    851     int new_switch_depth = 0;
    852     for (const Edge* e : n->in_edges()) {
    853       Node* src = e->src();
    854       new_switch_depth = std::max(
    855           new_switch_depth, switch_depth[src->id()] - (IsMerge(src) ? 1 : 0));
    856     }
    857     switch_depth[n->id()] = new_switch_depth + (IsSwitch(n) ? 1 : 0);
    858 
    859     // Only merge the input operands of a switch. The switch's clustering itself
    860     // is determined by the interaction of the switch's outputs.
    861     if (IsSwitch(n)) {
    862       Node* input;
    863       TF_CHECK_OK(n->input_node(0, &input));
    864       entry_cluster[n->id()] = &clusters[input->id()];
    865       UnionFind<Cluster>* cluster = find_output_cluster(input);
    866       int cluster_depth = switch_depth[cluster->Get().representative];
    867       // Merge the inputs of the switch node with one another. This results in
    868       // predicates and control input residing in the same cluster.
    869       for (const Edge* e : n->in_edges()) {
    870         Node* src = e->src();
    871         UnionFind<Cluster>* src_cluster = find_output_cluster(src);
    872         int src_cluster_depth = switch_depth[src_cluster->Get().representative];
    873         if (cluster_depth != src_cluster_depth) {
    874           return errors::InvalidArgument(
    875               "Unable to functionalize control flow in graph: Switch ('",
    876               n->name(), "') has operands ('", input->name(), "' and '",
    877               src->name(), "') that have different switch depths (",
    878               cluster_depth, " != ", src_cluster_depth, ")");
    879         }
    880         cluster->Merge(src_cluster);
    881       }
    882       continue;
    883     }
    884 
    885     for (const Edge* e : n->in_edges()) {
    886       Node* src = e->src();
    887       if (!src->IsOp()) continue;
    888       UnionFind<Cluster>* cluster = find_output_cluster(src);
    889       // Merge a node with its data operands and with its control operands if
    890       // the src and dst are in the same ControlContext. The ControlContext is
    891       // not explicitly available here, and instead the switch depth is used as
    892       // a proxy here. Due to the invariant that control edges can only be from
    893       // a containing scope to an inner scope or from the inner scope to its
    894       // containing scope (for exit nodes), the switch depth will only match if
    895       // the src and dst are in the same ControlContext. Control edges between
    896       // ControlContexts are handled during the extraction.
    897       int src_id = cluster->Get().representative;
    898       int src_depth = switch_depth[src_id];
    899       if (!e->IsControlEdge() || new_switch_depth == src_depth) {
    900         if (src_depth != new_switch_depth) {
    901           return errors::InvalidArgument(
    902               "Unable to functionalize control flow in graph: Operand ('",
    903               src->name(), "') and operator ('", n->name(),
    904               "') have different switch depths (", src_depth,
    905               " != ", new_switch_depth, ")");
    906         }
    907         cluster->Merge(&clusters[n->id()]);
    908       }
    909     }
    910   }
    911 
    912   if (dump_graphs_) {
    913     // Mark the switch cluster each node is part of.
    914     for (Node* n : graph_->nodes()) {
    915       n->ClearAttr("_XlaFunctionalizeSwitchGroup");
    916       n->AddAttr("_XlaFunctionalizeSwitchGroup",
    917                  clusters[n->id()].Get().representative);
    918     }
    919     LOG(INFO) << "FunctionalizeControlFlow (with_clusters): "
    920               << dump_graph::DumpGraphToFile("functionalize_clustered", *graph_,
    921                                              library_);
    922   }
    923 
    924   // Verify all the nodes of a cluster are at the same depth.
    925   std::unordered_map<int, std::pair<int, Node*>> cluster_to_depth_node;
    926   for (Node* n : graph_->nodes()) {
    927     int depth = switch_depth[n->id()];
    928     int cluster_rep = clusters[n->id()].Get().representative;
    929     auto it = cluster_to_depth_node.find(cluster_rep);
    930     if (it == cluster_to_depth_node.end()) {
    931       cluster_to_depth_node[cluster_rep] = std::make_pair(depth, n);
    932     } else {
    933       if (it->second.first != depth) {
    934         return errors::Internal(
    935             "Illegal clustering created, mismatch in depths:", "\n\t",
    936             n->DebugString(), "(", clusters[n->id()].Get().representative,
    937             ") at depth=", depth, " vs\n\t", it->second.second->DebugString(),
    938             "(", clusters[n->id()].Get().representative, ") at depth ",
    939             it->second.first);
    940       }
    941     }
    942   }
    943 
    944   struct Hash {
    945     size_t operator()(const std::pair<Node*, Cluster>& item) const {
    946       return Hash64Combine(hash<Node*>()(item.first),
    947                            std::hash<int>()(item.second.representative));
    948     }
    949   };
    950 
    951   // Merge Switch nodes with common predicate.
    952   std::unordered_map<std::pair<Node*, Cluster>, int, Hash> predicate_index;
    953   // The nodes in switch_order are in reverse topological order, but the
    954   // clustered switches need not be (i.e., when considered as a cluster one
    955   // element of a cluster may be later in the topological order than another
    956   // node whose cluster is later in the topological order of clustered
    957   // switches).
    958   for (auto it = switch_order.rbegin(); it != switch_order.rend(); ++it) {
    959     Node* pred;
    960     TF_CHECK_OK((*it)->input_node(1, &pred));
    961     auto repr = std::make_pair(pred, clusters[(*it)->id()].Get());
    962     if (predicate_index.find(repr) == predicate_index.end()) {
    963       predicate_index[repr] = switch_clusters.size();
    964       switch_clusters.emplace_back(pred);
    965       // Generate a name by concatenating with the cluster representative as
    966       // there could be multiple switch clusters with the same predicate.
    967       switch_clusters[predicate_index[repr]].name =
    968           strings::StrCat(pred->name(), "_", repr.second.representative, "_If");
    969     }
    970     switch_clusters[predicate_index[repr]].switches.push_back(*it);
    971   }
    972 
    973   return switch_clusters;
    974 }
    975 
    976 StatusOr<std::vector<Node*>>
    977 FunctionalizeCond::EnsureDominanceAndReturnNonDominatedControlNodes(
    978     const std::unordered_map<Node*, ForwardFlowNode>& branch_map,
    979     const std::vector<Node*>& switches) {
    980   std::vector<Node*> old_control_nodes;
    981   for (const auto& kv : branch_map) {
    982     if (kv.second.count != kv.first->in_edges().size()) {
    983       std::vector<const Edge*> delete_edges;
    984       for (const Edge* in : kv.first->in_edges()) {
    985         auto it = branch_map.find(in->src());
    986         if (it == branch_map.end()) {
    987           if (in->IsControlEdge()) {
    988             old_control_nodes.push_back(in->src());
    989             delete_edges.push_back(in);
    990           } else {
    991             if (IsSwitch(in->src())) {
    992               if (std::find(switches.begin(), switches.end(), in->src()) ==
    993                   switches.end()) {
    994                 return errors::Internal(
    995                     "Unexpected switch node found during flow forward: ",
    996                     in->src()->DebugString());
    997               }
    998               continue;
    999             }
   1000             return errors::InvalidArgument(
   1001                 "Value ", kv.first->name(), "'s input, ", in->src()->name(),
   1002                 ", is not dominated by switch nodes ", NodesToString(switches));
   1003           }
   1004         }
   1005       }
   1006       // Remove control edges from nodes that are not dominated by the switch
   1007       // nodes. New control dependencies will be added between these nodes and
   1008       // the XlaIf node inserted.
   1009       for (const Edge* e : delete_edges) {
   1010         graph_->RemoveEdge(e);
   1011       }
   1012     }
   1013   }
   1014   return old_control_nodes;
   1015 }
   1016 
   1017 StatusOr<
   1018     std::pair<std::unordered_map<Node*, FunctionalizeCond::ForwardFlowNode>,
   1019               std::unordered_set<Node*>>>
   1020 FunctionalizeCond::DetermineBranchMapAndFrontier(
   1021     const SwitchCluster& switch_cluster) {
   1022   std::unordered_map<Node*, ForwardFlowNode> branch_map;
   1023   std::unordered_set<Node*> frontier;
   1024   std::vector<Node*> stack = switch_cluster.switches;
   1025   std::vector<bool> visited(graph_->num_node_ids(), false);
   1026   while (!stack.empty()) {
   1027     Node* n = stack.back();
   1028     stack.pop_back();
   1029 
   1030     if (visited[n->id()]) {
   1031       continue;
   1032     }
   1033     visited[n->id()] = true;
   1034 
   1035     // Propagate branch state along each edge of a switch node.
   1036     bool sink_only = true;
   1037     for (const Edge* e : n->out_edges()) {
   1038       Node* out = e->dst();
   1039       if (!out->IsOp()) {
   1040         continue;
   1041       }
   1042       sink_only = false;
   1043       // Propagate branch information.
   1044       ForwardFlowNode& ffn = branch_map[out];
   1045       if (IsSwitch(n)) {
   1046         int index = e->IsControlEdge() ? Branch::kNeither : e->src_output();
   1047         TF_RETURN_IF_ERROR(Join(ForwardFlowNode(Branch(index)), out, &ffn));
   1048       } else {
   1049         TF_RETURN_IF_ERROR(Join(branch_map[n], out, &ffn));
   1050       }
   1051       if (IsMerge(out)) {
   1052         if (out->in_edges().size() == ffn.count) {
   1053           frontier.insert(out);
   1054         }
   1055       } else if (!visited[out->id()]) {
   1056         stack.push_back(out);
   1057       }
   1058     }
   1059     if (sink_only) {
   1060       if (!IsIdentity(n)) {
   1061         VLOG(1) << "Feeding into sink: " << n->DebugString();
   1062       }
   1063     }
   1064   }
   1065 
   1066   if (dump_graphs_) {
   1067     for (const auto& kv : branch_map) {
   1068       // Append attribute to the graph if running with logging to make the
   1069       // changes clearer in the visualization.
   1070       kv.first->AddAttr("_XlaFunctionalizeBranch",
   1071                         Branch_Name(kv.second.branch));
   1072     }
   1073   }
   1074   return std::make_pair(std::move(branch_map), std::move(frontier));
   1075 }
   1076 
   1077 Status FunctionalizeCond::FunctionalizeInternal() {
   1078   TF_ASSIGN_OR_RETURN(std::vector<SwitchCluster> predicate_switch_order,
   1079                       DeterminePredicateSwitchOrder());
   1080 
   1081   // Iterate from innermost set of clustered switches to outermost, replacing
   1082   // matching switch->merge subgraphs with single XlaIf nodes.
   1083   for (auto it = predicate_switch_order.rbegin();
   1084        it != predicate_switch_order.rend(); ++it) {
   1085     auto& ps = *it;
   1086     VLOG(3) << "Flow down from: " << NodesToString(ps.switches) << " ("
   1087             << ps.predicate->name() << ")";
   1088 
   1089     std::unordered_map<Node*, ForwardFlowNode> branch_map;
   1090     std::unordered_set<Node*> frontier;
   1091     TF_ASSIGN_OR_RETURN(std::tie(branch_map, frontier),
   1092                         DetermineBranchMapAndFrontier(ps));
   1093 
   1094     if (dump_graphs_)
   1095       LOG(INFO) << "FunctionalizeControlFlow (before XlaIf conversion): "
   1096                 << dump_graph::DumpGraphToFile("functionalize_bc", *graph_,
   1097                                                library_);
   1098     TF_RETURN_IF_ERROR(ValidateFrontier(branch_map, frontier));
   1099 
   1100     // Sort the merge and switch nodes using NodeCmp. The switch-nodes are
   1101     // further grouped (post sorting) by input to the switch node as in the
   1102     // functionalized form each input will be passed in only once. This grouping
   1103     // should retain the sorted order.
   1104     CondArgNodes cond_arg_nodes;
   1105     std::unordered_map<Node*, int> input_index;
   1106     std::sort(ps.switches.begin(), ps.switches.end(), NodeCmp());
   1107     for (Node* switch_node : ps.switches) {
   1108       Node* in;
   1109       TF_RETURN_IF_ERROR(switch_node->input_node(0, &in));
   1110       if (input_index.find(in) == input_index.end()) {
   1111         input_index[in] = cond_arg_nodes.size();
   1112         cond_arg_nodes.emplace_back(in);
   1113       }
   1114       cond_arg_nodes.at(input_index.at(in)).switches.push_back(switch_node);
   1115     }
   1116     std::vector<Node*> merge_nodes(frontier.begin(), frontier.end());
   1117     std::sort(merge_nodes.begin(), merge_nodes.end(), NodeCmp());
   1118 
   1119     TF_ASSIGN_OR_RETURN(std::vector<Node*> old_control_nodes,
   1120                         EnsureDominanceAndReturnNonDominatedControlNodes(
   1121                             branch_map, ps.switches));
   1122 
   1123     TF_ASSIGN_OR_RETURN(Node * if_node,
   1124                         ConvertToXlaIf(cond_arg_nodes, ps, merge_nodes));
   1125     for (Node* old : old_control_nodes) {
   1126       graph_->AddControlEdge(old, if_node);
   1127     }
   1128 
   1129     for (auto& del_kv : branch_map) {
   1130       graph_->RemoveNode(del_kv.first);
   1131     }
   1132     for (auto& kv : cond_arg_nodes) {
   1133       for (Node* node : kv.switches) {
   1134         graph_->RemoveNode(node);
   1135       }
   1136     }
   1137     if (dump_graphs_)
   1138       LOG(INFO) << "FunctionalizeControlFlow (after XlaIf conversion): "
   1139                 << dump_graph::DumpGraphToFile("functionalize_ac", *graph_,
   1140                                                library_);
   1141   }
   1142   return Status::OK();
   1143 }
   1144 
   1145 StatusOr<Node*> FunctionalizeCond::BuildAndAddXlaIfOp(
   1146     const CondArgNodes& cond_arg_nodes, const SwitchCluster& switch_cluster,
   1147     const std::vector<Node*>& merge_nodes) {
   1148   VLOG(2) << "Build if op for " << switch_cluster.name;
   1149 
   1150   NodeDef if_def;
   1151   // Create a new If node using the name of the merge node.
   1152   NodeDefBuilder builder(switch_cluster.name, "XlaIf");
   1153   string branch[] = {"else_branch", "then_branch"};
   1154   for (int i = 0; i < 2; ++i) {
   1155     static std::atomic<int64> sequence_num(0LL);
   1156     int64 id = ++sequence_num;
   1157 
   1158     NameAttrList body_name;
   1159     body_name.set_name(
   1160         strings::StrCat("_functionalize_if_", branch[i], "_", id));
   1161     auto body = xla::MakeUnique<Graph>(graph_->op_registry());
   1162     TF_RETURN_IF_ERROR(ExtractBody(cond_arg_nodes, switch_cluster.switches,
   1163                                    merge_nodes, i, body.get()));
   1164     VLOG(3) << "Body " << branch[i] << ": " << DebugString(body.get());
   1165     FunctionDef body_fdef;
   1166     TF_RETURN_IF_ERROR(GraphToFunctionDef(*body, body_name.name(), &body_fdef));
   1167     TF_RETURN_IF_ERROR(library_->AddFunctionDef(body_fdef));
   1168     builder.Attr(branch[i], body_name);
   1169   }
   1170 
   1171   // Build input type.
   1172   std::vector<NodeDefBuilder::NodeOut> inputs;
   1173   DataTypeVector in_arg_types;
   1174   for (auto& kv : cond_arg_nodes) {
   1175     bool inserted = false;
   1176     for (const Node* arg : kv.switches) {
   1177       const Edge* in_edge;
   1178       TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge));
   1179       if (in_edge->IsControlEdge()) {
   1180         builder.ControlInput(in_edge->src()->name());
   1181       } else {
   1182         if (!inserted) {
   1183           DataType dtype = arg->input_type(0);
   1184           inputs.emplace_back(NodeDefBuilder::NodeOut(
   1185               in_edge->src()->name(), in_edge->src_output(), dtype));
   1186           in_arg_types.push_back(dtype);
   1187           inserted = true;
   1188         }
   1189       }
   1190     }
   1191   }
   1192   builder.Attr("Tin", in_arg_types);
   1193 
   1194   // Build output type.
   1195   DataTypeVector out_type;
   1196   for (const Node* merge : merge_nodes) {
   1197     DataType dtype = merge->output_type(0);
   1198     out_type.push_back(dtype);
   1199   }
   1200   builder.Attr("Tout", out_type);
   1201 
   1202   builder.Attr("Tcond", DT_BOOL);
   1203   builder.Device(switch_cluster.predicate->assigned_device_name());
   1204   // Conditional should be the first input ...
   1205   builder.Input(
   1206       NodeDefBuilder::NodeOut(switch_cluster.predicate->name(), 0,
   1207                               switch_cluster.predicate->output_type(0)));
   1208   // ... followed by the other inputs.
   1209   builder.Input(inputs);
   1210 
   1211   TF_RETURN_IF_ERROR(builder.Finalize(&if_def));
   1212   TF_ASSIGN_OR_RETURN(Node * if_node, AddNode(if_def, graph_));
   1213   return if_node;
   1214 }
   1215 
   1216 Status FunctionalizeCond::ExtractBody(const CondArgNodes& cond_arg_nodes,
   1217                                       const std::vector<Node*>& switches,
   1218                                       const std::vector<Node*>& merge_nodes,
   1219                                       int input_edge, Graph* body) {
   1220   VLOG(2) << "ExtractBody for " << NodesToString(merge_nodes) << " along edge "
   1221           << input_edge;
   1222   std::vector<bool> squash_src_outputs(graph_->num_node_ids(), false);
   1223   std::vector<Node*> node_map(graph_->num_node_ids(), nullptr);
   1224   int arg_count = 0;
   1225   for (auto& kv : cond_arg_nodes) {
   1226     Node* arg_node = nullptr;
   1227     for (const auto* arg : kv.switches) {
   1228       DataType dtype = arg->input_type(0);
   1229       if (arg_node == nullptr) {
   1230         TF_ASSIGN_OR_RETURN(arg_node, BuildArgNode(body, dtype, arg_count++));
   1231       }
   1232       node_map.at(arg->id()) = arg_node;
   1233       squash_src_outputs.at(arg->id()) = true;
   1234     }
   1235   }
   1236 
   1237   std::vector<Node*> stack;
   1238   stack.reserve(merge_nodes.size());
   1239   for (int j = 0; j < merge_nodes.size(); ++j) {
   1240     Node* node = merge_nodes[j];
   1241     TF_ASSIGN_OR_RETURN(node_map.at(node->id()),
   1242                         BuildRetvalNode(body, node->output_type(0),
   1243                                         /*index=*/j));
   1244     const Edge* in_edge;
   1245     TF_RETURN_IF_ERROR(node->input_edge(input_edge, &in_edge));
   1246     Node* in = in_edge->src();
   1247     if (node_map.at(in->id()) == nullptr) {
   1248       node_map.at(in->id()) = body->CopyNode(in);
   1249     }
   1250 
   1251     if (std::find(switches.begin(), switches.end(), in) == switches.end()) {
   1252       body->AddEdge(node_map.at(in->id()), in_edge->src_output(),
   1253                     node_map.at(node->id()), 0);
   1254     } else {
   1255       body->AddEdge(node_map.at(in->id()), 0, node_map.at(node->id()), 0);
   1256       // Don't include input nodes that are already just returned in stack.
   1257       continue;
   1258     }
   1259     stack.push_back(in);
   1260   }
   1261 
   1262   return CopySubgraph(*graph_, nullptr, stack, squash_src_outputs, &node_map,
   1263                       body);
   1264 }
   1265 
   1266 Status FunctionalizeCond::AddInputEdges(const CondArgNodes& cond_arg_nodes,
   1267                                         Node* predicate, Node* if_node) {
   1268   VLOG(3) << "AddInputEdges for " << if_node->name();
   1269   int index = 0;
   1270   graph_->AddEdge(predicate, 0, if_node, index++);
   1271   for (auto& kv : cond_arg_nodes) {
   1272     bool inserted = false;
   1273     for (const Node* arg : kv.switches) {
   1274       const Edge* in_edge;
   1275       TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge));
   1276       if (in_edge->IsControlEdge()) {
   1277         graph_->AddControlEdge(in_edge->src(), if_node);
   1278       } else {
   1279         if (!inserted) {
   1280           graph_->AddEdge(in_edge->src(), in_edge->src_output(), if_node,
   1281                           index++);
   1282           inserted = true;
   1283         }
   1284       }
   1285     }
   1286   }
   1287   return Status::OK();
   1288 }
   1289 
   1290 Status FunctionalizeCond::AddOutputEdges(const std::vector<Node*>& outputs,
   1291                                          Node* if_node) {
   1292   VLOG(3) << "AddOutputEdges for " << if_node->name();
   1293   for (int i = 0; i < outputs.size(); ++i) {
   1294     Node* node = outputs[i];
   1295     std::vector<const Edge*> edges(node->out_edges().begin(),
   1296                                    node->out_edges().end());
   1297     for (const Edge* edge : edges) {
   1298       Node* dst = edge->dst();
   1299       int dst_input = edge->dst_input();
   1300 
   1301       if (edge->src_output() > 0) {
   1302         return errors::Unimplemented("Output of index (", edge->src_output(),
   1303                                      ") of merge node ", node->name());
   1304       }
   1305       graph_->RemoveEdge(edge);
   1306 
   1307       int src_output =
   1308           dst_input == Graph::kControlSlot ? Graph::kControlSlot : i;
   1309       graph_->AddEdge(if_node, src_output, dst, dst_input);
   1310     }
   1311   }
   1312   return Status::OK();
   1313 }
   1314 
   1315 StatusOr<Node*> FunctionalizeCond::ConvertToXlaIf(
   1316     const CondArgNodes& cond_arg_nodes, const SwitchCluster& switch_cluster,
   1317     const std::vector<Node*>& merge_nodes) {
   1318   VLOG(1) << "ConvertToXlaIf for " << switch_cluster.ToString() << " -> "
   1319           << NodesToString(merge_nodes);
   1320 
   1321   // Extract bodies and builds a If operator.
   1322   TF_ASSIGN_OR_RETURN(
   1323       Node * if_node,
   1324       BuildAndAddXlaIfOp(cond_arg_nodes, switch_cluster, merge_nodes));
   1325   TF_RETURN_IF_ERROR(
   1326       AddInputEdges(cond_arg_nodes, switch_cluster.predicate, if_node));
   1327   TF_RETURN_IF_ERROR(AddOutputEdges(merge_nodes, if_node));
   1328 
   1329   return if_node;
   1330 }
   1331 
   1332 Status FunctionalizeCond::Functionalize(Graph* graph,
   1333                                         FunctionLibraryDefinition* library) {
   1334   VLOG(1) << "FunctionalizeCond::Functionalize";
   1335   FunctionalizeCond fc(graph, library, /*dump_graphs=*/VLOG_IS_ON(2));
   1336   return fc.FunctionalizeInternal();
   1337 }
   1338 
   1339 }  // namespace
   1340 
   1341 // Transformation that converts TensorFlow's graph control flow constructs into
   1342 // functional equivalents.
   1343 Status FunctionalizeControlFlow(Graph* graph,
   1344                                 FunctionLibraryDefinition* library) {
   1345   VLOG(2) << "FunctionalizeControlFlow (initial): "
   1346           << dump_graph::DumpGraphToFile("functionalize_initial", *graph,
   1347                                          library);
   1348   // Note: BuildControlFlowInfo() requires that the graph's source node is
   1349   // connected to all source nodes in the graph. Many graphs violate this
   1350   // invariant.
   1351   std::vector<ControlFlowInfo> cf_info;
   1352   TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info));
   1353 
   1354   // Builds Frames, indexed by name.
   1355   std::unordered_map<string, Frame> frames;
   1356   for (Node* node : graph->op_nodes()) {
   1357     const ControlFlowInfo& cf = cf_info[node->id()];
   1358 
   1359     VLOG(2) << "node: " << node->name() << " (" << node->id()
   1360             << ") frame_name: " << cf.frame_name
   1361             << " frame: " << (cf.frame ? cf.frame->name() : "---")
   1362             << " parent_frame: "
   1363             << (cf.parent_frame ? cf.parent_frame->name() : "---");
   1364     TF_RET_CHECK(cf.frame != nullptr && cf.parent_frame != nullptr);
   1365 
   1366     Frame& frame = frames[cf.frame_name];
   1367     Frame* parent = &frames[cf_info[cf.parent_frame->id()].frame_name];
   1368     if (frame.parent == nullptr) {
   1369       frame.parent = parent;
   1370       frame.name = cf.frame_name;
   1371       ++parent->num_children;
   1372     } else if (frame.parent != parent) {
   1373       return errors::InvalidArgument("Mismatched parent frames for ",
   1374                                      cf.frame->id(), ": ", parent->name, " vs ",
   1375                                      frame.parent->name);
   1376     }
   1377 
   1378     if (IsEnter(node)) {
   1379       Arg arg;
   1380       arg.enter = node;
   1381       TF_RETURN_IF_ERROR(GetNodeAttr(arg.enter->attrs(), "is_constant",
   1382                                      &arg.is_loop_invariant));
   1383       frame.args.push_back(arg);
   1384     } else if (IsLoopCond(node)) {
   1385       if (frame.loop_cond) {
   1386         return errors::InvalidArgument(
   1387             "Loop ", cf.frame_name,
   1388             " has more than one LoopCond node: ", node->name(), " and ",
   1389             frame.loop_cond->name());
   1390       }
   1391       frame.loop_cond = node;
   1392     }
   1393     frame.nodes.insert(node);
   1394   }
   1395 
   1396   // Adds frames with no children (i.e., the innermost frames) to a worklist.
   1397   std::deque<Frame*> worklist;
   1398   for (auto& frame : frames) {
   1399     if (frame.second.num_children == 0) {
   1400       worklist.push_back(&frame.second);
   1401     }
   1402   }
   1403 
   1404   // Eliminate loops from innermost to outermost.
   1405   while (!worklist.empty()) {
   1406     Frame* frame = worklist.front();
   1407     worklist.pop_front();
   1408     if (frame->parent == frame) {
   1409       // Skip the root frame.
   1410       continue;
   1411     }
   1412 
   1413     TF_RETURN_IF_ERROR(FunctionalizeLoop(graph, frame, library));
   1414 
   1415     // If the parent has no remaining children, add it to the worklist.
   1416     --frame->parent->num_children;
   1417     if (frame->parent->num_children == 0) {
   1418       worklist.push_back(frame->parent);
   1419     }
   1420   }
   1421 
   1422   // FunctionalizeControlFlow is invoked for every function, so the loops's
   1423   // bodies and conditionals that were extracted into functions will be handled
   1424   // in successive invocations.
   1425   TF_RETURN_IF_ERROR(FunctionalizeCond::Functionalize(graph, library));
   1426 
   1427   VLOG(2) << "FunctionalizeControlFlow (final): "
   1428           << dump_graph::DumpGraphToFile("functionalize_final", *graph,
   1429                                          library);
   1430   return Status::OK();
   1431 }
   1432 
   1433 }  // namespace tensorflow
   1434