Home | History | Annotate | Download | only in jit
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
     17 
     18 #include "absl/container/flat_hash_map.h"
     19 #include "absl/memory/memory.h"
     20 #include "absl/strings/match.h"
     21 #include "tensorflow/cc/framework/ops.h"
     22 #include "tensorflow/cc/ops/array_ops.h"
     23 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
     24 #include "tensorflow/cc/ops/function_ops.h"
     25 #include "tensorflow/cc/ops/list_ops.h"
     26 #include "tensorflow/cc/ops/resource_variable_ops.h"
     27 #include "tensorflow/cc/ops/sendrecv_ops.h"
     28 #include "tensorflow/cc/ops/standard_ops.h"
     29 #include "tensorflow/compiler/jit/defs.h"
     30 #include "tensorflow/compiler/jit/node_matchers.h"
     31 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
     32 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     33 #include "tensorflow/core/framework/node_def_util.h"
     34 #include "tensorflow/core/framework/op.h"
     35 #include "tensorflow/core/graph/algorithm.h"
     36 #include "tensorflow/core/graph/graph_constructor.h"
     37 #include "tensorflow/core/graph/graph_def_builder.h"
     38 #include "tensorflow/core/graph/graph_def_builder_util.h"
     39 #include "tensorflow/core/lib/core/status_test_util.h"
     40 #include "tensorflow/core/platform/test.h"
     41 
     42 using ::tensorflow::testing::FindNodeByName;
     43 
     44 namespace tensorflow {
     45 namespace {
     46 
     47 REGISTER_OP("UncompilableNullary").Output("o: float");
     48 REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float");
     49 
     50 std::unordered_map<string, string> GetClusters(const Graph& graph) {
     51   std::unordered_map<string, string> ids;
     52   for (Node* node : graph.nodes()) {
     53     string cluster;
     54     if (GetNodeAttr(node->attrs(), kXlaClusterAttr, &cluster).ok()) {
     55       CHECK(!cluster.empty());
     56       ids[node->name()] = cluster;
     57     }
     58   }
     59 
     60   if (VLOG_IS_ON(2)) {
     61     VLOG(2) << "Clusters:";
     62     for (const auto& p : ids) {
     63       VLOG(2) << " " << p.first << " -> " << p.second;
     64     }
     65   }
     66   return ids;
     67 }
     68 
     69 absl::flat_hash_map<string, std::vector<string>> GetClusterSets(
     70     const Graph& g, std::vector<string>* cluster_names = nullptr) {
     71   CHECK(cluster_names == nullptr || cluster_names->empty());
     72   absl::flat_hash_map<string, std::vector<string>> cluster_sets;
     73   for (const auto& p : GetClusters(g)) {
     74     cluster_sets[p.second].push_back(p.first);
     75   }
     76   for (auto& p : cluster_sets) {
     77     if (cluster_names != nullptr) {
     78       cluster_names->push_back(p.first);
     79     }
     80     std::sort(p.second.begin(), p.second.end());
     81   }
     82   if (cluster_names != nullptr) {
     83     std::sort(cluster_names->begin(), cluster_names->end());
     84   }
     85   return cluster_sets;
     86 }
     87 
     88 TEST(XlaCompilationTest, Chains) {
     89   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
     90   GraphDef graphdef;
     91   {
     92     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
     93     Node* a =
     94         ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
     95     Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
     96     Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
     97     Node* d =
     98         ops::UnaryOp("UncompilableUnary", c, builder.opts().WithName("D"));
     99     Node* e = ops::UnaryOp("Relu", d, builder.opts().WithName("E"));
    100     ops::UnaryOp("Relu", e, builder.opts().WithName("F"));
    101     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
    102   }
    103 
    104   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
    105   auto clusters = GetClusters(*graph);
    106   EXPECT_EQ(4, clusters.size());
    107   EXPECT_EQ(clusters["B"], clusters["C"]);
    108   EXPECT_EQ(clusters["E"], clusters["F"]);
    109   EXPECT_NE(clusters["B"], clusters["E"]);
    110   EXPECT_TRUE(clusters.find("A") == clusters.cend());
    111   EXPECT_TRUE(clusters.find("D") == clusters.cend());
    112 }
    113 
    114 TEST(XlaCompilationTest, UncompilableCycles) {
    115   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    116   GraphDef graphdef;
    117   {
    118     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
    119     Node* a = ops::SourceOp("Const", builder.opts()
    120                                          .WithName("A")
    121                                          .WithAttr("dtype", DT_FLOAT)
    122                                          .WithAttr("value", Tensor()));
    123     Node* b =
    124         ops::UnaryOp("UncompilableUnary", a, builder.opts().WithName("B"));
    125     ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
    126     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
    127   }
    128 
    129   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
    130   auto clusters = GetClusters(*graph);
    131 
    132   EXPECT_TRUE(clusters.empty());
    133 }
    134 
    135 TEST(XlaCompilationTest, CompilableCycles) {
    136   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    137   GraphDef graphdef;
    138   {
    139     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
    140     Node* a = ops::SourceOp("Const", builder.opts()
    141                                          .WithName("A")
    142                                          .WithAttr("dtype", DT_FLOAT)
    143                                          .WithAttr("value", Tensor()));
    144     Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
    145     ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
    146     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
    147   }
    148 
    149   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
    150   auto clusters = GetClusters(*graph);
    151 
    152   EXPECT_EQ(3, clusters.size());
    153   EXPECT_EQ(clusters["A"], clusters["B"]);
    154   EXPECT_EQ(clusters["A"], clusters["C"]);
    155 }
    156 
    157 TEST(XlaCompilationTest, StringUnsupported) {
    158   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    159   GraphDef graphdef;
    160   {
    161     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
    162     Node* a = ops::SourceOp(
    163         "Const", builder.opts()
    164                      .WithName("A")
    165                      .WithAttr("dtype", DT_STRING)
    166                      .WithAttr("value", Tensor(DT_STRING, TensorShape())));
    167     Node* b = ops::UnaryOp("EncodeBase64", a, builder.opts().WithName("B"));
    168     ops::BinaryOp("StringSplit", a, b, builder.opts().WithName("C"));
    169     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
    170   }
    171 
    172   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
    173   auto clusters = GetClusters(*graph);
    174   EXPECT_TRUE(clusters.empty());
    175 }
    176 
    177 TEST(XlaCompilationTest, HalfSupported) {
    178   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    179   GraphDef graphdef;
    180   {
    181     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
    182     Tensor t(DT_HALF, TensorShape());
    183     t.scalar<Eigen::half>()() = static_cast<Eigen::half>(0.0f);
    184     Node* a = ops::SourceOp("Const", builder.opts()
    185                                          .WithName("A")
    186                                          .WithAttr("dtype", DT_HALF)
    187                                          .WithAttr("value", t));
    188     Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B"));
    189     ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
    190     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
    191   }
    192 
    193   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
    194   auto clusters = GetClusters(*graph);
    195   EXPECT_FALSE(clusters.empty());
    196 }
    197 
    198 TEST(XlaCompilationTest, FunctionCalls) {
    199   FunctionDef compilable = FunctionDefHelper::Define(
    200       "CompilableFn", {"n_a:float", "n_b:float"}, {"n_c:float"}, {},
    201       {{{"n_c"}, "Add", {"n_a", "n_b"}, {{"T", DT_FLOAT}}}});
    202   FunctionDef uncompilable =
    203       FunctionDefHelper::Define("UncompilableFn", {"n_a:float"}, {"n_c:float"},
    204                                 {}, {{{"n_c"}, "UncompilableUnary", {"n_a"}}});
    205   FunctionDef noinline = compilable;
    206   noinline.mutable_signature()->set_name("NoInlineFn");
    207   AddAttr("_noinline", static_cast<bool>(true), noinline.mutable_attr());
    208 
    209   FunctionDefLibrary flib;
    210   *flib.add_function() = compilable;
    211   *flib.add_function() = uncompilable;
    212   *flib.add_function() = noinline;
    213   FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
    214 
    215   std::unique_ptr<Graph> graph(new Graph(&flib_def));
    216   GraphDef graphdef;
    217   {
    218     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately, &flib_def);
    219     Node* a =
    220         ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
    221     Node* b = ops::BinaryOp("CompilableFn", a, a, builder.opts().WithName("B"));
    222     Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
    223     ops::UnaryOp("UncompilableFn", c, builder.opts().WithName("D"));
    224     ops::BinaryOp("NoInlineFn", c, c, builder.opts().WithName("E"));
    225     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
    226   }
    227 
    228   TF_ASSERT_OK(
    229       MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def));
    230   auto clusters = GetClusters(*graph);
    231 
    232   EXPECT_EQ(2, clusters.size());
    233   EXPECT_FALSE(clusters["B"].empty());
    234   EXPECT_EQ(clusters["B"], clusters["C"]);
    235   EXPECT_TRUE(clusters.find("A") == clusters.cend());
    236   EXPECT_TRUE(clusters.find("D") == clusters.cend());
    237   EXPECT_TRUE(clusters.find("E") == clusters.cend());
    238 }
    239 
    240 // Metadata-only operators such as Shape/Rank/Size may not be the root of a
    241 // cluster. This is partially to work around b/26800664, and partially because
    242 // we should probably prefer to compile metadata operators with their producers
    243 // wherever possible, rather than their consumers.
    244 TEST(XlaCompilationTest, MetadataOpsDontStartClusters) {
    245   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    246   GraphDef graphdef;
    247   {
    248     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
    249     Node* a =
    250         ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
    251     // While all of the following ops are notionally compilable, none is
    252     // permitted
    253     // to start a cluster. So nothing should be compiled.
    254     Node* b = ops::UnaryOp("Shape", a, builder.opts().WithName("B"));
    255     Node* c = ops::UnaryOp("Rank", b, builder.opts().WithName("C"));
    256     Node* d = ops::UnaryOp("Size", c, builder.opts().WithName("D"));
    257     ops::UnaryOp("Shape", d, builder.opts().WithName("E"));
    258     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
    259   }
    260   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
    261   auto clusters = GetClusters(*graph);
    262   EXPECT_EQ(0, clusters.size());  // Nothing should be compiled.
    263 }
    264 
    265 static Status GradForUnaryCwise(FunctionDef* g,
    266                                 std::vector<FunctionDefHelper::Node> nodes) {
    267   for (auto& n : nodes) {
    268     if (n.attr.empty()) {
    269       n.attr = {{"T", DT_FLOAT}};
    270     }
    271   }
    272   *g = FunctionDefHelper::Define(
    273       // Arg defs
    274       {"x: float", "dy: float"},
    275       // Ret val defs
    276       {"dx: float"},
    277       // Attr defs
    278       {},
    279       // Nodes
    280       nodes);
    281   return Status::OK();
    282 }
    283 
    284 // A gradient containing only supported operators
    285 Status SupportedGrad(const AttrSlice& attrs, FunctionDef* g) {
    286   // clang-format off
    287   return GradForUnaryCwise(g, {
    288       {{"y"}, "Tanh", {"x"}},
    289       {{"y2"}, "Square", {"y"}, {}, {"dy"}},
    290       FunctionDefHelper::Const("one", 1.0f),
    291       {{"a"}, "Sub", {"one", "y2"}},
    292       {{"dx"}, "Mul", {"dy", "a"}},
    293   });
    294   // clang-format on
    295 }
    296 REGISTER_OP_GRADIENT("Supported", SupportedGrad);
    297 
    298 // A gradient containing an unsupported operator.
    299 Status UnsupportedGrad(const AttrSlice& attrs, FunctionDef* g) {
    300   // clang-format off
    301   return GradForUnaryCwise(g, {
    302       {{"y"}, "Tanh", {"x"}},
    303       {{"y2"}, "UncompilableUnary", {"y"}, {}, {"dy"}},
    304       FunctionDefHelper::Const("one", 1.0f),
    305       {{"a"}, "Sub", {"one", "y2"}},
    306       {{"dx"}, "Mul", {"dy", "a"}},
    307   });
    308   // clang-format on
    309 }
    310 REGISTER_OP_GRADIENT("Unsupported", UnsupportedGrad);
    311 
    312 TEST(XlaCompilationTest, SymbolicGradients) {
    313   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    314   GraphDef graphdef;
    315   {
    316     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
    317     Node* a =
    318         ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
    319 
    320     // Builds a Symbolic gradient for Supported
    321     NodeBuilder b_builder("B", "SymbolicGradient",
    322                           builder.opts().op_registry());
    323     NameAttrList b_name_attr;
    324     b_name_attr.set_name("Supported");
    325     b_builder.Attr("f", b_name_attr);
    326     b_builder.Attr("Tin", {DT_FLOAT, DT_FLOAT});
    327     b_builder.Attr("Tout", {DT_FLOAT});
    328     b_builder.Input({a, a});
    329     Node* b = builder.opts().FinalizeBuilder(&b_builder);
    330 
    331     Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
    332 
    333     // Builds a Symbolic gradient for Unsupported
    334     NodeBuilder d_builder("D", "SymbolicGradient",
    335                           builder.opts().op_registry());
    336     NameAttrList d_name_attr;
    337     d_name_attr.set_name("Unsupported");
    338     d_builder.Attr("f", d_name_attr);
    339     d_builder.Attr("Tin", {DT_FLOAT, DT_FLOAT});
    340     d_builder.Attr("Tout", {DT_FLOAT});
    341     d_builder.Input({c, c});
    342     builder.opts().FinalizeBuilder(&d_builder);
    343 
    344     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
    345   }
    346 
    347   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
    348   auto clusters = GetClusters(*graph);
    349 
    350   EXPECT_EQ(2, clusters.size());
    351   EXPECT_FALSE(clusters["B"].empty());
    352   EXPECT_EQ(clusters["B"], clusters["C"]);
    353   EXPECT_TRUE(clusters.find("A") == clusters.cend());
    354   EXPECT_TRUE(clusters.find("D") == clusters.cend());
    355 }
    356 
    357 TEST(XlaCompilationTest, Loops) {
    358   // Regression test for b/32350199, where the autoclustering code introduced a
    359   // deadlock in a graph containing a while loop.
    360   Scope root = Scope::NewRootScope().ExitOnError();
    361   auto a = ops::Placeholder(root.WithOpName("A"), DT_FLOAT);
    362   auto b = ops::Placeholder(root.WithOpName("B"), DT_FLOAT);
    363   auto c = ops::Add(root.WithOpName("C"), a, b);
    364   auto enter = ops::internal::Enter(root, c, "aframe");
    365   auto next_iter = ops::NextIteration(root, enter);
    366   auto exit = ops::internal::Exit(root, next_iter);
    367   auto d = ops::Add(root.WithOpName("D"), c, exit);
    368 
    369   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    370   TF_EXPECT_OK(root.ToGraph(graph.get()));
    371 
    372   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
    373   auto clusters = GetClusters(*graph);
    374 
    375   // Nothing should be compiled. In particular, 'd' and 'c' must not be
    376   // compiled.
    377   EXPECT_EQ(0, clusters.size());
    378 }
    379 
    380 TEST(XlaCompilationTest, CyclesWithAllDifferentScopesGlobalJitOverridden) {
    381   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    382   GraphDef graphdef;
    383   {
    384     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
    385     Node* a = ops::SourceOp("Const", builder.opts()
    386                                          .WithName("A")
    387                                          .WithAttr("dtype", DT_FLOAT)
    388                                          .WithAttr("value", Tensor())
    389                                          .WithAttr(kXlaScopeAttr, "ScopeA"));
    390     Node* b = ops::UnaryOp(
    391         "Relu", a,
    392         builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "ScopeB"));
    393     ops::BinaryOp(
    394         "MatMul", a, b,
    395         builder.opts().WithName("C").WithAttr(kXlaScopeAttr, "ScopeC"));
    396     TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
    397   }
    398 
    399   FunctionDefLibrary flib;
    400   FunctionLibraryDefinition flib_def(graph->op_registry(), flib);
    401   TF_ASSERT_OK(
    402       MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def));
    403   auto clusters = GetClusters(*graph);
    404 
    405   // The computation is: C = A + relu(A)
    406   // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC.
    407   // In this case, the GlobalJitLevel overrides the scopes to cluster while
    408   // ignoring scopes.
    409   EXPECT_EQ(3, clusters.size());
    410   EXPECT_EQ(clusters["A"], clusters["B"]);
    411   EXPECT_EQ(clusters["A"], clusters["C"]);
    412 }
    413 
    414 TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) {
    415   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    416   GraphDef graphdef;
    417   {
    418     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
    419     Node* a = ops::SourceOp("Const", builder.opts()
    420                                          .WithName("A")
    421                                          .WithAttr("dtype", DT_FLOAT)
    422                                          .WithAttr("value", Tensor())
    423                                          .WithAttr(kXlaScopeAttr, "ScopeA"));
    424     Node* b = ops::UnaryOp(
    425         "Relu", a,
    426         builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "ScopeB"));
    427     ops::BinaryOp(
    428         "MatMul", a, b,
    429         builder.opts().WithName("C").WithAttr(kXlaScopeAttr, "ScopeC"));
    430     TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
    431   }
    432 
    433   TF_ASSERT_OK(
    434       MarkForCompilationPassTestHelper::MarkForCompilation(&graph, false));
    435   auto clusters = GetClusters(*graph);
    436 
    437   // The computation is: C = A + relu(A)
    438   // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC.
    439   // In this case, we cannot fuse anything, and there are no clusters.
    440   EXPECT_EQ(0, clusters.size());
    441 }
    442 
    443 TEST(XlaCompilationTest, CyclesWithSplittingScopes) {
    444   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    445   GraphDef graphdef;
    446   {
    447     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
    448     Node* a = ops::SourceOp("Const", builder.opts()
    449                                          .WithName("A")
    450                                          .WithAttr("dtype", DT_FLOAT)
    451                                          .WithAttr("value", Tensor())
    452                                          .WithAttr(kXlaCompileAttr, true)
    453                                          .WithAttr(kXlaScopeAttr, "Scope1"));
    454     Node* b = ops::UnaryOp("Relu", a,
    455                            builder.opts()
    456                                .WithName("B")
    457                                .WithAttr(kXlaCompileAttr, true)
    458                                .WithAttr(kXlaScopeAttr, "Scope1"));
    459     Node* c = ops::BinaryOp("MatMul", a, b,
    460                             builder.opts()
    461                                 .WithName("C")
    462                                 .WithAttr(kXlaCompileAttr, true)
    463                                 .WithAttr(kXlaScopeAttr, "Scope2"));
    464     ops::BinaryOp("Add", b, c,
    465                   builder.opts()
    466                       .WithName("D")
    467                       .WithAttr(kXlaCompileAttr, true)
    468                       .WithAttr(kXlaScopeAttr, "Scope2"));
    469     TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
    470   }
    471 
    472   TF_ASSERT_OK(
    473       MarkForCompilationPassTestHelper::MarkForCompilation(&graph, false));
    474   auto clusters = GetClusters(*graph);
    475 
    476   // The computation is: D = relu(A) + (A @ relu(A))
    477   // where A and relu(A) are in Scope1, and the @, + ops are in Scope2.
    478   // In this case, we can fuse the A and relu(A), and we can fuse the
    479   // second half of the operations; there are two clusters.
    480   EXPECT_EQ(4, clusters.size());
    481   EXPECT_EQ(clusters["A"], clusters["B"]);
    482   EXPECT_NE(clusters["A"], clusters["C"]);
    483   EXPECT_EQ(clusters["C"], clusters["D"]);
    484 }
    485 
    486 TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) {
    487   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    488   GraphDef graphdef;
    489   {
    490     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
    491     Node* a = ops::SourceOp("Const", builder.opts()
    492                                          .WithName("A")
    493                                          .WithAttr("dtype", DT_FLOAT)
    494                                          .WithAttr("value", Tensor())
    495                                          .WithAttr(kXlaCompileAttr, true)
    496                                          .WithAttr(kXlaScopeAttr, "ScopeA"));
    497     Node* b = ops::UnaryOp("Relu", a,
    498                            builder.opts()
    499                                .WithName("B")
    500                                .WithAttr(kXlaCompileAttr, true)
    501                                .WithAttr(kXlaScopeAttr, "ScopeB"));
    502     ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
    503     TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
    504   }
    505 
    506   TF_ASSERT_OK(
    507       MarkForCompilationPassTestHelper::MarkForCompilation(&graph, false));
    508   auto clusters = GetClusters(*graph);
    509 
    510   // The computation is: C = A @ relu(A)
    511   // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC.
    512   // In this case, we cannot fuse anything.
    513   EXPECT_EQ(3, clusters.size());
    514   EXPECT_NE(clusters["A"], clusters["B"]);
    515   EXPECT_EQ(clusters["B"], clusters["C"]);
    516 }
    517 
    518 namespace {
    519 Node* MakeRead(const Scope& scope, const string& id,
    520                Node** var_handle_op = nullptr) {
    521   Output var_handle =
    522       ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
    523   Output read =
    524       ops::ReadVariableOp(scope.WithOpName("Read" + id), var_handle, DT_FLOAT);
    525   if (var_handle_op) {
    526     *var_handle_op = var_handle.node();
    527   }
    528   return read.node();
    529 }
    530 
    531 Node* MakeWrite(const Scope& scope, const string& id) {
    532   Output var_handle =
    533       ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
    534   Output value_to_write =
    535       ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f);
    536   ops::AssignVariableOp assign_op(scope.WithOpName("Assignment" + id),
    537                                   var_handle, value_to_write);
    538   return assign_op.operation.node();
    539 }
    540 
    541 Node* MakeNeutral(const Scope& scope, const string& id) {
    542   return ops::Const(scope.WithOpName("Const" + id), 42.0f).node();
    543 }
    544 }  // namespace
    545 
    546 TEST(XlaCompilationTest, ResourcesClusteringAllowed) {
    547   Scope root = Scope::NewRootScope().ExitOnError();
    548 
    549   Node* read = MakeRead(root, "R");
    550   Node* write = MakeWrite(root, "W");
    551 
    552   root.graph()->AddControlEdge(read, write);
    553 
    554   FixupSourceAndSinkEdges(root.graph());
    555   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    556   TF_EXPECT_OK(root.ToGraph(graph.get()));
    557   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
    558   absl::flat_hash_map<string, std::vector<string>> cluster_sets =
    559       GetClusterSets(*graph);
    560   ASSERT_EQ(cluster_sets.size(), 1);
    561   std::vector<string> expected_clustered_nodes = {"AssignmentW", "ReadR",
    562                                                   "ValueToAssignW"};
    563   ASSERT_EQ(cluster_sets.begin()->second, expected_clustered_nodes);
    564 }
    565 
    566 TEST(XlaCompilationTest, ResourcesClusteringDisallowed) {
    567   Scope root = Scope::NewRootScope().ExitOnError();
    568 
    569   Node* read = MakeRead(root, "R");
    570   Node* write = MakeWrite(root, "W");
    571 
    572   root.graph()->AddControlEdge(write, read);
    573 
    574   FixupSourceAndSinkEdges(root.graph());
    575   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    576   TF_EXPECT_OK(root.ToGraph(graph.get()));
    577   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
    578   absl::flat_hash_map<string, std::vector<string>> cluster_sets =
    579       GetClusterSets(*graph);
    580   ASSERT_EQ(cluster_sets.size(), 0);
    581 }
    582 
    583 TEST(XlaCompilationTest, ChainOfOps) {
    584   Scope root = Scope::NewRootScope().ExitOnError();
    585 
    586   Node* write_0 = MakeWrite(root, "W0");
    587   Node* neutral_0 = MakeNeutral(root, "N0");
    588   Node* read_0 = MakeRead(root, "R0");
    589   Node* write_1 = MakeWrite(root, "W1");
    590   Node* neutral_1 = MakeNeutral(root, "N1");
    591   Node* read_1 = MakeRead(root, "R1");
    592 
    593   root.graph()->AddControlEdge(write_0, neutral_0);
    594   root.graph()->AddControlEdge(neutral_0, read_0);
    595   root.graph()->AddControlEdge(read_0, write_1);
    596   root.graph()->AddControlEdge(write_1, neutral_1);
    597   root.graph()->AddControlEdge(neutral_1, read_1);
    598 
    599   FixupSourceAndSinkEdges(root.graph());
    600   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    601   TF_EXPECT_OK(root.ToGraph(graph.get()));
    602   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
    603 
    604   std::vector<string> cluster_names;
    605   absl::flat_hash_map<string, std::vector<string>> cluster_sets =
    606       GetClusterSets(*graph, &cluster_names);
    607 
    608   ASSERT_EQ(cluster_sets.size(), 1);
    609 
    610   std::vector<string> expected_clustered_nodes_a = {
    611       "AssignmentW1", "ConstN1", "ReadR0", "ValueToAssignW1"};
    612   ASSERT_EQ(cluster_sets[cluster_names[0]], expected_clustered_nodes_a);
    613 }
    614 
    615 TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) {
    616   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    617   Scope root = Scope::NewRootScope().ExitOnError();
    618   {
    619     auto BuildNoopNode = [](absl::string_view name, Graph* graph) {
    620       NodeDefBuilder builder(name, "NoOp");
    621       NodeDef def;
    622       TF_CHECK_OK(builder.Finalize(&def));
    623 
    624       Status status;
    625       Node* node = graph->AddNode(def, &status);
    626       TF_CHECK_OK(status);
    627       return node;
    628     };
    629 
    630     Node* a = BuildNoopNode("a", graph.get());
    631     Node* b = BuildNoopNode("b", graph.get());
    632     Node* c = BuildNoopNode("c", graph.get());
    633     graph->AddControlEdge(a, b);
    634     graph->AddControlEdge(b, c);
    635     graph->AddControlEdge(c, a);
    636   }
    637 
    638   TF_EXPECT_OK(root.ToGraph(graph.get()));
    639 
    640   Status status = MarkForCompilationPassTestHelper::MarkForCompilation(&graph);
    641   EXPECT_FALSE(status.ok());
    642   EXPECT_TRUE(absl::StrContains(status.ToString(),
    643                                 "Edge from c to a would create a cycle.\n"
    644                                 "+-> a\n"
    645                                 "|   b\n"
    646                                 "+-- c\n"));
    647 }
    648 
    649 TEST(XlaCompilationTest, Retval) {
    650   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    651   GraphDef graphdef;
    652   {
    653     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
    654     Node* a = ops::SourceOp("Const", builder.opts()
    655                                          .WithName("A")
    656                                          .WithAttr("dtype", DT_FLOAT)
    657                                          .WithAttr("value", Tensor()));
    658     Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
    659     ops::UnaryOp("_Retval", b,
    660                  builder.opts()
    661                      .WithName("R")
    662                      .WithAttr("T", DT_FLOAT)
    663                      .WithAttr("index", 0));
    664 
    665     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
    666   }
    667 
    668   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
    669   auto clusters = GetClusters(*graph);
    670 
    671   EXPECT_TRUE(clusters.empty());
    672 }
    673 
    674 TEST(XlaCompilationTest, DontCountIdentityOps) {
    675   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    676   Scope root = Scope::NewRootScope().ExitOnError();
    677   {
    678     auto a = ops::_Arg(root.WithOpName("A"), DT_INT32, 0);
    679     auto b = ops::Identity(root.WithOpName("B"), a);
    680     auto c = ops::Identity(root.WithOpName("C"), b);
    681     auto r = ops::_Retval(root.WithOpName("R"), c, 0);
    682   }
    683   TF_ASSERT_OK(root.ToGraph(graph.get()));
    684   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
    685   auto clusters = GetClusters(*graph);
    686 
    687   EXPECT_TRUE(clusters.empty());
    688 }
    689 
    690 TEST(XlaCompilationTest, ConstOp) {
    691   // valid data type
    692   {
    693     std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    694     Scope root = Scope::NewRootScope().ExitOnError();
    695     auto c = ops::Const(root.WithOpName("const"), 0.5f);
    696     c.node()->AddAttr(kXlaCompileAttr, true);
    697     TF_ASSERT_OK(root.ToGraph(graph.get()));
    698     TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
    699     EXPECT_EQ(1, GetClusters(*graph).size());
    700   }
    701 
    702   // invalid data type
    703   {
    704     std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    705     Scope root = Scope::NewRootScope().ExitOnError();
    706     auto c = ops::Const(root.WithOpName("const"), string("string"));
    707     c.node()->AddAttr(kXlaCompileAttr, true);
    708     TF_ASSERT_OK(root.ToGraph(graph.get()));
    709     TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
    710     EXPECT_TRUE(GetClusters(*graph).empty());
    711   }
    712 }
    713 
    714 TEST(XlaCompilationTest, DontClusterIdentityWithRefInput) {
    715   Scope root = Scope::NewRootScope().ExitOnError();
    716   Output variable = ops::Variable(root.WithOpName("variable"),
    717                                   PartialTensorShape{}, DT_FLOAT);
    718   Output read = ops::Identity(root.WithOpName("read"), variable);
    719   Output neg = ops::Negate(root.WithOpName("negate"), read);
    720   Output add = ops::Add(root.WithOpName("add"), neg, neg);
    721   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    722 
    723   TF_ASSERT_OK(root.ToGraph(graph.get()));
    724   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
    725 
    726   std::unordered_map<string, string> clusters = GetClusters(*graph);
    727 
    728   ASSERT_FALSE(clusters.empty());
    729   string cluster_name = clusters.begin()->second;
    730 
    731   std::unordered_map<string, string> expected_clusters(
    732       {{"negate", cluster_name}, {"add", cluster_name}});
    733   EXPECT_EQ(clusters, expected_clusters);
    734 }
    735 
    736 TEST(XlaCompilationTest, ClusterIdentityWithNonRefInput) {
    737   Scope root = Scope::NewRootScope().ExitOnError();
    738   Output variable = ops::Variable(root.WithOpName("variable"),
    739                                   PartialTensorShape{}, DT_FLOAT);
    740   Output read = ops::Identity(root.WithOpName("read"), variable);
    741   Output neg = ops::Negate(root.WithOpName("negate"), read);
    742   Output identity = ops::Negate(root.WithOpName("identity"), neg);
    743   Output add = ops::Add(root.WithOpName("add"), identity, neg);
    744   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    745 
    746   TF_ASSERT_OK(root.ToGraph(graph.get()));
    747   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
    748 
    749   std::unordered_map<string, string> clusters = GetClusters(*graph);
    750 
    751   ASSERT_FALSE(clusters.empty());
    752   string cluster_name = clusters.begin()->second;
    753 
    754   std::unordered_map<string, string> expected_clusters(
    755       {{"negate", cluster_name},
    756        {"identity", cluster_name},
    757        {"add", cluster_name}});
    758   EXPECT_EQ(clusters, expected_clusters);
    759 }
    760 
    761 TEST(XlaCompilationTest, ClusterControlTrigger) {
    762   Scope root = Scope::NewRootScope().ExitOnError();
    763 
    764   Output recv_a = ops::_Recv(root.WithOpName("recv_a"), DT_BOOL, "tensor_a",
    765                              "sender", 0, "receiver");
    766   Output recv_b = ops::_Recv(root.WithOpName("recv_b"), DT_BOOL, "tensor_b",
    767                              "sender", 0, "receiver");
    768   Output const_a = ops::Const(root.WithOpName("const_a"), 42);
    769 
    770   ops::ControlTrigger ctrl_trigger_a(root.WithOpName("ctrl_trigger_a"));
    771   ops::ControlTrigger ctrl_trigger_b(root.WithOpName("ctrl_trigger_b"));
    772   root.graph()->AddControlEdge(recv_a.node(), ctrl_trigger_a.operation.node());
    773   root.graph()->AddControlEdge(recv_b.node(), ctrl_trigger_a.operation.node());
    774   root.graph()->AddControlEdge(ctrl_trigger_b.operation.node(), const_a.node());
    775 
    776   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    777 
    778   TF_ASSERT_OK(root.ToGraph(graph.get()));
    779   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
    780 
    781   std::unordered_map<string, string> clusters = GetClusters(*graph);
    782 
    783   // TODO(b/118970344): ctrl_trigger_a has inputs with mismatching deadness so
    784   // it won't be clustered.  ctrl_trigger_b is okay to cluster but we don't
    785   // cluster it because of b/118970344.
    786   EXPECT_TRUE(clusters.empty());
    787 }
    788 
    789 TEST(XlaCompilationTest, RandomShape) {
    790   Scope root = Scope::NewRootScope().ExitOnError();
    791   Output shape_shape = ops::Const(root.WithOpName("shape_shape"), {2}, {1});
    792   Output shape =
    793       ops::RandomUniformInt(root.WithOpName("shape"), shape_shape,
    794                             ops::Const(root.WithOpName("minval"), 1),
    795                             ops::Const(root.WithOpName("maxval"), 20));
    796   Output reshape_input =
    797       ops::Placeholder(root.WithOpName("reshape_input"), DT_FLOAT,
    798                        ops::Placeholder::Shape(TensorShape({500, 500})));
    799   Output reshape =
    800       ops::Reshape(root.WithOpName("reshape"), reshape_input, shape);
    801 
    802   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    803 
    804   TF_ASSERT_OK(root.ToGraph(graph.get()));
    805   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
    806 
    807   std::unordered_map<string, string> clusters = GetClusters(*graph);
    808   EXPECT_EQ(clusters["shape"], "");
    809 }
    810 
    811 TEST(XlaCompilationTest, RandomShapeWithFunc) {
    812   Scope root = Scope::DisabledShapeInferenceScope().ExitOnError();
    813 
    814   FunctionDefLibrary flib_def;
    815   FunctionDef func = FunctionDefHelper::Create(
    816       /*function_name=*/"Stateful_func", /*in_def=*/{},
    817       /*out_def=*/{"out: int32"},
    818       /*attr_def*/
    819       {}, /*node_def=*/
    820       {FunctionDefHelper::Const("shape_shape", 2),
    821        FunctionDefHelper::Const("minval", 1),
    822        FunctionDefHelper::Const("maxval", 20),
    823        {{"shape"},
    824         "RandomUniformInt",
    825         {"shape_shape:output:0", "minval:output:0", "maxval:output:0"},
    826         {{"Tout", DataType::DT_INT32}, {"T", DataType::DT_INT32}}}},
    827       /*ret_def=*/{{"out", "shape:output:0"}});
    828 
    829   func.mutable_signature()->set_is_stateful(true);
    830   *flib_def.add_function() = std::move(func);
    831   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
    832   NodeDef call_node;
    833   call_node.set_name("fn_call");
    834   call_node.set_op("Stateful_func");
    835   Status status;
    836   Node* call = root.graph()->AddNode(call_node, &status);
    837   TF_ASSERT_OK(status);
    838 
    839   Output shape = Output(call, 0);
    840   Output reshape_input =
    841       ops::Placeholder(root.WithOpName("reshape_input"), DT_FLOAT,
    842                        ops::Placeholder::Shape(TensorShape({500, 500})));
    843   Output reshape =
    844       ops::Reshape(root.WithOpName("reshape"), reshape_input, shape);
    845 
    846   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    847   TF_ASSERT_OK(root.ToGraph(graph.get()));
    848   auto fld = absl::make_unique<FunctionLibraryDefinition>(OpRegistry::Global(),
    849                                                           flib_def);
    850   TF_ASSERT_OK(
    851       MarkForCompilationPassTestHelper::MarkForCompilation(&graph, fld.get()));
    852 
    853   std::unordered_map<string, string> clusters = GetClusters(*graph);
    854   EXPECT_EQ(clusters["fn_call"], "");
    855 }
    856 
    857 TEST(XlaCompilationTest, RandomShapeOnXlaDevice) {
    858   absl::string_view xla_gpu_device =
    859       "/job:worker/replica:0/task:0/device:XLA_GPU:0";
    860 
    861   Scope root = Scope::NewRootScope().ExitOnError();
    862   Output shape_shape =
    863       ops::Const(root.WithOpName("test/shape_shape"), {2}, {1});
    864   Output shape =
    865       ops::RandomUniformInt(root.WithOpName("test/shape_rng"), shape_shape,
    866                             ops::Const(root.WithOpName("test/minval"), 1),
    867                             ops::Const(root.WithOpName("test/maxval"), 20));
    868   Output reshape_input =
    869       ops::Placeholder(root.WithOpName("test/reshape_input"), DT_FLOAT,
    870                        ops::Placeholder::Shape(TensorShape({500, 500})));
    871   Output reshape =
    872       ops::Reshape(root.WithOpName("test/reshape"), reshape_input, shape);
    873 
    874   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    875   TF_ASSERT_OK(root.ToGraph(graph.get()));
    876 
    877   for (Node* n : graph->nodes()) {
    878     if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
    879       n->set_assigned_device_name(string(xla_gpu_device));
    880     }
    881   }
    882   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
    883 
    884   std::unordered_map<string, string> clusters = GetClusters(*graph);
    885   EXPECT_EQ(clusters["test/shape_rng"], "");
    886   EXPECT_EQ(clusters["test/reshape"], "");
    887 }
    888 
    889 TEST(XlaCompilationTest, TensorArrayShapeOnXlaDevice) {
    890   absl::string_view xla_gpu_device =
    891       "/job:worker/replica:0/task:0/device:XLA_GPU:0";
    892   Scope root = Scope::NewRootScope().ExitOnError();
    893   ops::TensorArray tensor_array(root.WithOpName("test/tensor_array"), 1,
    894                                 DT_INT32);
    895   Output zero = ops::Const(root.WithOpName("test/zero"), 0);
    896   ops::TensorArrayWrite tensor_array_write(
    897       root.WithOpName("test/write"), tensor_array.handle, zero,
    898       ops::Const(root.WithOpName("test/forty_two"), 42.0f), tensor_array.flow);
    899   Output tensor_array_read =
    900       ops::TensorArrayRead(root.WithOpName("test/read"), tensor_array.handle,
    901                            zero, tensor_array_write.flow_out, DT_INT32);
    902   Output reshape =
    903       ops::Reshape(root.WithOpName("test/reshape"),
    904                    ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT),
    905                    tensor_array_read);
    906 
    907   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    908   TF_ASSERT_OK(root.ToGraph(graph.get()));
    909 
    910   for (Node* n : graph->nodes()) {
    911     if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
    912       n->set_assigned_device_name(string(xla_gpu_device));
    913     }
    914   }
    915   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
    916 
    917   std::unordered_map<string, string> clusters = GetClusters(*graph);
    918   EXPECT_NE(clusters["test/read"], "");
    919   EXPECT_EQ(clusters["test/read"], clusters["test/reshape"]);
    920 }
    921 
    922 TEST(XlaCompilationTest, DontClusterMergingNodes) {
    923   // MatMulCombined below takes data from nodes on GPU0 and GPU1 and is placed
    924   // on GPU1. However, it should not be clustered with the previous node on
    925   // GPU1, because that will serialize production of its inputs that should be
    926   // done in parallel.
    927   //
    928   // This graph is:
    929   // (Const0, Const0) -> MatMul0
    930   // (Const1, Const1) -> MatMul1
    931   // (MatMul0, MatMul1) -> MatMulCombined
    932   //
    933   // Device0: [Const0, Const0, MatMul0]
    934   // Device1: [Const1, Const1, MatMul1, MatMulCombined]
    935   //
    936   // Cluster0: [Const0, Const0, MatMul0]
    937   // Cluster1: [Const1, Const1, MatMul1]
    938   // Cluster2: [MatMulCombined]
    939   Scope root = Scope::NewRootScope().ExitOnError();
    940   absl::string_view xla_gpu_dev0 =
    941       "/job:worker/replica:0/task:0/device:XLA_GPU:0";
    942   absl::string_view xla_gpu_dev1 =
    943       "/job:worker/replica:0/task:0/device:XLA_GPU:1";
    944   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    945   Output a = ops::Tanh(root.WithOpName("tanh_A_dev0"),
    946                        ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2}));
    947   Output b = ops::Tanh(root.WithOpName("tanh_B_dev1"),
    948                        ops::Const(root.WithOpName("B_dev1"), 1.0f, {2, 2}));
    949   Output matmul0 = ops::MatMul(root.WithOpName("MatMul0_dev0"), a, a);
    950   Output matmul1 = ops::MatMul(root.WithOpName("MatMul1_dev1"), b, b);
    951 
    952   Output combined =
    953       ops::MatMul(root.WithOpName("MatMulCombined_dev1"), matmul0, matmul1);
    954   TF_ASSERT_OK(root.ToGraph(graph.get()));
    955 
    956   for (Node* n : graph->nodes()) {
    957     if (absl::EndsWith(n->name(), /*suffix=*/"dev0")) {
    958       n->set_assigned_device_name(string(xla_gpu_dev0));
    959     } else if (absl::EndsWith(n->name(), /*suffix=*/"dev1")) {
    960       n->set_assigned_device_name(string(xla_gpu_dev1));
    961     }
    962   }
    963   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
    964 
    965   // Each of the MatMuls should be in a separate cluster.
    966   std::unordered_map<string, string> clusters = GetClusters(*graph);
    967   EXPECT_NE(clusters["MatMul0_dev0"], clusters["MatMul1_dev1"]);
    968   EXPECT_NE(clusters["MatMulCombined_dev1"], clusters["MatMul0_dev0"]);
    969   EXPECT_NE(clusters["MatMulCombined_dev1"], clusters["MatMul1_dev1"]);
    970   EXPECT_EQ(clusters["A_dev0"], clusters["MatMul0_dev0"]);
    971   EXPECT_EQ(clusters["B_dev1"], clusters["MatMul1_dev1"]);
    972 }
    973 
    974 // TODO(b/117085735): This form of clustering should be prevented.
    975 TEST(XlaCompilationTest, NOT_DontClusterSpreadingNodes) {
    976   // MatMulSource below creates data for nodes on GPU0 and GPU1 and is placed
    977   // on GPU0. However, it should not be clustered with the next node on
    978   // GPU0, because that will prevent the node on GPU1 from beginning its work as
    979   // soon as the data has been produced.
    980   //
    981   // This graph is:
    982   // (Const0, Const0) -> MatMulSource
    983   // MatMulSource -> (MatMul0, MatMul1)
    984   //
    985   // Device0: [Const0, Const1, MatMulSource, MatMul0]
    986   // Device1: [MatMul1]
    987   //
    988   // Cluster0: [Const0, Const1, MatMulSource]
    989   // Cluster1: [MatMul0]
    990   // Cluster2: [MatMul1]
    991   Scope root = Scope::NewRootScope().ExitOnError();
    992   absl::string_view xla_gpu_dev0 =
    993       "/job:worker/replica:0/task:0/device:XLA_GPU:0";
    994   absl::string_view xla_gpu_dev1 =
    995       "/job:worker/replica:0/task:0/device:XLA_GPU:1";
    996   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    997   Output a = ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2});
    998   Output matmul_source =
    999       ops::MatMul(root.WithOpName("MatMulSource_dev0"), a, a);
   1000 
   1001   Output matmul0 = ops::MatMul(root.WithOpName("MatMul0_dev0"), matmul_source,
   1002                                matmul_source);
   1003   Output matmul1 = ops::MatMul(root.WithOpName("MatMul1_dev1"), matmul_source,
   1004                                matmul_source);
   1005 
   1006   TF_ASSERT_OK(root.ToGraph(graph.get()));
   1007   for (Node* n : graph->nodes()) {
   1008     if (absl::EndsWith(n->name(), /*suffix=*/"dev0")) {
   1009       n->set_assigned_device_name(string(xla_gpu_dev0));
   1010     } else if (absl::EndsWith(n->name(), /*suffix=*/"dev1")) {
   1011       n->set_assigned_device_name(string(xla_gpu_dev1));
   1012     }
   1013   }
   1014   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
   1015 
   1016   std::unordered_map<string, string> clusters = GetClusters(*graph);
   1017   EXPECT_EQ(clusters["A_dev0"], clusters["MatMulSource_dev0"]);
   1018   EXPECT_NE(clusters["MatMul0_dev0"], clusters["MatMul1_dev1"]);
   1019   EXPECT_NE(clusters["MatMulSource_dev0"], clusters["MatMul1_dev1"]);
   1020 
   1021   // Improved Heuristics should prevent this probably.
   1022   EXPECT_EQ(clusters["MatMulSource_dev0"], clusters["MatMul0_dev0"]);
   1023 }
   1024 
   1025 TEST(XlaCompilationTest, ClusterStatefulRandomOpOnXlaDevice) {
   1026   absl::string_view xla_cpu_device =
   1027       "/job:worker/replica:0/task:0/device:XLA_CPU:0";
   1028 
   1029   Scope root = Scope::NewRootScope().ExitOnError();
   1030   Output shape = ops::Const(root.WithOpName("test/shape_shape"), {200, 200});
   1031   Output a = ops::RandomUniform(root.WithOpName("test/a"), shape, DT_FLOAT);
   1032   Output b = ops::RandomUniform(root.WithOpName("test/b"), shape, DT_FLOAT);
   1033   Output c = ops::Add(root.WithOpName("test/c"), a, b);
   1034 
   1035   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
   1036   TF_ASSERT_OK(root.ToGraph(graph.get()));
   1037 
   1038   for (Node* n : graph->nodes()) {
   1039     if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
   1040       n->set_assigned_device_name(string(xla_cpu_device));
   1041     }
   1042   }
   1043   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
   1044 
   1045   std::unordered_map<string, string> clusters = GetClusters(*graph);
   1046   EXPECT_NE(clusters["test/a"], "");
   1047   EXPECT_NE(clusters["test/b"], "");
   1048   EXPECT_NE(clusters["test/c"], "");
   1049 }
   1050 
   1051 TEST(XlaCompilationTest, DontAutoClusterStatefulRandomOp) {
   1052   Scope root = Scope::NewRootScope().ExitOnError();
   1053   Output shape = ops::Const(root.WithOpName("test/shape_shape"), {200, 200});
   1054   Output a = ops::RandomUniform(root.WithOpName("test/a"), shape, DT_FLOAT);
   1055   Output b = ops::RandomUniform(root.WithOpName("test/b"), shape, DT_FLOAT);
   1056   Output c = ops::Add(root.WithOpName("test/c"), a, b);
   1057 
   1058   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
   1059   TF_ASSERT_OK(root.ToGraph(graph.get()));
   1060 
   1061   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
   1062 
   1063   std::unordered_map<string, string> clusters = GetClusters(*graph);
   1064   EXPECT_EQ(clusters["test/a"], "");
   1065   EXPECT_EQ(clusters["test/b"], "");
   1066 }
   1067 
   1068 TEST(XlaCompilationTest, ClusterDummyOpsOnXlaDevice) {
   1069   absl::string_view xla_cpu_device =
   1070       "/job:worker/replica:0/task:0/device:XLA_CPU:0";
   1071 
   1072   Scope root = Scope::NewRootScope().ExitOnError();
   1073   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
   1074   Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
   1075   Output check =
   1076       ops::CheckNumerics(root.WithOpName("test/check"), a, "test/check");
   1077   Output ge = ops::GreaterEqual(root.WithOpName("test/greaterequal"), check, b);
   1078   Operation assert = ops::Assert(root.WithOpName("test/assert"), ge, {a, b});
   1079 
   1080   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
   1081   TF_ASSERT_OK(root.ToGraph(graph.get()));
   1082 
   1083   for (Node* n : graph->nodes()) {
   1084     if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
   1085       n->set_assigned_device_name(string(xla_cpu_device));
   1086     }
   1087   }
   1088   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
   1089 
   1090   std::unordered_map<string, string> clusters = GetClusters(*graph);
   1091   EXPECT_NE(clusters["test/check"], "");
   1092   EXPECT_NE(clusters["test/greaterequal"], "");
   1093   EXPECT_NE(clusters["test/assert"], "");
   1094 }
   1095 
   1096 TEST(XlaCompilationTest, DontAutoClusterDummyOps) {
   1097   Scope root = Scope::NewRootScope().ExitOnError();
   1098   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
   1099   Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
   1100   Output check =
   1101       ops::CheckNumerics(root.WithOpName("test/check"), a, "test/check");
   1102   Output ge = ops::GreaterEqual(root.WithOpName("test/greaterequal"), check, b);
   1103   Operation assert = ops::Assert(root.WithOpName("test/assert"), ge, {a, b});
   1104 
   1105   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
   1106   TF_ASSERT_OK(root.ToGraph(graph.get()));
   1107 
   1108   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
   1109 
   1110   std::unordered_map<string, string> clusters = GetClusters(*graph);
   1111   EXPECT_EQ(clusters["test/assert"], "");
   1112   EXPECT_EQ(clusters["test/check"], "");
   1113 }
   1114 
   1115 TEST(XlaCompilationTest, DontAutoClusterOpsProducingVariant) {
   1116   Scope root = Scope::NewRootScope().ExitOnError();
   1117   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_INT64);
   1118   Output b = ops::Placeholder(root.WithOpName("test/b"), DT_INT64);
   1119 
   1120   Output cast_a = ops::Cast(root.WithOpName("test/cast_a"), a, DT_INT32);
   1121   Output cast_b = ops::Cast(root.WithOpName("test/cast_b"), b, DT_INT32);
   1122 
   1123   Output tensor_list_reserve = ops::TensorListReserve(
   1124       root.WithOpName("test/tensor_list_reserve"), cast_a, cast_b, DT_FLOAT);
   1125 
   1126   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
   1127   TF_ASSERT_OK(root.ToGraph(graph.get()));
   1128 
   1129   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
   1130 
   1131   std::unordered_map<string, string> clusters = GetClusters(*graph);
   1132   EXPECT_EQ(clusters["test/tensor_list_reserve"], "");
   1133 }
   1134 
   1135 TEST(XlaCompilationTest, DontAutoClusterOpsConsumingVariant) {
   1136   Scope root = Scope::NewRootScope().ExitOnError();
   1137   Output dummy_input =
   1138       ops::Placeholder(root.WithOpName("test/dummy_input"), DT_INT64);
   1139   Output variant_input =
   1140       ops::Placeholder(root.WithOpName("test/variant_input"), DT_VARIANT);
   1141 
   1142   // Create one more node so that we don't avoid creating a cluster solely
   1143   // because it would be trivial.
   1144   Output dummy_cast =
   1145       ops::Cast(root.WithOpName("test/dummy_cast"), dummy_input, DT_INT32);
   1146 
   1147   Output tensor_list_element_shape = ops::TensorListElementShape(
   1148       root.WithOpName("test/tensor_list_element_shape"), variant_input,
   1149       DT_INT32);
   1150 
   1151   root.graph()->AddControlEdge(dummy_cast.node(),
   1152                                tensor_list_element_shape.node());
   1153 
   1154   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
   1155   TF_ASSERT_OK(root.ToGraph(graph.get()));
   1156 
   1157   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
   1158 
   1159   std::unordered_map<string, string> clusters = GetClusters(*graph);
   1160   EXPECT_EQ(clusters["test/tensor_list_element_shape"], "");
   1161 }
   1162 
   1163 TEST(XlaCompilationTest, ClusterOpsProducingVariantIfOnXlaDevice) {
   1164   Scope root = Scope::NewRootScope().ExitOnError();
   1165   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_INT64);
   1166   Output b = ops::Placeholder(root.WithOpName("test/b"), DT_INT64);
   1167 
   1168   Output cast_a = ops::Cast(root.WithOpName("test/cast_a"), a, DT_INT32);
   1169   Output cast_b = ops::Cast(root.WithOpName("test/cast_b"), b, DT_INT32);
   1170 
   1171   Output tensor_list_reserve = ops::TensorListReserve(
   1172       root.WithOpName("test/tensor_list_reserve"), cast_a, cast_b, DT_FLOAT);
   1173 
   1174   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
   1175   TF_ASSERT_OK(root.ToGraph(graph.get()));
   1176 
   1177   string xla_cpu_device = "/job:worker/replica:0/task:0/device:XLA_CPU:0";
   1178   for (Node* n : graph->nodes()) {
   1179     if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
   1180       n->set_assigned_device_name(xla_cpu_device);
   1181     }
   1182   }
   1183 
   1184   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
   1185 
   1186   std::unordered_map<string, string> clusters = GetClusters(*graph);
   1187   EXPECT_NE(clusters["test/tensor_list_reserve"], "");
   1188 }
   1189 
   1190 const char* kCPU0 = "/job:worker/replica:0/task:0/device:CPU:0";
   1191 const char* kGPU0 = "/job:worker/replica:0/task:0/device:GPU:0";
   1192 const char* kXLA_GPU0 = "/job:worker/replica:0/task:0/device:XLA_GPU:0";
   1193 const char* kGPU1 = "/job:worker/replica:0/task:0/device:GPU:1";
   1194 
   1195 TEST(XlaCompilationTest, CreateCombinedCpuGpuClusters) {
   1196   Scope root = Scope::NewRootScope().ExitOnError();
   1197   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
   1198   Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
   1199 
   1200   Output x = ops::Add(root.WithOpName("test/x"), a, b);
   1201   Output y = ops::MatMul(root.WithOpName("test/y"), a, b);
   1202   Output z = ops::Add(root.WithOpName("test/z"), x, y);
   1203 
   1204   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
   1205   TF_ASSERT_OK(root.ToGraph(graph.get()));
   1206 
   1207   FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0);
   1208   FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kCPU0);
   1209   FindNodeByName(graph.get(), "test/z")->set_assigned_device_name(kGPU0);
   1210 
   1211   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
   1212 
   1213   std::unordered_map<string, string> clusters = GetClusters(*graph);
   1214 
   1215   EXPECT_NE(clusters["test/x"], "");
   1216 
   1217   EXPECT_EQ(clusters["test/x"], clusters["test/y"]);
   1218   EXPECT_EQ(clusters["test/y"], clusters["test/z"]);
   1219 }
   1220 
   1221 TEST(XlaCompilationTest, DontCreateGpu0AndGpu1Clusters) {
   1222   Scope root = Scope::NewRootScope().ExitOnError();
   1223   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
   1224   Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
   1225 
   1226   Output x = ops::Add(root.WithOpName("test/x"), a, b);
   1227   Output y = ops::Add(root.WithOpName("test/y"), x, x);
   1228 
   1229   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
   1230   TF_ASSERT_OK(root.ToGraph(graph.get()));
   1231 
   1232   FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0);
   1233   FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kGPU1);
   1234 
   1235   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
   1236 
   1237   std::unordered_map<string, string> clusters = GetClusters(*graph);
   1238 
   1239   EXPECT_EQ(clusters["test/x"], "");
   1240   EXPECT_EQ(clusters["test/y"], "");
   1241 }
   1242 
   1243 TEST(XlaCompilationTest, DontCreateCombinedCpuUnknownClusters) {
   1244   Scope root = Scope::NewRootScope().ExitOnError();
   1245   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
   1246   Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
   1247 
   1248   Output x = ops::Add(root.WithOpName("test/x"), a, b);
   1249   Output y = ops::Add(root.WithOpName("test/y"), x, x);
   1250 
   1251   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
   1252   TF_ASSERT_OK(root.ToGraph(graph.get()));
   1253 
   1254   FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kCPU0);
   1255   FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kXLA_GPU0);
   1256 
   1257   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
   1258 
   1259   std::unordered_map<string, string> clusters = GetClusters(*graph);
   1260 
   1261   EXPECT_EQ(clusters["test/x"], "");
   1262   EXPECT_EQ(clusters["test/y"], "");
   1263 }
   1264 
   1265 TEST(XlaCompilationTest, ClusterResourceOpsWhenSafe) {
   1266   Scope root = Scope::NewRootScope().ExitOnError();
   1267   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
   1268   Node* var_handle;
   1269   Node* resource_read = MakeRead(root, "read", &var_handle);
   1270   Output b = ops::Add(root.WithOpName("test/b"), Output(resource_read, 0), a);
   1271 
   1272   string resource_read_name = resource_read->name();
   1273   string var_handle_name = var_handle->name();
   1274 
   1275   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
   1276   TF_ASSERT_OK(root.ToGraph(graph.get()));
   1277 
   1278   FindNodeByName(graph.get(), "test/b")->set_assigned_device_name(kCPU0);
   1279   FindNodeByName(graph.get(), resource_read_name)
   1280       ->set_assigned_device_name(kGPU0);
   1281   FindNodeByName(graph.get(), var_handle_name)->set_assigned_device_name(kGPU0);
   1282 
   1283   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
   1284 
   1285   std::unordered_map<string, string> clusters = GetClusters(*graph);
   1286 
   1287   EXPECT_NE(clusters["test/b"], "");
   1288   EXPECT_EQ(clusters["test/b"], clusters[resource_read_name]);
   1289 }
   1290 
   1291 TEST(XlaCompilationTest, DontClusterResourceOpsWhenUnsafe) {
   1292   Scope root = Scope::NewRootScope().ExitOnError();
   1293   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
   1294   Node* var_handle;
   1295   Node* resource_read = MakeRead(root, "read", &var_handle);
   1296   Output b = ops::Add(root.WithOpName("test/b"), Output(resource_read, 0), a);
   1297 
   1298   string resource_read_name = resource_read->name();
   1299   string var_handle_name = var_handle->name();
   1300 
   1301   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
   1302   TF_ASSERT_OK(root.ToGraph(graph.get()));
   1303 
   1304   FindNodeByName(graph.get(), "test/b")->set_assigned_device_name(kGPU0);
   1305   FindNodeByName(graph.get(), resource_read_name)
   1306       ->set_assigned_device_name(kCPU0);
   1307   FindNodeByName(graph.get(), var_handle_name)->set_assigned_device_name(kCPU0);
   1308 
   1309   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
   1310 
   1311   std::unordered_map<string, string> clusters = GetClusters(*graph);
   1312 
   1313   EXPECT_EQ(clusters["test/b"], "");
   1314   EXPECT_EQ(clusters[resource_read_name], "");
   1315 }
   1316 
   1317 }  // namespace
   1318 }  // namespace tensorflow
   1319