Home | History | Annotate | Download | only in tf2xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
     17 
     18 #include "tensorflow/cc/framework/ops.h"
     19 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
     20 #include "tensorflow/cc/ops/function_ops.h"
     21 #include "tensorflow/cc/ops/resource_variable_ops.h"
     22 #include "tensorflow/cc/ops/standard_ops.h"
     23 #include "tensorflow/compiler/tf2xla/cc/ops/functional_ops.h"
     24 #include "tensorflow/compiler/tf2xla/test_util.h"
     25 #include "tensorflow/compiler/xla/status_macros.h"
     26 #include "tensorflow/core/common_runtime/function.h"
     27 #include "tensorflow/core/framework/function.h"
     28 #include "tensorflow/core/framework/node_def_util.h"
     29 #include "tensorflow/core/framework/op.h"
     30 #include "tensorflow/core/graph/graph_constructor.h"
     31 #include "tensorflow/core/graph/graph_def_builder.h"
     32 #include "tensorflow/core/lib/core/status_test_util.h"
     33 #include "tensorflow/core/platform/test.h"
     34 #include "tensorflow/core/util/equal_graph_def.h"
     35 
     36 namespace tensorflow {
     37 namespace {
     38 
     39 // Returns the names of the "then" and "else" functions for the XlaIf node in a
     40 // graph.
     41 Status FindIfThenAndElse(const GraphDef& graph, string* op_name,
     42                          NameAttrList* then_fn, NameAttrList* else_fn) {
     43   for (const NodeDef& node : graph.node()) {
     44     if (node.op() == "XlaIf") {
     45       *op_name = node.name();
     46       const NameAttrList* result;
     47       TF_RETURN_IF_ERROR(GetNodeAttr(node, "then_branch", &result));
     48       *then_fn = *result;
     49       TF_RETURN_IF_ERROR(GetNodeAttr(node, "else_branch", &result));
     50       *else_fn = *result;
     51       return Status::OK();
     52     }
     53   }
     54   return errors::NotFound("No XlaIf node found in graph");
     55 }
     56 
     57 // Graph:
     58 // x = array_ops.placeholder(dtypes.int32)
     59 // y = array_ops.placeholder(dtypes.int32)
     60 // z = control_flow_ops.cond(
     61 //     math_ops.less(y, x), lambda: math_ops.multiply(y, 17),
     62 //     lambda: math_ops.add(x, 23))
     63 TEST(FunctionalizeControlFlow, Conditional) {
     64   Graph graph(OpRegistry::Global());
     65   {
     66     Scope scope = Scope::NewRootScope().ExitOnError();
     67 
     68     auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
     69     auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32);
     70     auto less = ops::Less(scope.WithOpName("cond/Less"), y, x);
     71     auto switch_1 = ops::Switch(scope.WithOpName("cond/Switch"), less, less);
     72 
     73     auto identity_t =
     74         ops::Identity(scope.WithOpName("cond/Identity"), switch_1.output_true);
     75     auto seventeen = ops::Const<int32>(
     76         scope.WithOpName("cond").WithControlDependencies(identity_t), 17);
     77     auto switch_2 = ops::Switch(scope.WithOpName("cond/Switch"), y, less);
     78     auto mul = ops::Multiply(scope.WithOpName("cond/Mul"), switch_2.output_true,
     79                              seventeen);
     80 
     81     auto identity_f =
     82         ops::Identity(scope.WithOpName("cond/Identity"), switch_1.output_false);
     83     auto twenty_three = ops::Const<int32>(
     84         scope.WithOpName("cond").WithControlDependencies(identity_f), 23);
     85     auto switch_3 = ops::Switch(scope.WithOpName("cond/Switch"), x, less);
     86     auto add = ops::Add(scope.WithOpName("cond/false/add"),
     87                         switch_3.output_false, twenty_three);
     88 
     89     auto merge = ops::Merge(scope.WithOpName("cond/Merge"),
     90                             std::initializer_list<Input>{add, mul});
     91 
     92     TF_EXPECT_OK(scope.ToGraph(&graph));
     93   }
     94 
     95   FunctionLibraryDefinition library(OpRegistry::Global(), {});
     96   TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
     97 
     98   GraphDef graph_def;
     99   graph.ToGraphDef(&graph_def);
    100   string op_name;
    101   NameAttrList then_fn;
    102   NameAttrList else_fn;
    103   TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn));
    104   InstantiationResultForTest else_result;
    105   TF_EXPECT_OK(
    106       InstantiateFunctionForTest(else_fn.name(), library, &else_result));
    107 
    108   // Outer graph
    109   {
    110     Scope scope = Scope::NewRootScope().ExitOnError();
    111     auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32);
    112     auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
    113     auto less = ops::Less(scope.WithOpName("cond/Less"), y, x);
    114     auto if_op = ops::XlaIf(scope.WithOpName(op_name), less,
    115                             std::initializer_list<Input>{less, y, x}, then_fn,
    116                             else_fn, {DT_INT32});
    117     GraphDef expected;
    118     TF_EXPECT_OK(scope.ToGraphDef(&expected));
    119     TF_EXPECT_GRAPH_EQ(expected, graph_def);
    120   }
    121 
    122   // then body.
    123   {
    124     Scope scope = Scope::NewRootScope().ExitOnError();
    125     auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0);
    126     auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
    127     auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
    128     auto identity = ops::Identity(scope.WithOpName("cond/Identity"), arg_0);
    129     auto cond = ops::Const(
    130         scope.WithOpName("cond").WithControlDependencies(identity), 17);
    131     auto mul = ops::Mul(scope.WithOpName("cond/Mul"), arg_1, cond);
    132     auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), mul, 0);
    133 
    134     GraphDef expected;
    135     TF_EXPECT_OK(scope.ToGraphDef(&expected));
    136 
    137     InstantiationResultForTest result;
    138     TF_EXPECT_OK(InstantiateFunctionForTest(then_fn.name(), library, &result));
    139 
    140     EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
    141     EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types);
    142     TF_EXPECT_GRAPH_EQ(expected, result.gdef);
    143   }
    144 
    145   // else body.
    146   {
    147     Scope scope = Scope::NewRootScope().ExitOnError();
    148     auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0);
    149     auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
    150     auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
    151     auto identity = ops::Identity(scope.WithOpName("cond/Identity_1"), arg_0);
    152     auto cond_1 = ops::Const(
    153         scope.WithOpName("cond_1").WithControlDependencies(identity), 23);
    154     auto add = ops::Add(scope.WithOpName("cond/false/add"), arg_2, cond_1);
    155     auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0);
    156 
    157     GraphDef expected;
    158     TF_EXPECT_OK(scope.ToGraphDef(&expected));
    159 
    160     InstantiationResultForTest result;
    161     TF_EXPECT_OK(InstantiateFunctionForTest(else_fn.name(), library, &result));
    162 
    163     EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
    164     EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types);
    165     TF_EXPECT_GRAPH_EQ(expected, result.gdef);
    166   }
    167 }
    168 
    169 // Returns the names of the "cond" and "body" functions for the While node
    170 // in a graph.
    171 Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond,
    172                             NameAttrList* body) {
    173   for (const NodeDef& node : graph.node()) {
    174     if (node.op() == "XlaWhile") {
    175       const NameAttrList* result;
    176       TF_RETURN_IF_ERROR(GetNodeAttr(node, "cond", &result));
    177       *cond = *result;
    178       TF_RETURN_IF_ERROR(GetNodeAttr(node, "body", &result));
    179       *body = *result;
    180       return Status::OK();
    181     }
    182   }
    183   return errors::NotFound("No XlaWhile node found in graph");
    184 }
    185 
    186 // Graph:
    187 // x = array_ops.placeholder(dtypes.int32)
    188 // y = control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [x])
    189 TEST(FunctionalizeControlFlow, OneLoopVar) {
    190   Graph graph(OpRegistry::Global());
    191   {
    192     Scope scope = Scope::NewRootScope().ExitOnError();
    193 
    194     auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
    195 
    196     auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
    197     auto enter =
    198         ops::internal::Enter(scope.WithOpName("while/Enter"), source, "aloop");
    199     // Add an unused Enter node. These should be ignored.
    200     auto enter2 =
    201         ops::internal::Enter(scope.WithOpName("while/Enter2"), source, "aloop");
    202     auto merge = ops::Merge(scope.WithOpName("while/Merge"),
    203                             std::initializer_list<Input>{enter, dummy});
    204     auto ten = ops::Const<int32>(
    205         scope.WithOpName("while/Less/y").WithControlDependencies(merge.output),
    206         10);
    207     auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten);
    208     auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
    209     auto switch_ =
    210         ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond);
    211     auto exit = ops::internal::Exit(scope.WithOpName("while/Exit"),
    212                                     switch_.output_false);
    213     auto identity =
    214         ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true);
    215     auto one = ops::Const<int32>(
    216         scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
    217     auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
    218     auto next_iteration =
    219         ops::NextIteration(scope.WithOpName("while/NextIteration"), add);
    220 
    221     auto sink = ops::Identity(scope.WithOpName("sink"), exit);
    222 
    223     // Remove the dummy node and add the loop backedge.
    224     scope.graph()->RemoveNode(dummy.node());
    225     scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1);
    226 
    227     TF_EXPECT_OK(scope.ToGraph(&graph));
    228   }
    229 
    230   // Regression test: control edges from an Enter node to the graph sink should
    231   // be ignored.
    232   for (Node* n : graph.nodes()) {
    233     if (n->name() == "while/Enter") {
    234       graph.AddControlEdge(n, graph.sink_node());
    235     }
    236   }
    237 
    238   FunctionLibraryDefinition library(OpRegistry::Global(), {});
    239   TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
    240 
    241   GraphDef graph_def;
    242   graph.ToGraphDef(&graph_def);
    243 
    244   NameAttrList cond_fn, body_fn;
    245   TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
    246 
    247   // Outer graph
    248   {
    249     Scope scope = Scope::NewRootScope().ExitOnError();
    250     auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
    251     auto while_op =
    252         ops::XlaWhile(scope.WithOpName("while/LoopCond"),
    253                       std::initializer_list<Input>{source}, cond_fn, body_fn);
    254     auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]);
    255     GraphDef expected;
    256     TF_EXPECT_OK(scope.ToGraphDef(&expected));
    257     TF_EXPECT_GRAPH_EQ(expected, graph_def);
    258   }
    259 
    260   // Condition graph
    261   {
    262     Scope scope = Scope::NewRootScope().ExitOnError();
    263     auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
    264     auto ten = ops::Const<int32>(
    265         scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10);
    266     auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten);
    267     auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0);
    268 
    269     GraphDef expected;
    270     TF_EXPECT_OK(scope.ToGraphDef(&expected));
    271 
    272     InstantiationResultForTest result;
    273     TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result));
    274 
    275     EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
    276     EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
    277     TF_EXPECT_GRAPH_EQ(expected, result.gdef);
    278   }
    279 
    280   // Body graph.
    281   {
    282     Scope scope = Scope::NewRootScope().ExitOnError();
    283     auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
    284     auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg);
    285     auto one = ops::Const<int32>(
    286         scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
    287     auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
    288     auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0);
    289 
    290     GraphDef expected;
    291     TF_EXPECT_OK(scope.ToGraphDef(&expected));
    292 
    293     InstantiationResultForTest result;
    294     TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result));
    295 
    296     EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
    297     EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
    298     TF_EXPECT_GRAPH_EQ(expected, result.gdef);
    299   }
    300 }
    301 
    302 // Tests functionalizing OneLoopVar where the loop value is not used post the
    303 // loop.
    304 // Graph:
    305 // x = array_ops.placeholder(dtypes.int32)
    306 // control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [x])
    307 TEST(FunctionalizeControlFlow, OneLoopVarWithoutExit) {
    308   Graph graph(OpRegistry::Global());
    309   {
    310     Scope scope = Scope::NewRootScope().ExitOnError();
    311 
    312     auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
    313 
    314     auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
    315     auto enter =
    316         ops::internal::Enter(scope.WithOpName("while/Enter"), source, "aloop");
    317     auto merge = ops::Merge(scope.WithOpName("while/Merge"),
    318                             std::initializer_list<Input>{enter, dummy});
    319     auto ten = ops::Const<int32>(
    320         scope.WithOpName("while/Less/y").WithControlDependencies(merge.output),
    321         10);
    322     auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten);
    323     auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
    324     auto switch_ =
    325         ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond);
    326     auto identity =
    327         ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true);
    328     auto one = ops::Const<int32>(
    329         scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
    330     auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
    331     auto next_iteration =
    332         ops::NextIteration(scope.WithOpName("while/NextIteration"), add);
    333 
    334     // Remove the dummy node and add the loop backedge.
    335     scope.graph()->RemoveNode(dummy.node());
    336     scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1);
    337 
    338     TF_EXPECT_OK(scope.ToGraph(&graph));
    339   }
    340 
    341   FunctionLibraryDefinition library(OpRegistry::Global(), {});
    342   TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
    343 
    344   GraphDef graph_def;
    345   graph.ToGraphDef(&graph_def);
    346 
    347   NameAttrList cond_fn, body_fn;
    348   TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
    349 
    350   // Outer graph
    351   {
    352     Scope scope = Scope::NewRootScope().ExitOnError();
    353     auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
    354     auto while_op =
    355         ops::XlaWhile(scope.WithOpName("while/LoopCond"),
    356                       std::initializer_list<Input>{source}, cond_fn, body_fn);
    357     GraphDef expected;
    358     TF_EXPECT_OK(scope.ToGraphDef(&expected));
    359     TF_EXPECT_GRAPH_EQ(expected, graph_def);
    360   }
    361 
    362   // Condition graph
    363   {
    364     Scope scope = Scope::NewRootScope().ExitOnError();
    365     auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
    366     auto ten = ops::Const<int32>(
    367         scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10);
    368     auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten);
    369     auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0);
    370 
    371     GraphDef expected;
    372     TF_EXPECT_OK(scope.ToGraphDef(&expected));
    373 
    374     InstantiationResultForTest result;
    375     TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result));
    376 
    377     EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
    378     EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
    379     TF_EXPECT_GRAPH_EQ(expected, result.gdef);
    380   }
    381 
    382   // Body graph.
    383   {
    384     Scope scope = Scope::NewRootScope().ExitOnError();
    385     auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
    386     auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg);
    387     auto one = ops::Const<int32>(
    388         scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
    389     auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
    390     auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0);
    391 
    392     GraphDef expected;
    393     TF_EXPECT_OK(scope.ToGraphDef(&expected));
    394 
    395     InstantiationResultForTest result;
    396     TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result));
    397 
    398     EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
    399     EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
    400     TF_EXPECT_GRAPH_EQ(expected, result.gdef);
    401   }
    402 }
    403 
    404 // Graph:
    405 // x = array_ops.placeholder(dtypes.int32)
    406 // y = array_ops.placeholder(dtypes.int32)
    407 // cond = lambda (i, j): i + 3 < 10
    408 // body = lambda (i, j): (i < 10, j * 2)
    409 // z = control_flow_ops.while_loop(cond, body, [x, y])
    410 TEST(FunctionalizeControlFlow, TwoLoopVars) {
    411   Graph graph(OpRegistry::Global());
    412   {
    413     Scope scope = Scope::NewRootScope().ExitOnError();
    414 
    415     auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
    416 
    417     auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32);
    418     auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32);
    419     auto enter_x =
    420         ops::internal::Enter(scope.WithOpName("while/Enter/x"), x, "aloop");
    421     auto enter_y =
    422         ops::internal::Enter(scope.WithOpName("while/Enter/y"), y, "aloop");
    423     auto merge_x = ops::Merge(scope.WithOpName("while/Merge/x"),
    424                               std::initializer_list<Input>{enter_x, dummy});
    425     auto merge_y = ops::Merge(scope.WithOpName("while/Merge/y"),
    426                               std::initializer_list<Input>{enter_y, dummy});
    427 
    428     // Loop condition
    429     auto three = ops::Const<int32>(scope.WithOpName("while/cond/three")
    430                                        .WithControlDependencies(merge_x.output),
    431                                    3);
    432     auto cond_add =
    433         ops::Add(scope.WithOpName("while/cond/Add"), merge_x.output, three);
    434     auto ten = ops::Const<int32>(scope.WithOpName("while/cond/ten")
    435                                      .WithControlDependencies(merge_x.output),
    436                                  10);
    437     auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten);
    438     auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
    439 
    440     auto switch_x = ops::Switch(scope.WithOpName("while/Switch/x"),
    441                                 merge_x.output, loop_cond);
    442     auto switch_y = ops::Switch(scope.WithOpName("while/Switch/y"),
    443                                 merge_y.output, loop_cond);
    444 
    445     auto exit_x = ops::internal::Exit(scope.WithOpName("while/Exit/x"),
    446                                       switch_x.output_false);
    447     auto exit_y = ops::internal::Exit(scope.WithOpName("while/Exit/y"),
    448                                       switch_y.output_false);
    449 
    450     auto identity_x = ops::Identity(scope.WithOpName("while/Identity/x"),
    451                                     switch_x.output_true);
    452     auto identity_y = ops::Identity(scope.WithOpName("while/Identity/y"),
    453                                     switch_y.output_true);
    454 
    455     auto one = ops::Const<int32>(
    456         scope.WithOpName("while/add/one").WithControlDependencies(identity_x),
    457         1);
    458     auto two = ops::Const<int32>(
    459         scope.WithOpName("while/mul/two").WithControlDependencies(identity_x),
    460         2);
    461 
    462     auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one);
    463     auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two);
    464     auto next_iteration_x =
    465         ops::NextIteration(scope.WithOpName("while/NextIteration/x"), add);
    466     auto next_iteration_y =
    467         ops::NextIteration(scope.WithOpName("while/NextIteration/y"), mul);
    468 
    469     auto sink_x = ops::Identity(scope.WithOpName("sink_x"), exit_x);
    470     auto sink_y = ops::Identity(scope.WithOpName("sink_y"), exit_y);
    471 
    472     // Remove the dummy node and add the loop backedges.
    473     scope.graph()->RemoveNode(dummy.node());
    474     scope.graph()->AddEdge(next_iteration_x.node(), 0, merge_x.output.node(),
    475                            1);
    476     scope.graph()->AddEdge(next_iteration_y.node(), 0, merge_y.output.node(),
    477                            1);
    478 
    479     TF_EXPECT_OK(scope.ToGraph(&graph));
    480   }
    481 
    482   FunctionLibraryDefinition library(OpRegistry::Global(), {});
    483   TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
    484 
    485   GraphDef graph_def;
    486   graph.ToGraphDef(&graph_def);
    487 
    488   NameAttrList cond_fn, body_fn;
    489   TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
    490 
    491   // Outer graph.
    492   {
    493     Scope scope = Scope::NewRootScope().ExitOnError();
    494     auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32);
    495     auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32);
    496     auto while_op =
    497         ops::XlaWhile(scope.WithOpName("while/LoopCond"),
    498                       std::initializer_list<Input>{x, y}, cond_fn, body_fn);
    499     auto sink_x = ops::Identity(scope.WithOpName("sink_x"), while_op[0]);
    500     auto sink_y = ops::Identity(scope.WithOpName("sink_y"), while_op[1]);
    501     GraphDef expected;
    502     TF_EXPECT_OK(scope.ToGraphDef(&expected));
    503     TF_EXPECT_GRAPH_EQ(expected, graph_def);
    504   }
    505 
    506   // Condition graph.
    507   {
    508     Scope scope = Scope::NewRootScope().ExitOnError();
    509     auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
    510     auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
    511     auto three = ops::Const<int32>(scope.WithOpName("while/cond/three")
    512                                        .WithControlDependencies(arg0.output),
    513                                    3);
    514     auto cond_add =
    515         ops::Add(scope.WithOpName("while/cond/Add"), arg0.output, three);
    516     auto ten = ops::Const<int32>(
    517         scope.WithOpName("while/cond/ten").WithControlDependencies(arg0.output),
    518         10);
    519     auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten);
    520     auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0);
    521 
    522     GraphDef expected;
    523     TF_EXPECT_OK(scope.ToGraphDef(&expected));
    524 
    525     InstantiationResultForTest result;
    526     TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result));
    527 
    528     EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types);
    529     EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
    530     TF_EXPECT_GRAPH_EQ(expected, result.gdef);
    531   }
    532 
    533   // Body graph.
    534   {
    535     Scope scope = Scope::NewRootScope().ExitOnError();
    536     auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
    537     auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
    538 
    539     auto identity_x = ops::Identity(scope.WithOpName("while/Identity/x"), arg0);
    540     auto identity_y = ops::Identity(scope.WithOpName("while/Identity/y"), arg1);
    541 
    542     auto one = ops::Const<int32>(
    543         scope.WithOpName("while/add/one").WithControlDependencies(identity_x),
    544         1);
    545     auto two = ops::Const<int32>(
    546         scope.WithOpName("while/mul/two").WithControlDependencies(identity_x),
    547         2);
    548 
    549     auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one);
    550     auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two);
    551     auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0);
    552     auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), mul, 1);
    553 
    554     GraphDef expected;
    555     TF_EXPECT_OK(scope.ToGraphDef(&expected));
    556 
    557     InstantiationResultForTest result;
    558     TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result));
    559 
    560     EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types);
    561     EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.ret_types);
    562     TF_EXPECT_GRAPH_EQ(expected, result.gdef);
    563   }
    564 }
    565 
    566 // Example with nesting, loop-invariant arguments, and resource variables.
    567 //
    568 // accum = resource_variable_ops.ResourceVariable(1)
    569 // x = array_ops.placeholder(2, dtype=dtypes.int32)
    570 // y = 3 + x
    571 //
    572 // def inner_body(j, k):
    573 //   add = state_ops.assign_add(accum, k * j + x)
    574 //   with ops.control_dependencies([add]):
    575 //     return [j + 1, k]
    576 //
    577 // def body(i):
    578 //   m = control_flow_ops.while_loop(lambda j, k: j < 5, inner_body,
    579 //                                   [1, y], name="inner")
    580 //   with ops.control_dependencies(m):
    581 //     return [i + 1]
    582 //
    583 // z = control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="outer")
    584 TEST(FunctionalizeControlFlow, Complex) {
    585   Graph graph(OpRegistry::Global());
    586   {
    587     Scope scope = Scope::NewRootScope().ExitOnError();
    588 
    589     auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
    590 
    591     auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
    592     auto three = ops::Const<int32>(scope.WithOpName("three"), 3);
    593     auto y = ops::Add(scope.WithOpName("y"), x, three);
    594 
    595     auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32,
    596                                 TensorShape({}));
    597 
    598     // Outer loop
    599     auto zero = ops::Const<int32>(scope.WithOpName("outer/Const"), 0);
    600     auto enter_i =
    601         ops::internal::Enter(scope.WithOpName("outer/Enter_i"), zero, "outer");
    602     auto merge_i = ops::Merge(scope.WithOpName("outer/Merge_i"),
    603                               std::initializer_list<Input>{enter_i, dummy});
    604     auto ten = ops::Const<int32>(scope.WithOpName("outer/Less/y")
    605                                      .WithControlDependencies(merge_i.output),
    606                                  10);
    607     auto less_i =
    608         ops::Less(scope.WithOpName("outer/Less_i"), merge_i.output, ten);
    609     auto outer_loop_cond =
    610         ops::LoopCond(scope.WithOpName("outer/LoopCond"), less_i);
    611     auto switch_i = ops::Switch(scope.WithOpName("outer/Switch"),
    612                                 merge_i.output, outer_loop_cond);
    613     auto exit_i = ops::internal::Exit(scope.WithOpName("outer/Exit"),
    614                                       switch_i.output_false);
    615     auto identity_i =
    616         ops::Identity(scope.WithOpName("outer/Identity"), switch_i.output_true);
    617 
    618     auto enter_x_outer =
    619         ops::internal::Enter(scope.WithOpName("outer/Enter_x"), x, "outer",
    620                              ops::internal::Enter::Attrs().IsConstant(true));
    621     auto enter_k_outer =
    622         ops::internal::Enter(scope.WithOpName("outer/Enter_k"), y, "outer",
    623                              ops::internal::Enter::Attrs().IsConstant(true));
    624     auto enter_var_outer =
    625         ops::internal::Enter(scope.WithOpName("outer/Enter_var"), var, "outer",
    626                              ops::internal::Enter::Attrs().IsConstant(true));
    627 
    628     // Inner loop
    629     auto one_j = ops::Const<int32>(
    630         scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1);
    631     auto enter_j = ops::internal::Enter(scope.WithOpName("outer/inner/Enter_j"),
    632                                         one_j, "inner");
    633     auto enter_k =
    634         ops::internal::Enter(scope.WithOpName("outer/inner/Enter_k")
    635                                  .WithControlDependencies(identity_i),
    636                              enter_k_outer, "inner");
    637     auto enter_x = ops::internal::Enter(
    638         scope.WithOpName("outer/inner/Enter_x"), enter_x_outer, "inner",
    639         ops::internal::Enter::Attrs().IsConstant(true));
    640     auto enter_var = ops::internal::Enter(
    641         scope.WithOpName("outer/inner/Enter_var"), enter_var_outer, "inner",
    642         ops::internal::Enter::Attrs().IsConstant(true));
    643 
    644     auto merge_j = ops::Merge(scope.WithOpName("outer/inner/Merge_j"),
    645                               std::initializer_list<Input>{enter_j, dummy});
    646     auto merge_k = ops::Merge(scope.WithOpName("outer/inner/Merge_k"),
    647                               std::initializer_list<Input>{enter_k, dummy});
    648 
    649     auto five = ops::Const<int32>(scope.WithOpName("outer/inner/Five")
    650                                       .WithControlDependencies(merge_j.output),
    651                                   5);
    652     auto less_j =
    653         ops::Less(scope.WithOpName("outer/inner/Less_j"), merge_j.output, five);
    654     auto loop_cond = ops::LoopCond(scope.WithOpName("outer/LoopCond"), less_j);
    655 
    656     auto switch_j = ops::Switch(scope.WithOpName("outer/inner/Switch_j"),
    657                                 merge_j.output, loop_cond);
    658     auto switch_k = ops::Switch(scope.WithOpName("outer/inner/Switch_k"),
    659                                 merge_k.output, loop_cond);
    660     auto exit_j = ops::internal::Exit(scope.WithOpName("outer/inner/Exit_j"),
    661                                       switch_j.output_false);
    662     auto exit_k = ops::internal::Exit(scope.WithOpName("outer/inner/Exit_k"),
    663                                       switch_k.output_false);
    664     auto identity_j = ops::Identity(scope.WithOpName("outer/inner/Identity_j"),
    665                                     switch_j.output_true);
    666     auto identity_k = ops::Identity(scope.WithOpName("outer/inner/Identity_k"),
    667                                     switch_k.output_true);
    668 
    669     // Variable update
    670     auto mul_jk =
    671         ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k);
    672     auto add_jkx =
    673         ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, enter_x);
    674     auto assign = ops::AssignAddVariableOp(
    675         scope.WithOpName("outer/inner/assign_add"), enter_var, add_jkx);
    676 
    677     auto one =
    678         ops::Const<int32>(scope.WithOpName("outer/inner/One")
    679                               .WithControlDependencies(
    680                                   gtl::ArraySlice<Operation>{assign.operation}),
    681                           1);
    682     auto add_j =
    683         ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one);
    684 
    685     auto next_iteration_j = ops::NextIteration(
    686         scope.WithOpName("outer/inner/NextIteration_j"), add_j);
    687     auto next_iteration_k = ops::NextIteration(
    688         scope.WithOpName("outer/inner/NextIteration_k"), identity_k);
    689 
    690     // Body and backedge for outer loop.
    691     auto one_outer = ops::Const<int32>(
    692         scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1);
    693     auto add_i =
    694         ops::Add(scope.WithOpName("outer/add")
    695                      .WithControlDependencies(gtl::ArraySlice<Operation>{
    696                          exit_j.output.op(), exit_k.output.op()}),
    697                  identity_i, one_outer);
    698     auto next_iteration_i =
    699         ops::NextIteration(scope.WithOpName("outer/NextIteration"), add_i);
    700 
    701     auto sink = ops::Identity(scope.WithOpName("sink"), exit_i);
    702 
    703     // Remove the dummy node and add the loop backedge.
    704     scope.graph()->RemoveNode(dummy.node());
    705     scope.graph()->AddEdge(next_iteration_i.node(), 0, merge_i.output.node(),
    706                            1);
    707     scope.graph()->AddEdge(next_iteration_j.node(), 0, merge_j.output.node(),
    708                            1);
    709     scope.graph()->AddEdge(next_iteration_k.node(), 0, merge_k.output.node(),
    710                            1);
    711 
    712     TF_EXPECT_OK(scope.ToGraph(&graph));
    713   }
    714 
    715   FunctionLibraryDefinition library(OpRegistry::Global(), {});
    716   TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
    717 
    718   GraphDef graph_def;
    719   graph.ToGraphDef(&graph_def);
    720 
    721   NameAttrList outer_cond_fn, outer_body_fn;
    722   TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &outer_cond_fn, &outer_body_fn));
    723 
    724   // Outer graph.
    725   {
    726     Scope scope = Scope::NewRootScope().ExitOnError();
    727     auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
    728     auto three = ops::Const<int32>(scope.WithOpName("three"), 3);
    729     auto y = ops::Add(scope.WithOpName("y"), x, three);
    730 
    731     auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32,
    732                                 TensorShape({}));
    733 
    734     auto zero = ops::Const<int32>(scope.WithOpName("outer/Const"), 0);
    735 
    736     auto while_op = ops::XlaWhile(scope.WithOpName("outer/LoopCond"),
    737                                   std::initializer_list<Input>{zero, y, x, var},
    738                                   outer_cond_fn, outer_body_fn);
    739     auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]);
    740     GraphDef expected;
    741     TF_EXPECT_OK(scope.ToGraphDef(&expected));
    742     TF_EXPECT_GRAPH_EQ(expected, graph_def);
    743   }
    744 
    745   // Outer condition graph.
    746   {
    747     Scope scope = Scope::NewRootScope().ExitOnError();
    748     auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
    749     auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
    750     auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
    751     auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3);
    752 
    753     auto ten = ops::Const<int32>(
    754         scope.WithOpName("outer/Less/y").WithControlDependencies(arg0.output),
    755         10);
    756     auto less = ops::Less(scope.WithOpName("outer/Less_i"), arg0, ten);
    757     auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0);
    758 
    759     GraphDef expected;
    760     TF_EXPECT_OK(scope.ToGraphDef(&expected));
    761 
    762     InstantiationResultForTest result;
    763     TF_EXPECT_OK(
    764         InstantiateFunctionForTest(outer_cond_fn.name(), library, &result));
    765 
    766     EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
    767               result.arg_types);
    768     EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
    769     TF_EXPECT_GRAPH_EQ(expected, result.gdef);
    770   }
    771 
    772   // Outer body graph.
    773   NameAttrList inner_cond_fn, inner_body_fn;
    774   {
    775     InstantiationResultForTest result;
    776     TF_EXPECT_OK(
    777         InstantiateFunctionForTest(outer_body_fn.name(), library, &result));
    778 
    779     // Find the inner condition and body names.
    780     TF_EXPECT_OK(
    781         FindWhileCondAndBody(result.gdef, &inner_cond_fn, &inner_body_fn));
    782 
    783     Scope scope = Scope::NewRootScope().ExitOnError();
    784     auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
    785     auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
    786     auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
    787     auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3);
    788 
    789     auto identity_i = ops::Identity(scope.WithOpName("outer/Identity"), arg0);
    790     auto one_j = ops::Const<int32>(
    791         scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1);
    792     auto while_op =
    793         ops::XlaWhile(scope.WithOpName("outer/LoopCond_1"),
    794                       std::initializer_list<Input>{one_j, arg1, arg2, arg3},
    795                       inner_cond_fn, inner_body_fn);
    796 
    797     auto one_outer = ops::Const<int32>(
    798         scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1);
    799     auto add_i =
    800         ops::Add(scope.WithOpName("outer/add")
    801                      .WithControlDependencies(gtl::ArraySlice<Operation>{
    802                          while_op[0].op(), while_op[1].op()}),
    803                  identity_i, one_outer);
    804 
    805     auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_i, 0);
    806     auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), arg1, 1);
    807     auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2);
    808 
    809     GraphDef expected;
    810     TF_EXPECT_OK(scope.ToGraphDef(&expected));
    811 
    812     EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
    813               result.arg_types);
    814     EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}), result.ret_types);
    815     TF_EXPECT_GRAPH_EQ(expected, result.gdef);
    816   }
    817 
    818   // Inner condition graph.
    819   {
    820     Scope scope = Scope::NewRootScope().ExitOnError();
    821     auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
    822     auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
    823     auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
    824     auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3);
    825 
    826     auto five = ops::Const<int32>(
    827         scope.WithOpName("outer/inner/Five").WithControlDependencies(arg0), 5);
    828     auto less_j = ops::Less(scope.WithOpName("outer/inner/Less_j"), arg0, five);
    829     auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less_j, 0);
    830 
    831     GraphDef expected;
    832     TF_EXPECT_OK(scope.ToGraphDef(&expected));
    833 
    834     InstantiationResultForTest result;
    835     TF_EXPECT_OK(
    836         InstantiateFunctionForTest(inner_cond_fn.name(), library, &result));
    837 
    838     EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
    839               result.arg_types);
    840     EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
    841     TF_EXPECT_GRAPH_EQ(expected, result.gdef);
    842   }
    843 
    844   // Inner body graph.
    845   {
    846     Scope scope = Scope::NewRootScope().ExitOnError();
    847     auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
    848     auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
    849     auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
    850     auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3);
    851 
    852     auto identity_j =
    853         ops::Identity(scope.WithOpName("outer/inner/Identity_j"), arg0);
    854     auto identity_k =
    855         ops::Identity(scope.WithOpName("outer/inner/Identity_k"), arg1);
    856 
    857     auto mul_jk =
    858         ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k);
    859     auto add_jkx = ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, arg2);
    860     auto assign = ops::AssignAddVariableOp(
    861         scope.WithOpName("outer/inner/assign_add"), arg3, add_jkx);
    862 
    863     auto one =
    864         ops::Const<int32>(scope.WithOpName("outer/inner/One")
    865                               .WithControlDependencies(
    866                                   gtl::ArraySlice<Operation>{assign.operation}),
    867                           1);
    868     auto add_j =
    869         ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one);
    870 
    871     auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_j, 0);
    872     auto retval1 =
    873         ops::_Retval(scope.WithOpName("_retval1_RetVal"), identity_k, 1);
    874     auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2);
    875 
    876     GraphDef expected;
    877     TF_EXPECT_OK(scope.ToGraphDef(&expected));
    878 
    879     InstantiationResultForTest result;
    880     TF_EXPECT_OK(
    881         InstantiateFunctionForTest(inner_body_fn.name(), library, &result));
    882 
    883     EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
    884               result.arg_types);
    885     EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}), result.ret_types);
    886     TF_EXPECT_GRAPH_EQ(expected, result.gdef);
    887   }
    888 }
    889 
    890 }  // namespace
    891 }  // namespace tensorflow
    892