Home | History | Annotate | Download | only in jit
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
     17 
     18 #include "tensorflow/cc/framework/ops.h"
     19 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
     20 #include "tensorflow/cc/ops/standard_ops.h"
     21 #include "tensorflow/compiler/jit/defs.h"
     22 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
     23 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     24 #include "tensorflow/core/framework/node_def_util.h"
     25 #include "tensorflow/core/framework/op.h"
     26 #include "tensorflow/core/graph/graph_constructor.h"
     27 #include "tensorflow/core/graph/graph_def_builder.h"
     28 #include "tensorflow/core/graph/graph_def_builder_util.h"
     29 #include "tensorflow/core/lib/core/status_test_util.h"
     30 #include "tensorflow/core/platform/test.h"
     31 
     32 namespace tensorflow {
     33 namespace {
     34 
     35 REGISTER_OP("UncompilableNullary").Output("o: float");
     36 REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float");
     37 
     38 Status MarkForCompilation(std::unique_ptr<Graph>* graph,
     39                           FunctionLibraryDefinition* flib_def) {
     40   // Assign all nodes to the CPU device.
     41   static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0";
     42   for (Node* n : (*graph)->nodes()) {
     43     n->set_assigned_device_name(kCpuDevice);
     44   }
     45 
     46   GraphOptimizationPassOptions opt_options;
     47   opt_options.graph = graph;
     48   opt_options.flib_def = flib_def;
     49   MarkForCompilationPass pass;
     50   return pass.RunImpl(opt_options);
     51 }
     52 
     53 Status MarkForCompilation(std::unique_ptr<Graph>* graph) {
     54   FunctionDefLibrary flib;
     55   FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib);
     56   return MarkForCompilation(graph, &flib_def);
     57 }
     58 
     59 std::unordered_map<string, string> GetClusters(const Graph& graph) {
     60   std::unordered_map<string, string> ids;
     61   for (Node* node : graph.nodes()) {
     62     string cluster;
     63     if (GetNodeAttr(node->attrs(), kXlaClusterAttr, &cluster).ok()) {
     64       CHECK(!cluster.empty());
     65       ids[node->name()] = cluster;
     66     }
     67   }
     68   return ids;
     69 }
     70 
     71 TEST(XlaCompilationTest, Chains) {
     72   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
     73   GraphDef graphdef;
     74   {
     75     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
     76     Node* a =
     77         ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
     78     Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
     79     Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
     80     Node* d =
     81         ops::UnaryOp("UncompilableUnary", c, builder.opts().WithName("D"));
     82     Node* e = ops::UnaryOp("Relu", d, builder.opts().WithName("E"));
     83     ops::UnaryOp("Relu", e, builder.opts().WithName("F"));
     84     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
     85   }
     86 
     87   TF_ASSERT_OK(MarkForCompilation(&graph));
     88   auto clusters = GetClusters(*graph);
     89   EXPECT_EQ(4, clusters.size());
     90   EXPECT_EQ(clusters["B"], clusters["C"]);
     91   EXPECT_EQ(clusters["E"], clusters["F"]);
     92   EXPECT_NE(clusters["B"], clusters["E"]);
     93   EXPECT_TRUE(clusters.find("A") == clusters.cend());
     94   EXPECT_TRUE(clusters.find("D") == clusters.cend());
     95 }
     96 
     97 TEST(XlaCompilationTest, UncompilableCycles) {
     98   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
     99   GraphDef graphdef;
    100   {
    101     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
    102     Node* a = ops::SourceOp("Const", builder.opts()
    103                                          .WithName("A")
    104                                          .WithAttr("dtype", DT_FLOAT)
    105                                          .WithAttr("value", Tensor()));
    106     Node* b =
    107         ops::UnaryOp("UncompilableUnary", a, builder.opts().WithName("B"));
    108     ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
    109     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
    110   }
    111 
    112   TF_ASSERT_OK(MarkForCompilation(&graph));
    113   auto clusters = GetClusters(*graph);
    114 
    115   EXPECT_TRUE(clusters.empty());
    116 }
    117 
    118 TEST(XlaCompilationTest, CompilableCycles) {
    119   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    120   GraphDef graphdef;
    121   {
    122     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
    123     Node* a = ops::SourceOp("Const", builder.opts()
    124                                          .WithName("A")
    125                                          .WithAttr("dtype", DT_FLOAT)
    126                                          .WithAttr("value", Tensor()));
    127     Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
    128     ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
    129     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
    130   }
    131 
    132   TF_ASSERT_OK(MarkForCompilation(&graph));
    133   auto clusters = GetClusters(*graph);
    134 
    135   EXPECT_EQ(3, clusters.size());
    136   EXPECT_EQ(clusters["A"], clusters["B"]);
    137   EXPECT_EQ(clusters["A"], clusters["C"]);
    138 }
    139 
    140 TEST(XlaCompilationTest, UnsupportedTypes) {
    141   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    142   GraphDef graphdef;
    143   {
    144     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
    145     Node* a = ops::SourceOp(
    146         "Const", builder.opts()
    147                      .WithName("A")
    148                      .WithAttr("dtype", DT_COMPLEX128)
    149                      .WithAttr("value", Tensor(DT_COMPLEX128, TensorShape())));
    150     Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B"));
    151     ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
    152     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
    153   }
    154 
    155   TF_ASSERT_OK(MarkForCompilation(&graph));
    156   auto clusters = GetClusters(*graph);
    157   EXPECT_TRUE(clusters.empty());
    158 }
    159 
    160 TEST(XlaCompilationTest, ConcatWithConstArg) {
    161   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    162   GraphDef graphdef;
    163   {
    164     Tensor t(DT_INT32, TensorShape());
    165     t.scalar<int32>()() = 0;
    166     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
    167     Node* dim = ops::SourceOp("Const", builder.opts()
    168                                            .WithName("Dim")
    169                                            .WithAttr("dtype", DT_INT32)
    170                                            .WithAttr("value", t));
    171     Node* a = ops::SourceOp("Const", builder.opts()
    172                                          .WithName("A")
    173                                          .WithAttr("dtype", DT_FLOAT)
    174                                          .WithAttr("value", t));
    175 
    176     NodeBuilder concat_builder("Concat", "Concat",
    177                                builder.opts().op_registry());
    178     concat_builder.Input(dim).Input({a, a}).Attr("N", 2);
    179     builder.opts().FinalizeBuilder(&concat_builder);
    180 
    181     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
    182   }
    183 
    184   TF_ASSERT_OK(MarkForCompilation(&graph));
    185   auto clusters = GetClusters(*graph);
    186   EXPECT_EQ(3, clusters.size());  // Everything should be compiled.
    187 }
    188 
    189 TEST(XlaCompilationTest, FunctionCalls) {
    190   FunctionDef compilable = FunctionDefHelper::Define(
    191       "CompilableFn", {"n_a:float", "n_b:float"}, {"n_c:float"}, {},
    192       {{{"n_c"}, "Add", {"n_a", "n_b"}, {{"T", DT_FLOAT}}}});
    193   FunctionDef uncompilable =
    194       FunctionDefHelper::Define("UncompilableFn", {"n_a:float"}, {"n_c:float"},
    195                                 {}, {{{"n_c"}, "UncompilableUnary", {"n_a"}}});
    196   FunctionDef noinline = compilable;
    197   noinline.mutable_signature()->set_name("NoInlineFn");
    198   AddAttr("_noinline", bool(true), noinline.mutable_attr());
    199 
    200   FunctionDefLibrary flib;
    201   *flib.add_function() = compilable;
    202   *flib.add_function() = uncompilable;
    203   *flib.add_function() = noinline;
    204   FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
    205 
    206   std::unique_ptr<Graph> graph(new Graph(&flib_def));
    207   GraphDef graphdef;
    208   {
    209     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately, &flib_def);
    210     Node* a =
    211         ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
    212     Node* b = ops::BinaryOp("CompilableFn", a, a, builder.opts().WithName("B"));
    213     Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
    214     ops::UnaryOp("UncompilableFn", c, builder.opts().WithName("D"));
    215     ops::BinaryOp("NoInlineFn", c, c, builder.opts().WithName("E"));
    216     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
    217   }
    218 
    219   TF_ASSERT_OK(MarkForCompilation(&graph, &flib_def));
    220   auto clusters = GetClusters(*graph);
    221 
    222   EXPECT_EQ(2, clusters.size());
    223   EXPECT_FALSE(clusters["B"].empty());
    224   EXPECT_EQ(clusters["B"], clusters["C"]);
    225   EXPECT_TRUE(clusters.find("A") == clusters.cend());
    226   EXPECT_TRUE(clusters.find("D") == clusters.cend());
    227   EXPECT_TRUE(clusters.find("E") == clusters.cend());
    228 }
    229 
    230 // Metadata-only operators such as Shape/Rank/Size may not be the root of a
    231 // cluster. This is partially to work around b/26800664, and partially because
    232 // we should probably prefer to compile metadata operators with their producers
    233 // wherever possible, rather than their consumers.
    234 TEST(XlaCompilationTest, MetadataOpsDontStartClusters) {
    235   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    236   GraphDef graphdef;
    237   {
    238     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
    239     Node* a =
    240         ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
    241     // While all of the following ops are notionally compilable, none is
    242     // permitted
    243     // to start a cluster. So nothing should be compiled.
    244     Node* b = ops::UnaryOp("Shape", a, builder.opts().WithName("B"));
    245     Node* c = ops::UnaryOp("Rank", b, builder.opts().WithName("C"));
    246     Node* d = ops::UnaryOp("Size", c, builder.opts().WithName("D"));
    247     ops::UnaryOp("Shape", d, builder.opts().WithName("E"));
    248     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
    249   }
    250   TF_ASSERT_OK(MarkForCompilation(&graph));
    251   auto clusters = GetClusters(*graph);
    252   EXPECT_EQ(0, clusters.size());  // Nothing should be compiled.
    253 }
    254 
    255 static Status GradForUnaryCwise(FunctionDef* g,
    256                                 std::vector<FunctionDefHelper::Node> nodes) {
    257   for (auto& n : nodes) {
    258     if (n.attr.empty()) {
    259       n.attr = {{"T", DT_FLOAT}};
    260     }
    261   }
    262   *g = FunctionDefHelper::Define(
    263       // Arg defs
    264       {"x: float", "dy: float"},
    265       // Ret val defs
    266       {"dx: float"},
    267       // Attr defs
    268       {},
    269       // Nodes
    270       nodes);
    271   return Status::OK();
    272 }
    273 
    274 // A gradient containing only supported operators
    275 Status SupportedGrad(const AttrSlice& attrs, FunctionDef* g) {
    276   // clang-format off
    277   return GradForUnaryCwise(g, {
    278       {{"y"}, "Tanh", {"x"}},
    279       {{"y2"}, "Square", {"y"}, {}, {"dy"}},
    280       FunctionDefHelper::Const("one", 1.0f),
    281       {{"a"}, "Sub", {"one", "y2"}},
    282       {{"dx"}, "Mul", {"dy", "a"}},
    283   });
    284   // clang-format on
    285 }
    286 REGISTER_OP_GRADIENT("Supported", SupportedGrad);
    287 
    288 // A gradient containing an unsupported operator.
    289 Status UnsupportedGrad(const AttrSlice& attrs, FunctionDef* g) {
    290   // clang-format off
    291   return GradForUnaryCwise(g, {
    292       {{"y"}, "Tanh", {"x"}},
    293       {{"y2"}, "UncompilableUnary", {"y"}, {}, {"dy"}},
    294       FunctionDefHelper::Const("one", 1.0f),
    295       {{"a"}, "Sub", {"one", "y2"}},
    296       {{"dx"}, "Mul", {"dy", "a"}},
    297   });
    298   // clang-format on
    299 }
    300 REGISTER_OP_GRADIENT("Unsupported", UnsupportedGrad);
    301 
    302 TEST(XlaCompilationTest, SymbolicGradients) {
    303   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    304   GraphDef graphdef;
    305   {
    306     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
    307     Node* a =
    308         ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
    309 
    310     // Builds a Symbolic gradient for Supported
    311     NodeBuilder b_builder("B", "SymbolicGradient",
    312                           builder.opts().op_registry());
    313     NameAttrList b_name_attr;
    314     b_name_attr.set_name("Supported");
    315     b_builder.Attr("f", b_name_attr);
    316     b_builder.Attr("Tin", {DT_FLOAT, DT_FLOAT});
    317     b_builder.Attr("Tout", {DT_FLOAT});
    318     b_builder.Input({a, a});
    319     Node* b = builder.opts().FinalizeBuilder(&b_builder);
    320 
    321     Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
    322 
    323     // Builds a Symbolic gradient for Unsupported
    324     NodeBuilder d_builder("D", "SymbolicGradient",
    325                           builder.opts().op_registry());
    326     NameAttrList d_name_attr;
    327     d_name_attr.set_name("Unsupported");
    328     d_builder.Attr("f", d_name_attr);
    329     d_builder.Attr("Tin", {DT_FLOAT, DT_FLOAT});
    330     d_builder.Attr("Tout", {DT_FLOAT});
    331     d_builder.Input({c, c});
    332     builder.opts().FinalizeBuilder(&d_builder);
    333 
    334     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
    335   }
    336 
    337   TF_ASSERT_OK(MarkForCompilation(&graph));
    338   auto clusters = GetClusters(*graph);
    339 
    340   EXPECT_EQ(2, clusters.size());
    341   EXPECT_FALSE(clusters["B"].empty());
    342   EXPECT_EQ(clusters["B"], clusters["C"]);
    343   EXPECT_TRUE(clusters.find("A") == clusters.cend());
    344   EXPECT_TRUE(clusters.find("D") == clusters.cend());
    345 }
    346 
    347 TEST(XlaCompilationTest, Loops) {
    348   // Regression test for b/32350199, where the autoclustering code introduced a
    349   // deadlock in a graph containing a while loop.
    350   Scope root = Scope::NewRootScope().ExitOnError();
    351   auto a = ops::Placeholder(root.WithOpName("A"), DT_FLOAT);
    352   auto b = ops::Placeholder(root.WithOpName("B"), DT_FLOAT);
    353   auto c = ops::Add(root.WithOpName("C"), a, b);
    354   auto enter = ops::internal::Enter(root, c, "aframe");
    355   auto next_iter = ops::NextIteration(root, enter);
    356   auto exit = ops::internal::Exit(root, next_iter);
    357   auto d = ops::Add(root.WithOpName("D"), c, exit);
    358 
    359   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    360   TF_EXPECT_OK(root.ToGraph(graph.get()));
    361 
    362   TF_ASSERT_OK(MarkForCompilation(&graph));
    363   auto clusters = GetClusters(*graph);
    364 
    365   // Nothing should be compiled. In particular, 'd' and 'c' must not be
    366   // compiled.
    367   EXPECT_EQ(0, clusters.size());
    368 }
    369 
    370 TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) {
    371   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    372   GraphDef graphdef;
    373   {
    374     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
    375     Node* a = ops::SourceOp("Const", builder.opts()
    376                                          .WithName("A")
    377                                          .WithAttr("dtype", DT_FLOAT)
    378                                          .WithAttr("value", Tensor())
    379                                          .WithAttr(kXlaScopeAttr, "ScopeA"));
    380     Node* b = ops::UnaryOp(
    381         "Relu", a,
    382         builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "ScopeB"));
    383     ops::BinaryOp(
    384         "MatMul", a, b,
    385         builder.opts().WithName("C").WithAttr(kXlaScopeAttr, "ScopeC"));
    386     TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
    387   }
    388 
    389   TF_ASSERT_OK(MarkForCompilation(&graph));
    390   auto clusters = GetClusters(*graph);
    391 
    392   // The computation is: C = A + relu(A)
    393   // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC.
    394   // In this case, we cannot fuse anything, and there are no clusters.
    395   EXPECT_EQ(0, clusters.size());
    396 }
    397 
    398 TEST(XlaCompilationTest, CyclesWithSplittingScopes) {
    399   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    400   GraphDef graphdef;
    401   {
    402     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
    403     Node* a = ops::SourceOp("Const", builder.opts()
    404                                          .WithName("A")
    405                                          .WithAttr("dtype", DT_FLOAT)
    406                                          .WithAttr("value", Tensor())
    407                                          .WithAttr(kXlaScopeAttr, "Scope1"));
    408     Node* b = ops::UnaryOp(
    409         "Relu", a,
    410         builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "Scope1"));
    411     Node* c = ops::BinaryOp(
    412         "MatMul", a, b,
    413         builder.opts().WithName("C").WithAttr(kXlaScopeAttr, "Scope2"));
    414     ops::BinaryOp(
    415         "Add", b, c,
    416         builder.opts().WithName("D").WithAttr(kXlaScopeAttr, "Scope2"));
    417     TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
    418   }
    419 
    420   TF_ASSERT_OK(MarkForCompilation(&graph));
    421   auto clusters = GetClusters(*graph);
    422 
    423   // The computation is: D = relu(A) + (A @ relu(A))
    424   // where A and relu(A) are in Scope1, and the @, + ops are in Scope2.
    425   // In this case, we can fuse the A and relu(A), and we can fuse the
    426   // second half of the operations; there are two clusters.
    427   EXPECT_EQ(4, clusters.size());
    428   EXPECT_EQ(clusters["A"], clusters["B"]);
    429   EXPECT_NE(clusters["A"], clusters["C"]);
    430   EXPECT_EQ(clusters["C"], clusters["D"]);
    431 }
    432 
    433 TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) {
    434   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    435   GraphDef graphdef;
    436   {
    437     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
    438     Node* a = ops::SourceOp("Const", builder.opts()
    439                                          .WithName("A")
    440                                          .WithAttr("dtype", DT_FLOAT)
    441                                          .WithAttr("value", Tensor())
    442                                          .WithAttr(kXlaScopeAttr, "ScopeA"));
    443     Node* b = ops::UnaryOp(
    444         "Relu", a,
    445         builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "ScopeB"));
    446     ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
    447     TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
    448   }
    449 
    450   TF_ASSERT_OK(MarkForCompilation(&graph));
    451   auto clusters = GetClusters(*graph);
    452 
    453   // The computation is: C = A @ relu(A)
    454   // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC.
    455   // In this case, we cannot fuse anything.
    456   EXPECT_EQ(2, clusters.size());
    457   EXPECT_NE(clusters["A"], clusters["B"]);
    458   EXPECT_EQ(clusters["B"], clusters["C"]);
    459 }
    460 
    461 REGISTER_OP("ResourceInput").Input("a: resource").Output("o: float");
    462 REGISTER_OP("ResourceOutput").Input("a: float").Output("o: resource");
    463 
    464 namespace {
    465 
    466 class DummyOp : public XlaOpKernel {
    467   using XlaOpKernel::XlaOpKernel;
    468   void Compile(XlaOpKernelContext* ctx) override {}
    469 };
    470 
    471 REGISTER_XLA_OP(Name("ResourceInput"), DummyOp);
    472 REGISTER_XLA_OP(Name("ResourceOutput"), DummyOp);
    473 
    474 }  // namespace
    475 
    476 TEST(XlaCompilationTest, Resources) {
    477   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    478   GraphDef graphdef;
    479   {
    480     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
    481     Node* a =
    482         ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
    483     Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
    484     // We should not form clusters with resource ops by default.
    485     Node* c = ops::UnaryOp("ResourceOutput", b, builder.opts().WithName("C"));
    486     Node* d = ops::UnaryOp("ResourceInput", c, builder.opts().WithName("D"));
    487     ops::UnaryOp("Relu", d, builder.opts().WithName("E"));
    488     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
    489   }
    490   TF_ASSERT_OK(MarkForCompilation(&graph));
    491   auto clusters = GetClusters(*graph);
    492   EXPECT_EQ(0, clusters.size());  // Nothing should be compiled.
    493 }
    494 
    495 TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) {
    496   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    497   Scope root = Scope::NewRootScope().ExitOnError();
    498   {
    499     auto BuildNoopNode = [](StringPiece name, Graph* graph) {
    500       NodeDefBuilder builder(name, "NoOp");
    501       NodeDef def;
    502       TF_CHECK_OK(builder.Finalize(&def));
    503 
    504       Status status;
    505       Node* node = graph->AddNode(def, &status);
    506       TF_CHECK_OK(status);
    507       return node;
    508     };
    509 
    510     Node* a = BuildNoopNode("a", graph.get());
    511     Node* b = BuildNoopNode("b", graph.get());
    512     Node* c = BuildNoopNode("c", graph.get());
    513     graph->AddControlEdge(a, b);
    514     graph->AddControlEdge(b, c);
    515     graph->AddControlEdge(c, a);
    516   }
    517 
    518   TF_EXPECT_OK(root.ToGraph(graph.get()));
    519 
    520   Status status = MarkForCompilation(&graph);
    521   EXPECT_FALSE(status.ok());
    522   EXPECT_TRUE(StringPiece(status.ToString())
    523                   .contains("Edge from c to a would create a cycle.\n"
    524                             "+-> a\n"
    525                             "|   b\n"
    526                             "+-- c\n"));
    527 }
    528 
    529 TEST(XlaCompilationTest, Retval) {
    530   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    531   GraphDef graphdef;
    532   {
    533     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
    534     Node* a = ops::SourceOp("Const", builder.opts()
    535                                          .WithName("A")
    536                                          .WithAttr("dtype", DT_FLOAT)
    537                                          .WithAttr("value", Tensor()));
    538     Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
    539     ops::UnaryOp("_Retval", b,
    540                  builder.opts()
    541                      .WithName("R")
    542                      .WithAttr("T", DT_FLOAT)
    543                      .WithAttr("index", 0));
    544 
    545     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
    546   }
    547 
    548   TF_ASSERT_OK(MarkForCompilation(&graph));
    549   auto clusters = GetClusters(*graph);
    550 
    551   EXPECT_EQ(2, clusters.size());
    552   EXPECT_TRUE(clusters.find("R") == clusters.cend());
    553   EXPECT_EQ(clusters["A"], clusters["B"]);
    554 }
    555 
    556 }  // namespace
    557 }  // namespace tensorflow
    558