1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h" 17 18 #include "absl/container/flat_hash_map.h" 19 #include "absl/memory/memory.h" 20 #include "absl/strings/match.h" 21 #include "tensorflow/cc/framework/ops.h" 22 #include "tensorflow/cc/ops/array_ops.h" 23 #include "tensorflow/cc/ops/control_flow_ops_internal.h" 24 #include "tensorflow/cc/ops/function_ops.h" 25 #include "tensorflow/cc/ops/list_ops.h" 26 #include "tensorflow/cc/ops/resource_variable_ops.h" 27 #include "tensorflow/cc/ops/sendrecv_ops.h" 28 #include "tensorflow/cc/ops/standard_ops.h" 29 #include "tensorflow/compiler/jit/defs.h" 30 #include "tensorflow/compiler/jit/node_matchers.h" 31 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 32 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 33 #include "tensorflow/core/framework/node_def_util.h" 34 #include "tensorflow/core/framework/op.h" 35 #include "tensorflow/core/graph/algorithm.h" 36 #include "tensorflow/core/graph/graph_constructor.h" 37 #include "tensorflow/core/graph/graph_def_builder.h" 38 #include "tensorflow/core/graph/graph_def_builder_util.h" 39 #include "tensorflow/core/lib/core/status_test_util.h" 40 #include "tensorflow/core/platform/test.h" 41 42 using ::tensorflow::testing::FindNodeByName; 43 44 namespace tensorflow { 45 namespace { 46 47 REGISTER_OP("UncompilableNullary").Output("o: float"); 48 REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float"); 49 50 std::unordered_map<string, string> GetClusters(const Graph& graph) { 51 std::unordered_map<string, string> ids; 52 for (Node* node : graph.nodes()) { 53 string cluster; 54 if (GetNodeAttr(node->attrs(), kXlaClusterAttr, &cluster).ok()) { 55 CHECK(!cluster.empty()); 56 ids[node->name()] = cluster; 57 } 58 } 59 60 if (VLOG_IS_ON(2)) { 61 VLOG(2) << "Clusters:"; 62 for (const auto& p : ids) { 63 VLOG(2) << " " << p.first << " -> " << p.second; 64 } 65 } 66 return ids; 67 } 68 69 absl::flat_hash_map<string, std::vector<string>> GetClusterSets( 70 const Graph& g, std::vector<string>* cluster_names = nullptr) { 71 CHECK(cluster_names == nullptr || cluster_names->empty()); 72 absl::flat_hash_map<string, std::vector<string>> cluster_sets; 73 for (const auto& p : GetClusters(g)) { 74 cluster_sets[p.second].push_back(p.first); 75 } 76 for (auto& p : cluster_sets) { 77 if (cluster_names != nullptr) { 78 cluster_names->push_back(p.first); 79 } 80 std::sort(p.second.begin(), p.second.end()); 81 } 82 if (cluster_names != nullptr) { 83 std::sort(cluster_names->begin(), cluster_names->end()); 84 } 85 return cluster_sets; 86 } 87 88 TEST(XlaCompilationTest, Chains) { 89 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 90 GraphDef graphdef; 91 { 92 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); 93 Node* a = 94 ops::SourceOp("UncompilableNullary", builder.opts().WithName("A")); 95 Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); 96 Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C")); 97 Node* d = 98 ops::UnaryOp("UncompilableUnary", c, builder.opts().WithName("D")); 99 Node* e = ops::UnaryOp("Relu", d, builder.opts().WithName("E")); 100 ops::UnaryOp("Relu", e, builder.opts().WithName("F")); 101 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); 102 } 103 104 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); 105 auto clusters = GetClusters(*graph); 106 EXPECT_EQ(4, clusters.size()); 107 EXPECT_EQ(clusters["B"], clusters["C"]); 108 EXPECT_EQ(clusters["E"], clusters["F"]); 109 EXPECT_NE(clusters["B"], clusters["E"]); 110 EXPECT_TRUE(clusters.find("A") == clusters.cend()); 111 EXPECT_TRUE(clusters.find("D") == clusters.cend()); 112 } 113 114 TEST(XlaCompilationTest, UncompilableCycles) { 115 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 116 GraphDef graphdef; 117 { 118 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); 119 Node* a = ops::SourceOp("Const", builder.opts() 120 .WithName("A") 121 .WithAttr("dtype", DT_FLOAT) 122 .WithAttr("value", Tensor())); 123 Node* b = 124 ops::UnaryOp("UncompilableUnary", a, builder.opts().WithName("B")); 125 ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); 126 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); 127 } 128 129 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); 130 auto clusters = GetClusters(*graph); 131 132 EXPECT_TRUE(clusters.empty()); 133 } 134 135 TEST(XlaCompilationTest, CompilableCycles) { 136 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 137 GraphDef graphdef; 138 { 139 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); 140 Node* a = ops::SourceOp("Const", builder.opts() 141 .WithName("A") 142 .WithAttr("dtype", DT_FLOAT) 143 .WithAttr("value", Tensor())); 144 Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); 145 ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); 146 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); 147 } 148 149 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); 150 auto clusters = GetClusters(*graph); 151 152 EXPECT_EQ(3, clusters.size()); 153 EXPECT_EQ(clusters["A"], clusters["B"]); 154 EXPECT_EQ(clusters["A"], clusters["C"]); 155 } 156 157 TEST(XlaCompilationTest, StringUnsupported) { 158 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 159 GraphDef graphdef; 160 { 161 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); 162 Node* a = ops::SourceOp( 163 "Const", builder.opts() 164 .WithName("A") 165 .WithAttr("dtype", DT_STRING) 166 .WithAttr("value", Tensor(DT_STRING, TensorShape()))); 167 Node* b = ops::UnaryOp("EncodeBase64", a, builder.opts().WithName("B")); 168 ops::BinaryOp("StringSplit", a, b, builder.opts().WithName("C")); 169 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); 170 } 171 172 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); 173 auto clusters = GetClusters(*graph); 174 EXPECT_TRUE(clusters.empty()); 175 } 176 177 TEST(XlaCompilationTest, HalfSupported) { 178 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 179 GraphDef graphdef; 180 { 181 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); 182 Tensor t(DT_HALF, TensorShape()); 183 t.scalar<Eigen::half>()() = static_cast<Eigen::half>(0.0f); 184 Node* a = ops::SourceOp("Const", builder.opts() 185 .WithName("A") 186 .WithAttr("dtype", DT_HALF) 187 .WithAttr("value", t)); 188 Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B")); 189 ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); 190 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); 191 } 192 193 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); 194 auto clusters = GetClusters(*graph); 195 EXPECT_FALSE(clusters.empty()); 196 } 197 198 TEST(XlaCompilationTest, FunctionCalls) { 199 FunctionDef compilable = FunctionDefHelper::Define( 200 "CompilableFn", {"n_a:float", "n_b:float"}, {"n_c:float"}, {}, 201 {{{"n_c"}, "Add", {"n_a", "n_b"}, {{"T", DT_FLOAT}}}}); 202 FunctionDef uncompilable = 203 FunctionDefHelper::Define("UncompilableFn", {"n_a:float"}, {"n_c:float"}, 204 {}, {{{"n_c"}, "UncompilableUnary", {"n_a"}}}); 205 FunctionDef noinline = compilable; 206 noinline.mutable_signature()->set_name("NoInlineFn"); 207 AddAttr("_noinline", static_cast<bool>(true), noinline.mutable_attr()); 208 209 FunctionDefLibrary flib; 210 *flib.add_function() = compilable; 211 *flib.add_function() = uncompilable; 212 *flib.add_function() = noinline; 213 FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); 214 215 std::unique_ptr<Graph> graph(new Graph(&flib_def)); 216 GraphDef graphdef; 217 { 218 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately, &flib_def); 219 Node* a = 220 ops::SourceOp("UncompilableNullary", builder.opts().WithName("A")); 221 Node* b = ops::BinaryOp("CompilableFn", a, a, builder.opts().WithName("B")); 222 Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C")); 223 ops::UnaryOp("UncompilableFn", c, builder.opts().WithName("D")); 224 ops::BinaryOp("NoInlineFn", c, c, builder.opts().WithName("E")); 225 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); 226 } 227 228 TF_ASSERT_OK( 229 MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def)); 230 auto clusters = GetClusters(*graph); 231 232 EXPECT_EQ(2, clusters.size()); 233 EXPECT_FALSE(clusters["B"].empty()); 234 EXPECT_EQ(clusters["B"], clusters["C"]); 235 EXPECT_TRUE(clusters.find("A") == clusters.cend()); 236 EXPECT_TRUE(clusters.find("D") == clusters.cend()); 237 EXPECT_TRUE(clusters.find("E") == clusters.cend()); 238 } 239 240 // Metadata-only operators such as Shape/Rank/Size may not be the root of a 241 // cluster. This is partially to work around b/26800664, and partially because 242 // we should probably prefer to compile metadata operators with their producers 243 // wherever possible, rather than their consumers. 244 TEST(XlaCompilationTest, MetadataOpsDontStartClusters) { 245 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 246 GraphDef graphdef; 247 { 248 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); 249 Node* a = 250 ops::SourceOp("UncompilableNullary", builder.opts().WithName("A")); 251 // While all of the following ops are notionally compilable, none is 252 // permitted 253 // to start a cluster. So nothing should be compiled. 254 Node* b = ops::UnaryOp("Shape", a, builder.opts().WithName("B")); 255 Node* c = ops::UnaryOp("Rank", b, builder.opts().WithName("C")); 256 Node* d = ops::UnaryOp("Size", c, builder.opts().WithName("D")); 257 ops::UnaryOp("Shape", d, builder.opts().WithName("E")); 258 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); 259 } 260 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); 261 auto clusters = GetClusters(*graph); 262 EXPECT_EQ(0, clusters.size()); // Nothing should be compiled. 263 } 264 265 static Status GradForUnaryCwise(FunctionDef* g, 266 std::vector<FunctionDefHelper::Node> nodes) { 267 for (auto& n : nodes) { 268 if (n.attr.empty()) { 269 n.attr = {{"T", DT_FLOAT}}; 270 } 271 } 272 *g = FunctionDefHelper::Define( 273 // Arg defs 274 {"x: float", "dy: float"}, 275 // Ret val defs 276 {"dx: float"}, 277 // Attr defs 278 {}, 279 // Nodes 280 nodes); 281 return Status::OK(); 282 } 283 284 // A gradient containing only supported operators 285 Status SupportedGrad(const AttrSlice& attrs, FunctionDef* g) { 286 // clang-format off 287 return GradForUnaryCwise(g, { 288 {{"y"}, "Tanh", {"x"}}, 289 {{"y2"}, "Square", {"y"}, {}, {"dy"}}, 290 FunctionDefHelper::Const("one", 1.0f), 291 {{"a"}, "Sub", {"one", "y2"}}, 292 {{"dx"}, "Mul", {"dy", "a"}}, 293 }); 294 // clang-format on 295 } 296 REGISTER_OP_GRADIENT("Supported", SupportedGrad); 297 298 // A gradient containing an unsupported operator. 299 Status UnsupportedGrad(const AttrSlice& attrs, FunctionDef* g) { 300 // clang-format off 301 return GradForUnaryCwise(g, { 302 {{"y"}, "Tanh", {"x"}}, 303 {{"y2"}, "UncompilableUnary", {"y"}, {}, {"dy"}}, 304 FunctionDefHelper::Const("one", 1.0f), 305 {{"a"}, "Sub", {"one", "y2"}}, 306 {{"dx"}, "Mul", {"dy", "a"}}, 307 }); 308 // clang-format on 309 } 310 REGISTER_OP_GRADIENT("Unsupported", UnsupportedGrad); 311 312 TEST(XlaCompilationTest, SymbolicGradients) { 313 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 314 GraphDef graphdef; 315 { 316 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); 317 Node* a = 318 ops::SourceOp("UncompilableNullary", builder.opts().WithName("A")); 319 320 // Builds a Symbolic gradient for Supported 321 NodeBuilder b_builder("B", "SymbolicGradient", 322 builder.opts().op_registry()); 323 NameAttrList b_name_attr; 324 b_name_attr.set_name("Supported"); 325 b_builder.Attr("f", b_name_attr); 326 b_builder.Attr("Tin", {DT_FLOAT, DT_FLOAT}); 327 b_builder.Attr("Tout", {DT_FLOAT}); 328 b_builder.Input({a, a}); 329 Node* b = builder.opts().FinalizeBuilder(&b_builder); 330 331 Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C")); 332 333 // Builds a Symbolic gradient for Unsupported 334 NodeBuilder d_builder("D", "SymbolicGradient", 335 builder.opts().op_registry()); 336 NameAttrList d_name_attr; 337 d_name_attr.set_name("Unsupported"); 338 d_builder.Attr("f", d_name_attr); 339 d_builder.Attr("Tin", {DT_FLOAT, DT_FLOAT}); 340 d_builder.Attr("Tout", {DT_FLOAT}); 341 d_builder.Input({c, c}); 342 builder.opts().FinalizeBuilder(&d_builder); 343 344 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); 345 } 346 347 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); 348 auto clusters = GetClusters(*graph); 349 350 EXPECT_EQ(2, clusters.size()); 351 EXPECT_FALSE(clusters["B"].empty()); 352 EXPECT_EQ(clusters["B"], clusters["C"]); 353 EXPECT_TRUE(clusters.find("A") == clusters.cend()); 354 EXPECT_TRUE(clusters.find("D") == clusters.cend()); 355 } 356 357 TEST(XlaCompilationTest, Loops) { 358 // Regression test for b/32350199, where the autoclustering code introduced a 359 // deadlock in a graph containing a while loop. 360 Scope root = Scope::NewRootScope().ExitOnError(); 361 auto a = ops::Placeholder(root.WithOpName("A"), DT_FLOAT); 362 auto b = ops::Placeholder(root.WithOpName("B"), DT_FLOAT); 363 auto c = ops::Add(root.WithOpName("C"), a, b); 364 auto enter = ops::internal::Enter(root, c, "aframe"); 365 auto next_iter = ops::NextIteration(root, enter); 366 auto exit = ops::internal::Exit(root, next_iter); 367 auto d = ops::Add(root.WithOpName("D"), c, exit); 368 369 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 370 TF_EXPECT_OK(root.ToGraph(graph.get())); 371 372 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); 373 auto clusters = GetClusters(*graph); 374 375 // Nothing should be compiled. In particular, 'd' and 'c' must not be 376 // compiled. 377 EXPECT_EQ(0, clusters.size()); 378 } 379 380 TEST(XlaCompilationTest, CyclesWithAllDifferentScopesGlobalJitOverridden) { 381 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 382 GraphDef graphdef; 383 { 384 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); 385 Node* a = ops::SourceOp("Const", builder.opts() 386 .WithName("A") 387 .WithAttr("dtype", DT_FLOAT) 388 .WithAttr("value", Tensor()) 389 .WithAttr(kXlaScopeAttr, "ScopeA")); 390 Node* b = ops::UnaryOp( 391 "Relu", a, 392 builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "ScopeB")); 393 ops::BinaryOp( 394 "MatMul", a, b, 395 builder.opts().WithName("C").WithAttr(kXlaScopeAttr, "ScopeC")); 396 TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); 397 } 398 399 FunctionDefLibrary flib; 400 FunctionLibraryDefinition flib_def(graph->op_registry(), flib); 401 TF_ASSERT_OK( 402 MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def)); 403 auto clusters = GetClusters(*graph); 404 405 // The computation is: C = A + relu(A) 406 // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC. 407 // In this case, the GlobalJitLevel overrides the scopes to cluster while 408 // ignoring scopes. 409 EXPECT_EQ(3, clusters.size()); 410 EXPECT_EQ(clusters["A"], clusters["B"]); 411 EXPECT_EQ(clusters["A"], clusters["C"]); 412 } 413 414 TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) { 415 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 416 GraphDef graphdef; 417 { 418 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); 419 Node* a = ops::SourceOp("Const", builder.opts() 420 .WithName("A") 421 .WithAttr("dtype", DT_FLOAT) 422 .WithAttr("value", Tensor()) 423 .WithAttr(kXlaScopeAttr, "ScopeA")); 424 Node* b = ops::UnaryOp( 425 "Relu", a, 426 builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "ScopeB")); 427 ops::BinaryOp( 428 "MatMul", a, b, 429 builder.opts().WithName("C").WithAttr(kXlaScopeAttr, "ScopeC")); 430 TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); 431 } 432 433 TF_ASSERT_OK( 434 MarkForCompilationPassTestHelper::MarkForCompilation(&graph, false)); 435 auto clusters = GetClusters(*graph); 436 437 // The computation is: C = A + relu(A) 438 // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC. 439 // In this case, we cannot fuse anything, and there are no clusters. 440 EXPECT_EQ(0, clusters.size()); 441 } 442 443 TEST(XlaCompilationTest, CyclesWithSplittingScopes) { 444 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 445 GraphDef graphdef; 446 { 447 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); 448 Node* a = ops::SourceOp("Const", builder.opts() 449 .WithName("A") 450 .WithAttr("dtype", DT_FLOAT) 451 .WithAttr("value", Tensor()) 452 .WithAttr(kXlaCompileAttr, true) 453 .WithAttr(kXlaScopeAttr, "Scope1")); 454 Node* b = ops::UnaryOp("Relu", a, 455 builder.opts() 456 .WithName("B") 457 .WithAttr(kXlaCompileAttr, true) 458 .WithAttr(kXlaScopeAttr, "Scope1")); 459 Node* c = ops::BinaryOp("MatMul", a, b, 460 builder.opts() 461 .WithName("C") 462 .WithAttr(kXlaCompileAttr, true) 463 .WithAttr(kXlaScopeAttr, "Scope2")); 464 ops::BinaryOp("Add", b, c, 465 builder.opts() 466 .WithName("D") 467 .WithAttr(kXlaCompileAttr, true) 468 .WithAttr(kXlaScopeAttr, "Scope2")); 469 TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); 470 } 471 472 TF_ASSERT_OK( 473 MarkForCompilationPassTestHelper::MarkForCompilation(&graph, false)); 474 auto clusters = GetClusters(*graph); 475 476 // The computation is: D = relu(A) + (A @ relu(A)) 477 // where A and relu(A) are in Scope1, and the @, + ops are in Scope2. 478 // In this case, we can fuse the A and relu(A), and we can fuse the 479 // second half of the operations; there are two clusters. 480 EXPECT_EQ(4, clusters.size()); 481 EXPECT_EQ(clusters["A"], clusters["B"]); 482 EXPECT_NE(clusters["A"], clusters["C"]); 483 EXPECT_EQ(clusters["C"], clusters["D"]); 484 } 485 486 TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) { 487 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 488 GraphDef graphdef; 489 { 490 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); 491 Node* a = ops::SourceOp("Const", builder.opts() 492 .WithName("A") 493 .WithAttr("dtype", DT_FLOAT) 494 .WithAttr("value", Tensor()) 495 .WithAttr(kXlaCompileAttr, true) 496 .WithAttr(kXlaScopeAttr, "ScopeA")); 497 Node* b = ops::UnaryOp("Relu", a, 498 builder.opts() 499 .WithName("B") 500 .WithAttr(kXlaCompileAttr, true) 501 .WithAttr(kXlaScopeAttr, "ScopeB")); 502 ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); 503 TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); 504 } 505 506 TF_ASSERT_OK( 507 MarkForCompilationPassTestHelper::MarkForCompilation(&graph, false)); 508 auto clusters = GetClusters(*graph); 509 510 // The computation is: C = A @ relu(A) 511 // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC. 512 // In this case, we cannot fuse anything. 513 EXPECT_EQ(3, clusters.size()); 514 EXPECT_NE(clusters["A"], clusters["B"]); 515 EXPECT_EQ(clusters["B"], clusters["C"]); 516 } 517 518 namespace { 519 Node* MakeRead(const Scope& scope, const string& id, 520 Node** var_handle_op = nullptr) { 521 Output var_handle = 522 ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); 523 Output read = 524 ops::ReadVariableOp(scope.WithOpName("Read" + id), var_handle, DT_FLOAT); 525 if (var_handle_op) { 526 *var_handle_op = var_handle.node(); 527 } 528 return read.node(); 529 } 530 531 Node* MakeWrite(const Scope& scope, const string& id) { 532 Output var_handle = 533 ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); 534 Output value_to_write = 535 ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f); 536 ops::AssignVariableOp assign_op(scope.WithOpName("Assignment" + id), 537 var_handle, value_to_write); 538 return assign_op.operation.node(); 539 } 540 541 Node* MakeNeutral(const Scope& scope, const string& id) { 542 return ops::Const(scope.WithOpName("Const" + id), 42.0f).node(); 543 } 544 } // namespace 545 546 TEST(XlaCompilationTest, ResourcesClusteringAllowed) { 547 Scope root = Scope::NewRootScope().ExitOnError(); 548 549 Node* read = MakeRead(root, "R"); 550 Node* write = MakeWrite(root, "W"); 551 552 root.graph()->AddControlEdge(read, write); 553 554 FixupSourceAndSinkEdges(root.graph()); 555 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 556 TF_EXPECT_OK(root.ToGraph(graph.get())); 557 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); 558 absl::flat_hash_map<string, std::vector<string>> cluster_sets = 559 GetClusterSets(*graph); 560 ASSERT_EQ(cluster_sets.size(), 1); 561 std::vector<string> expected_clustered_nodes = {"AssignmentW", "ReadR", 562 "ValueToAssignW"}; 563 ASSERT_EQ(cluster_sets.begin()->second, expected_clustered_nodes); 564 } 565 566 TEST(XlaCompilationTest, ResourcesClusteringDisallowed) { 567 Scope root = Scope::NewRootScope().ExitOnError(); 568 569 Node* read = MakeRead(root, "R"); 570 Node* write = MakeWrite(root, "W"); 571 572 root.graph()->AddControlEdge(write, read); 573 574 FixupSourceAndSinkEdges(root.graph()); 575 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 576 TF_EXPECT_OK(root.ToGraph(graph.get())); 577 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); 578 absl::flat_hash_map<string, std::vector<string>> cluster_sets = 579 GetClusterSets(*graph); 580 ASSERT_EQ(cluster_sets.size(), 0); 581 } 582 583 TEST(XlaCompilationTest, ChainOfOps) { 584 Scope root = Scope::NewRootScope().ExitOnError(); 585 586 Node* write_0 = MakeWrite(root, "W0"); 587 Node* neutral_0 = MakeNeutral(root, "N0"); 588 Node* read_0 = MakeRead(root, "R0"); 589 Node* write_1 = MakeWrite(root, "W1"); 590 Node* neutral_1 = MakeNeutral(root, "N1"); 591 Node* read_1 = MakeRead(root, "R1"); 592 593 root.graph()->AddControlEdge(write_0, neutral_0); 594 root.graph()->AddControlEdge(neutral_0, read_0); 595 root.graph()->AddControlEdge(read_0, write_1); 596 root.graph()->AddControlEdge(write_1, neutral_1); 597 root.graph()->AddControlEdge(neutral_1, read_1); 598 599 FixupSourceAndSinkEdges(root.graph()); 600 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 601 TF_EXPECT_OK(root.ToGraph(graph.get())); 602 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); 603 604 std::vector<string> cluster_names; 605 absl::flat_hash_map<string, std::vector<string>> cluster_sets = 606 GetClusterSets(*graph, &cluster_names); 607 608 ASSERT_EQ(cluster_sets.size(), 1); 609 610 std::vector<string> expected_clustered_nodes_a = { 611 "AssignmentW1", "ConstN1", "ReadR0", "ValueToAssignW1"}; 612 ASSERT_EQ(cluster_sets[cluster_names[0]], expected_clustered_nodes_a); 613 } 614 615 TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { 616 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 617 Scope root = Scope::NewRootScope().ExitOnError(); 618 { 619 auto BuildNoopNode = [](absl::string_view name, Graph* graph) { 620 NodeDefBuilder builder(name, "NoOp"); 621 NodeDef def; 622 TF_CHECK_OK(builder.Finalize(&def)); 623 624 Status status; 625 Node* node = graph->AddNode(def, &status); 626 TF_CHECK_OK(status); 627 return node; 628 }; 629 630 Node* a = BuildNoopNode("a", graph.get()); 631 Node* b = BuildNoopNode("b", graph.get()); 632 Node* c = BuildNoopNode("c", graph.get()); 633 graph->AddControlEdge(a, b); 634 graph->AddControlEdge(b, c); 635 graph->AddControlEdge(c, a); 636 } 637 638 TF_EXPECT_OK(root.ToGraph(graph.get())); 639 640 Status status = MarkForCompilationPassTestHelper::MarkForCompilation(&graph); 641 EXPECT_FALSE(status.ok()); 642 EXPECT_TRUE(absl::StrContains(status.ToString(), 643 "Edge from c to a would create a cycle.\n" 644 "+-> a\n" 645 "| b\n" 646 "+-- c\n")); 647 } 648 649 TEST(XlaCompilationTest, Retval) { 650 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 651 GraphDef graphdef; 652 { 653 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); 654 Node* a = ops::SourceOp("Const", builder.opts() 655 .WithName("A") 656 .WithAttr("dtype", DT_FLOAT) 657 .WithAttr("value", Tensor())); 658 Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); 659 ops::UnaryOp("_Retval", b, 660 builder.opts() 661 .WithName("R") 662 .WithAttr("T", DT_FLOAT) 663 .WithAttr("index", 0)); 664 665 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); 666 } 667 668 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); 669 auto clusters = GetClusters(*graph); 670 671 EXPECT_TRUE(clusters.empty()); 672 } 673 674 TEST(XlaCompilationTest, DontCountIdentityOps) { 675 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 676 Scope root = Scope::NewRootScope().ExitOnError(); 677 { 678 auto a = ops::_Arg(root.WithOpName("A"), DT_INT32, 0); 679 auto b = ops::Identity(root.WithOpName("B"), a); 680 auto c = ops::Identity(root.WithOpName("C"), b); 681 auto r = ops::_Retval(root.WithOpName("R"), c, 0); 682 } 683 TF_ASSERT_OK(root.ToGraph(graph.get())); 684 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); 685 auto clusters = GetClusters(*graph); 686 687 EXPECT_TRUE(clusters.empty()); 688 } 689 690 TEST(XlaCompilationTest, ConstOp) { 691 // valid data type 692 { 693 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 694 Scope root = Scope::NewRootScope().ExitOnError(); 695 auto c = ops::Const(root.WithOpName("const"), 0.5f); 696 c.node()->AddAttr(kXlaCompileAttr, true); 697 TF_ASSERT_OK(root.ToGraph(graph.get())); 698 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); 699 EXPECT_EQ(1, GetClusters(*graph).size()); 700 } 701 702 // invalid data type 703 { 704 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 705 Scope root = Scope::NewRootScope().ExitOnError(); 706 auto c = ops::Const(root.WithOpName("const"), string("string")); 707 c.node()->AddAttr(kXlaCompileAttr, true); 708 TF_ASSERT_OK(root.ToGraph(graph.get())); 709 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); 710 EXPECT_TRUE(GetClusters(*graph).empty()); 711 } 712 } 713 714 TEST(XlaCompilationTest, DontClusterIdentityWithRefInput) { 715 Scope root = Scope::NewRootScope().ExitOnError(); 716 Output variable = ops::Variable(root.WithOpName("variable"), 717 PartialTensorShape{}, DT_FLOAT); 718 Output read = ops::Identity(root.WithOpName("read"), variable); 719 Output neg = ops::Negate(root.WithOpName("negate"), read); 720 Output add = ops::Add(root.WithOpName("add"), neg, neg); 721 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 722 723 TF_ASSERT_OK(root.ToGraph(graph.get())); 724 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); 725 726 std::unordered_map<string, string> clusters = GetClusters(*graph); 727 728 ASSERT_FALSE(clusters.empty()); 729 string cluster_name = clusters.begin()->second; 730 731 std::unordered_map<string, string> expected_clusters( 732 {{"negate", cluster_name}, {"add", cluster_name}}); 733 EXPECT_EQ(clusters, expected_clusters); 734 } 735 736 TEST(XlaCompilationTest, ClusterIdentityWithNonRefInput) { 737 Scope root = Scope::NewRootScope().ExitOnError(); 738 Output variable = ops::Variable(root.WithOpName("variable"), 739 PartialTensorShape{}, DT_FLOAT); 740 Output read = ops::Identity(root.WithOpName("read"), variable); 741 Output neg = ops::Negate(root.WithOpName("negate"), read); 742 Output identity = ops::Negate(root.WithOpName("identity"), neg); 743 Output add = ops::Add(root.WithOpName("add"), identity, neg); 744 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 745 746 TF_ASSERT_OK(root.ToGraph(graph.get())); 747 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); 748 749 std::unordered_map<string, string> clusters = GetClusters(*graph); 750 751 ASSERT_FALSE(clusters.empty()); 752 string cluster_name = clusters.begin()->second; 753 754 std::unordered_map<string, string> expected_clusters( 755 {{"negate", cluster_name}, 756 {"identity", cluster_name}, 757 {"add", cluster_name}}); 758 EXPECT_EQ(clusters, expected_clusters); 759 } 760 761 TEST(XlaCompilationTest, ClusterControlTrigger) { 762 Scope root = Scope::NewRootScope().ExitOnError(); 763 764 Output recv_a = ops::_Recv(root.WithOpName("recv_a"), DT_BOOL, "tensor_a", 765 "sender", 0, "receiver"); 766 Output recv_b = ops::_Recv(root.WithOpName("recv_b"), DT_BOOL, "tensor_b", 767 "sender", 0, "receiver"); 768 Output const_a = ops::Const(root.WithOpName("const_a"), 42); 769 770 ops::ControlTrigger ctrl_trigger_a(root.WithOpName("ctrl_trigger_a")); 771 ops::ControlTrigger ctrl_trigger_b(root.WithOpName("ctrl_trigger_b")); 772 root.graph()->AddControlEdge(recv_a.node(), ctrl_trigger_a.operation.node()); 773 root.graph()->AddControlEdge(recv_b.node(), ctrl_trigger_a.operation.node()); 774 root.graph()->AddControlEdge(ctrl_trigger_b.operation.node(), const_a.node()); 775 776 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 777 778 TF_ASSERT_OK(root.ToGraph(graph.get())); 779 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); 780 781 std::unordered_map<string, string> clusters = GetClusters(*graph); 782 783 // TODO(b/118970344): ctrl_trigger_a has inputs with mismatching deadness so 784 // it won't be clustered. ctrl_trigger_b is okay to cluster but we don't 785 // cluster it because of b/118970344. 786 EXPECT_TRUE(clusters.empty()); 787 } 788 789 TEST(XlaCompilationTest, RandomShape) { 790 Scope root = Scope::NewRootScope().ExitOnError(); 791 Output shape_shape = ops::Const(root.WithOpName("shape_shape"), {2}, {1}); 792 Output shape = 793 ops::RandomUniformInt(root.WithOpName("shape"), shape_shape, 794 ops::Const(root.WithOpName("minval"), 1), 795 ops::Const(root.WithOpName("maxval"), 20)); 796 Output reshape_input = 797 ops::Placeholder(root.WithOpName("reshape_input"), DT_FLOAT, 798 ops::Placeholder::Shape(TensorShape({500, 500}))); 799 Output reshape = 800 ops::Reshape(root.WithOpName("reshape"), reshape_input, shape); 801 802 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 803 804 TF_ASSERT_OK(root.ToGraph(graph.get())); 805 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); 806 807 std::unordered_map<string, string> clusters = GetClusters(*graph); 808 EXPECT_EQ(clusters["shape"], ""); 809 } 810 811 TEST(XlaCompilationTest, RandomShapeWithFunc) { 812 Scope root = Scope::DisabledShapeInferenceScope().ExitOnError(); 813 814 FunctionDefLibrary flib_def; 815 FunctionDef func = FunctionDefHelper::Create( 816 /*function_name=*/"Stateful_func", /*in_def=*/{}, 817 /*out_def=*/{"out: int32"}, 818 /*attr_def*/ 819 {}, /*node_def=*/ 820 {FunctionDefHelper::Const("shape_shape", 2), 821 FunctionDefHelper::Const("minval", 1), 822 FunctionDefHelper::Const("maxval", 20), 823 {{"shape"}, 824 "RandomUniformInt", 825 {"shape_shape:output:0", "minval:output:0", "maxval:output:0"}, 826 {{"Tout", DataType::DT_INT32}, {"T", DataType::DT_INT32}}}}, 827 /*ret_def=*/{{"out", "shape:output:0"}}); 828 829 func.mutable_signature()->set_is_stateful(true); 830 *flib_def.add_function() = std::move(func); 831 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); 832 NodeDef call_node; 833 call_node.set_name("fn_call"); 834 call_node.set_op("Stateful_func"); 835 Status status; 836 Node* call = root.graph()->AddNode(call_node, &status); 837 TF_ASSERT_OK(status); 838 839 Output shape = Output(call, 0); 840 Output reshape_input = 841 ops::Placeholder(root.WithOpName("reshape_input"), DT_FLOAT, 842 ops::Placeholder::Shape(TensorShape({500, 500}))); 843 Output reshape = 844 ops::Reshape(root.WithOpName("reshape"), reshape_input, shape); 845 846 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 847 TF_ASSERT_OK(root.ToGraph(graph.get())); 848 auto fld = absl::make_unique<FunctionLibraryDefinition>(OpRegistry::Global(), 849 flib_def); 850 TF_ASSERT_OK( 851 MarkForCompilationPassTestHelper::MarkForCompilation(&graph, fld.get())); 852 853 std::unordered_map<string, string> clusters = GetClusters(*graph); 854 EXPECT_EQ(clusters["fn_call"], ""); 855 } 856 857 TEST(XlaCompilationTest, RandomShapeOnXlaDevice) { 858 absl::string_view xla_gpu_device = 859 "/job:worker/replica:0/task:0/device:XLA_GPU:0"; 860 861 Scope root = Scope::NewRootScope().ExitOnError(); 862 Output shape_shape = 863 ops::Const(root.WithOpName("test/shape_shape"), {2}, {1}); 864 Output shape = 865 ops::RandomUniformInt(root.WithOpName("test/shape_rng"), shape_shape, 866 ops::Const(root.WithOpName("test/minval"), 1), 867 ops::Const(root.WithOpName("test/maxval"), 20)); 868 Output reshape_input = 869 ops::Placeholder(root.WithOpName("test/reshape_input"), DT_FLOAT, 870 ops::Placeholder::Shape(TensorShape({500, 500}))); 871 Output reshape = 872 ops::Reshape(root.WithOpName("test/reshape"), reshape_input, shape); 873 874 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 875 TF_ASSERT_OK(root.ToGraph(graph.get())); 876 877 for (Node* n : graph->nodes()) { 878 if (absl::StartsWith(n->name(), /*prefix=*/"test/")) { 879 n->set_assigned_device_name(string(xla_gpu_device)); 880 } 881 } 882 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); 883 884 std::unordered_map<string, string> clusters = GetClusters(*graph); 885 EXPECT_EQ(clusters["test/shape_rng"], ""); 886 EXPECT_EQ(clusters["test/reshape"], ""); 887 } 888 889 TEST(XlaCompilationTest, TensorArrayShapeOnXlaDevice) { 890 absl::string_view xla_gpu_device = 891 "/job:worker/replica:0/task:0/device:XLA_GPU:0"; 892 Scope root = Scope::NewRootScope().ExitOnError(); 893 ops::TensorArray tensor_array(root.WithOpName("test/tensor_array"), 1, 894 DT_INT32); 895 Output zero = ops::Const(root.WithOpName("test/zero"), 0); 896 ops::TensorArrayWrite tensor_array_write( 897 root.WithOpName("test/write"), tensor_array.handle, zero, 898 ops::Const(root.WithOpName("test/forty_two"), 42.0f), tensor_array.flow); 899 Output tensor_array_read = 900 ops::TensorArrayRead(root.WithOpName("test/read"), tensor_array.handle, 901 zero, tensor_array_write.flow_out, DT_INT32); 902 Output reshape = 903 ops::Reshape(root.WithOpName("test/reshape"), 904 ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT), 905 tensor_array_read); 906 907 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 908 TF_ASSERT_OK(root.ToGraph(graph.get())); 909 910 for (Node* n : graph->nodes()) { 911 if (absl::StartsWith(n->name(), /*prefix=*/"test/")) { 912 n->set_assigned_device_name(string(xla_gpu_device)); 913 } 914 } 915 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); 916 917 std::unordered_map<string, string> clusters = GetClusters(*graph); 918 EXPECT_NE(clusters["test/read"], ""); 919 EXPECT_EQ(clusters["test/read"], clusters["test/reshape"]); 920 } 921 922 TEST(XlaCompilationTest, DontClusterMergingNodes) { 923 // MatMulCombined below takes data from nodes on GPU0 and GPU1 and is placed 924 // on GPU1. However, it should not be clustered with the previous node on 925 // GPU1, because that will serialize production of its inputs that should be 926 // done in parallel. 927 // 928 // This graph is: 929 // (Const0, Const0) -> MatMul0 930 // (Const1, Const1) -> MatMul1 931 // (MatMul0, MatMul1) -> MatMulCombined 932 // 933 // Device0: [Const0, Const0, MatMul0] 934 // Device1: [Const1, Const1, MatMul1, MatMulCombined] 935 // 936 // Cluster0: [Const0, Const0, MatMul0] 937 // Cluster1: [Const1, Const1, MatMul1] 938 // Cluster2: [MatMulCombined] 939 Scope root = Scope::NewRootScope().ExitOnError(); 940 absl::string_view xla_gpu_dev0 = 941 "/job:worker/replica:0/task:0/device:XLA_GPU:0"; 942 absl::string_view xla_gpu_dev1 = 943 "/job:worker/replica:0/task:0/device:XLA_GPU:1"; 944 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 945 Output a = ops::Tanh(root.WithOpName("tanh_A_dev0"), 946 ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2})); 947 Output b = ops::Tanh(root.WithOpName("tanh_B_dev1"), 948 ops::Const(root.WithOpName("B_dev1"), 1.0f, {2, 2})); 949 Output matmul0 = ops::MatMul(root.WithOpName("MatMul0_dev0"), a, a); 950 Output matmul1 = ops::MatMul(root.WithOpName("MatMul1_dev1"), b, b); 951 952 Output combined = 953 ops::MatMul(root.WithOpName("MatMulCombined_dev1"), matmul0, matmul1); 954 TF_ASSERT_OK(root.ToGraph(graph.get())); 955 956 for (Node* n : graph->nodes()) { 957 if (absl::EndsWith(n->name(), /*suffix=*/"dev0")) { 958 n->set_assigned_device_name(string(xla_gpu_dev0)); 959 } else if (absl::EndsWith(n->name(), /*suffix=*/"dev1")) { 960 n->set_assigned_device_name(string(xla_gpu_dev1)); 961 } 962 } 963 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); 964 965 // Each of the MatMuls should be in a separate cluster. 966 std::unordered_map<string, string> clusters = GetClusters(*graph); 967 EXPECT_NE(clusters["MatMul0_dev0"], clusters["MatMul1_dev1"]); 968 EXPECT_NE(clusters["MatMulCombined_dev1"], clusters["MatMul0_dev0"]); 969 EXPECT_NE(clusters["MatMulCombined_dev1"], clusters["MatMul1_dev1"]); 970 EXPECT_EQ(clusters["A_dev0"], clusters["MatMul0_dev0"]); 971 EXPECT_EQ(clusters["B_dev1"], clusters["MatMul1_dev1"]); 972 } 973 974 // TODO(b/117085735): This form of clustering should be prevented. 975 TEST(XlaCompilationTest, NOT_DontClusterSpreadingNodes) { 976 // MatMulSource below creates data for nodes on GPU0 and GPU1 and is placed 977 // on GPU0. However, it should not be clustered with the next node on 978 // GPU0, because that will prevent the node on GPU1 from beginning its work as 979 // soon as the data has been produced. 980 // 981 // This graph is: 982 // (Const0, Const0) -> MatMulSource 983 // MatMulSource -> (MatMul0, MatMul1) 984 // 985 // Device0: [Const0, Const1, MatMulSource, MatMul0] 986 // Device1: [MatMul1] 987 // 988 // Cluster0: [Const0, Const1, MatMulSource] 989 // Cluster1: [MatMul0] 990 // Cluster2: [MatMul1] 991 Scope root = Scope::NewRootScope().ExitOnError(); 992 absl::string_view xla_gpu_dev0 = 993 "/job:worker/replica:0/task:0/device:XLA_GPU:0"; 994 absl::string_view xla_gpu_dev1 = 995 "/job:worker/replica:0/task:0/device:XLA_GPU:1"; 996 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 997 Output a = ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2}); 998 Output matmul_source = 999 ops::MatMul(root.WithOpName("MatMulSource_dev0"), a, a); 1000 1001 Output matmul0 = ops::MatMul(root.WithOpName("MatMul0_dev0"), matmul_source, 1002 matmul_source); 1003 Output matmul1 = ops::MatMul(root.WithOpName("MatMul1_dev1"), matmul_source, 1004 matmul_source); 1005 1006 TF_ASSERT_OK(root.ToGraph(graph.get())); 1007 for (Node* n : graph->nodes()) { 1008 if (absl::EndsWith(n->name(), /*suffix=*/"dev0")) { 1009 n->set_assigned_device_name(string(xla_gpu_dev0)); 1010 } else if (absl::EndsWith(n->name(), /*suffix=*/"dev1")) { 1011 n->set_assigned_device_name(string(xla_gpu_dev1)); 1012 } 1013 } 1014 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); 1015 1016 std::unordered_map<string, string> clusters = GetClusters(*graph); 1017 EXPECT_EQ(clusters["A_dev0"], clusters["MatMulSource_dev0"]); 1018 EXPECT_NE(clusters["MatMul0_dev0"], clusters["MatMul1_dev1"]); 1019 EXPECT_NE(clusters["MatMulSource_dev0"], clusters["MatMul1_dev1"]); 1020 1021 // Improved Heuristics should prevent this probably. 1022 EXPECT_EQ(clusters["MatMulSource_dev0"], clusters["MatMul0_dev0"]); 1023 } 1024 1025 TEST(XlaCompilationTest, ClusterStatefulRandomOpOnXlaDevice) { 1026 absl::string_view xla_cpu_device = 1027 "/job:worker/replica:0/task:0/device:XLA_CPU:0"; 1028 1029 Scope root = Scope::NewRootScope().ExitOnError(); 1030 Output shape = ops::Const(root.WithOpName("test/shape_shape"), {200, 200}); 1031 Output a = ops::RandomUniform(root.WithOpName("test/a"), shape, DT_FLOAT); 1032 Output b = ops::RandomUniform(root.WithOpName("test/b"), shape, DT_FLOAT); 1033 Output c = ops::Add(root.WithOpName("test/c"), a, b); 1034 1035 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 1036 TF_ASSERT_OK(root.ToGraph(graph.get())); 1037 1038 for (Node* n : graph->nodes()) { 1039 if (absl::StartsWith(n->name(), /*prefix=*/"test/")) { 1040 n->set_assigned_device_name(string(xla_cpu_device)); 1041 } 1042 } 1043 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); 1044 1045 std::unordered_map<string, string> clusters = GetClusters(*graph); 1046 EXPECT_NE(clusters["test/a"], ""); 1047 EXPECT_NE(clusters["test/b"], ""); 1048 EXPECT_NE(clusters["test/c"], ""); 1049 } 1050 1051 TEST(XlaCompilationTest, DontAutoClusterStatefulRandomOp) { 1052 Scope root = Scope::NewRootScope().ExitOnError(); 1053 Output shape = ops::Const(root.WithOpName("test/shape_shape"), {200, 200}); 1054 Output a = ops::RandomUniform(root.WithOpName("test/a"), shape, DT_FLOAT); 1055 Output b = ops::RandomUniform(root.WithOpName("test/b"), shape, DT_FLOAT); 1056 Output c = ops::Add(root.WithOpName("test/c"), a, b); 1057 1058 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 1059 TF_ASSERT_OK(root.ToGraph(graph.get())); 1060 1061 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); 1062 1063 std::unordered_map<string, string> clusters = GetClusters(*graph); 1064 EXPECT_EQ(clusters["test/a"], ""); 1065 EXPECT_EQ(clusters["test/b"], ""); 1066 } 1067 1068 TEST(XlaCompilationTest, ClusterDummyOpsOnXlaDevice) { 1069 absl::string_view xla_cpu_device = 1070 "/job:worker/replica:0/task:0/device:XLA_CPU:0"; 1071 1072 Scope root = Scope::NewRootScope().ExitOnError(); 1073 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT); 1074 Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT); 1075 Output check = 1076 ops::CheckNumerics(root.WithOpName("test/check"), a, "test/check"); 1077 Output ge = ops::GreaterEqual(root.WithOpName("test/greaterequal"), check, b); 1078 Operation assert = ops::Assert(root.WithOpName("test/assert"), ge, {a, b}); 1079 1080 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 1081 TF_ASSERT_OK(root.ToGraph(graph.get())); 1082 1083 for (Node* n : graph->nodes()) { 1084 if (absl::StartsWith(n->name(), /*prefix=*/"test/")) { 1085 n->set_assigned_device_name(string(xla_cpu_device)); 1086 } 1087 } 1088 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); 1089 1090 std::unordered_map<string, string> clusters = GetClusters(*graph); 1091 EXPECT_NE(clusters["test/check"], ""); 1092 EXPECT_NE(clusters["test/greaterequal"], ""); 1093 EXPECT_NE(clusters["test/assert"], ""); 1094 } 1095 1096 TEST(XlaCompilationTest, DontAutoClusterDummyOps) { 1097 Scope root = Scope::NewRootScope().ExitOnError(); 1098 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT); 1099 Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT); 1100 Output check = 1101 ops::CheckNumerics(root.WithOpName("test/check"), a, "test/check"); 1102 Output ge = ops::GreaterEqual(root.WithOpName("test/greaterequal"), check, b); 1103 Operation assert = ops::Assert(root.WithOpName("test/assert"), ge, {a, b}); 1104 1105 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 1106 TF_ASSERT_OK(root.ToGraph(graph.get())); 1107 1108 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); 1109 1110 std::unordered_map<string, string> clusters = GetClusters(*graph); 1111 EXPECT_EQ(clusters["test/assert"], ""); 1112 EXPECT_EQ(clusters["test/check"], ""); 1113 } 1114 1115 TEST(XlaCompilationTest, DontAutoClusterOpsProducingVariant) { 1116 Scope root = Scope::NewRootScope().ExitOnError(); 1117 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_INT64); 1118 Output b = ops::Placeholder(root.WithOpName("test/b"), DT_INT64); 1119 1120 Output cast_a = ops::Cast(root.WithOpName("test/cast_a"), a, DT_INT32); 1121 Output cast_b = ops::Cast(root.WithOpName("test/cast_b"), b, DT_INT32); 1122 1123 Output tensor_list_reserve = ops::TensorListReserve( 1124 root.WithOpName("test/tensor_list_reserve"), cast_a, cast_b, DT_FLOAT); 1125 1126 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 1127 TF_ASSERT_OK(root.ToGraph(graph.get())); 1128 1129 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); 1130 1131 std::unordered_map<string, string> clusters = GetClusters(*graph); 1132 EXPECT_EQ(clusters["test/tensor_list_reserve"], ""); 1133 } 1134 1135 TEST(XlaCompilationTest, DontAutoClusterOpsConsumingVariant) { 1136 Scope root = Scope::NewRootScope().ExitOnError(); 1137 Output dummy_input = 1138 ops::Placeholder(root.WithOpName("test/dummy_input"), DT_INT64); 1139 Output variant_input = 1140 ops::Placeholder(root.WithOpName("test/variant_input"), DT_VARIANT); 1141 1142 // Create one more node so that we don't avoid creating a cluster solely 1143 // because it would be trivial. 1144 Output dummy_cast = 1145 ops::Cast(root.WithOpName("test/dummy_cast"), dummy_input, DT_INT32); 1146 1147 Output tensor_list_element_shape = ops::TensorListElementShape( 1148 root.WithOpName("test/tensor_list_element_shape"), variant_input, 1149 DT_INT32); 1150 1151 root.graph()->AddControlEdge(dummy_cast.node(), 1152 tensor_list_element_shape.node()); 1153 1154 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 1155 TF_ASSERT_OK(root.ToGraph(graph.get())); 1156 1157 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); 1158 1159 std::unordered_map<string, string> clusters = GetClusters(*graph); 1160 EXPECT_EQ(clusters["test/tensor_list_element_shape"], ""); 1161 } 1162 1163 TEST(XlaCompilationTest, ClusterOpsProducingVariantIfOnXlaDevice) { 1164 Scope root = Scope::NewRootScope().ExitOnError(); 1165 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_INT64); 1166 Output b = ops::Placeholder(root.WithOpName("test/b"), DT_INT64); 1167 1168 Output cast_a = ops::Cast(root.WithOpName("test/cast_a"), a, DT_INT32); 1169 Output cast_b = ops::Cast(root.WithOpName("test/cast_b"), b, DT_INT32); 1170 1171 Output tensor_list_reserve = ops::TensorListReserve( 1172 root.WithOpName("test/tensor_list_reserve"), cast_a, cast_b, DT_FLOAT); 1173 1174 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 1175 TF_ASSERT_OK(root.ToGraph(graph.get())); 1176 1177 string xla_cpu_device = "/job:worker/replica:0/task:0/device:XLA_CPU:0"; 1178 for (Node* n : graph->nodes()) { 1179 if (absl::StartsWith(n->name(), /*prefix=*/"test/")) { 1180 n->set_assigned_device_name(xla_cpu_device); 1181 } 1182 } 1183 1184 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); 1185 1186 std::unordered_map<string, string> clusters = GetClusters(*graph); 1187 EXPECT_NE(clusters["test/tensor_list_reserve"], ""); 1188 } 1189 1190 const char* kCPU0 = "/job:worker/replica:0/task:0/device:CPU:0"; 1191 const char* kGPU0 = "/job:worker/replica:0/task:0/device:GPU:0"; 1192 const char* kXLA_GPU0 = "/job:worker/replica:0/task:0/device:XLA_GPU:0"; 1193 const char* kGPU1 = "/job:worker/replica:0/task:0/device:GPU:1"; 1194 1195 TEST(XlaCompilationTest, CreateCombinedCpuGpuClusters) { 1196 Scope root = Scope::NewRootScope().ExitOnError(); 1197 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT); 1198 Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT); 1199 1200 Output x = ops::Add(root.WithOpName("test/x"), a, b); 1201 Output y = ops::MatMul(root.WithOpName("test/y"), a, b); 1202 Output z = ops::Add(root.WithOpName("test/z"), x, y); 1203 1204 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 1205 TF_ASSERT_OK(root.ToGraph(graph.get())); 1206 1207 FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0); 1208 FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kCPU0); 1209 FindNodeByName(graph.get(), "test/z")->set_assigned_device_name(kGPU0); 1210 1211 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); 1212 1213 std::unordered_map<string, string> clusters = GetClusters(*graph); 1214 1215 EXPECT_NE(clusters["test/x"], ""); 1216 1217 EXPECT_EQ(clusters["test/x"], clusters["test/y"]); 1218 EXPECT_EQ(clusters["test/y"], clusters["test/z"]); 1219 } 1220 1221 TEST(XlaCompilationTest, DontCreateGpu0AndGpu1Clusters) { 1222 Scope root = Scope::NewRootScope().ExitOnError(); 1223 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT); 1224 Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT); 1225 1226 Output x = ops::Add(root.WithOpName("test/x"), a, b); 1227 Output y = ops::Add(root.WithOpName("test/y"), x, x); 1228 1229 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 1230 TF_ASSERT_OK(root.ToGraph(graph.get())); 1231 1232 FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0); 1233 FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kGPU1); 1234 1235 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); 1236 1237 std::unordered_map<string, string> clusters = GetClusters(*graph); 1238 1239 EXPECT_EQ(clusters["test/x"], ""); 1240 EXPECT_EQ(clusters["test/y"], ""); 1241 } 1242 1243 TEST(XlaCompilationTest, DontCreateCombinedCpuUnknownClusters) { 1244 Scope root = Scope::NewRootScope().ExitOnError(); 1245 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT); 1246 Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT); 1247 1248 Output x = ops::Add(root.WithOpName("test/x"), a, b); 1249 Output y = ops::Add(root.WithOpName("test/y"), x, x); 1250 1251 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 1252 TF_ASSERT_OK(root.ToGraph(graph.get())); 1253 1254 FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kCPU0); 1255 FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kXLA_GPU0); 1256 1257 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); 1258 1259 std::unordered_map<string, string> clusters = GetClusters(*graph); 1260 1261 EXPECT_EQ(clusters["test/x"], ""); 1262 EXPECT_EQ(clusters["test/y"], ""); 1263 } 1264 1265 TEST(XlaCompilationTest, ClusterResourceOpsWhenSafe) { 1266 Scope root = Scope::NewRootScope().ExitOnError(); 1267 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT); 1268 Node* var_handle; 1269 Node* resource_read = MakeRead(root, "read", &var_handle); 1270 Output b = ops::Add(root.WithOpName("test/b"), Output(resource_read, 0), a); 1271 1272 string resource_read_name = resource_read->name(); 1273 string var_handle_name = var_handle->name(); 1274 1275 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 1276 TF_ASSERT_OK(root.ToGraph(graph.get())); 1277 1278 FindNodeByName(graph.get(), "test/b")->set_assigned_device_name(kCPU0); 1279 FindNodeByName(graph.get(), resource_read_name) 1280 ->set_assigned_device_name(kGPU0); 1281 FindNodeByName(graph.get(), var_handle_name)->set_assigned_device_name(kGPU0); 1282 1283 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); 1284 1285 std::unordered_map<string, string> clusters = GetClusters(*graph); 1286 1287 EXPECT_NE(clusters["test/b"], ""); 1288 EXPECT_EQ(clusters["test/b"], clusters[resource_read_name]); 1289 } 1290 1291 TEST(XlaCompilationTest, DontClusterResourceOpsWhenUnsafe) { 1292 Scope root = Scope::NewRootScope().ExitOnError(); 1293 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT); 1294 Node* var_handle; 1295 Node* resource_read = MakeRead(root, "read", &var_handle); 1296 Output b = ops::Add(root.WithOpName("test/b"), Output(resource_read, 0), a); 1297 1298 string resource_read_name = resource_read->name(); 1299 string var_handle_name = var_handle->name(); 1300 1301 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 1302 TF_ASSERT_OK(root.ToGraph(graph.get())); 1303 1304 FindNodeByName(graph.get(), "test/b")->set_assigned_device_name(kGPU0); 1305 FindNodeByName(graph.get(), resource_read_name) 1306 ->set_assigned_device_name(kCPU0); 1307 FindNodeByName(graph.get(), var_handle_name)->set_assigned_device_name(kCPU0); 1308 1309 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); 1310 1311 std::unordered_map<string, string> clusters = GetClusters(*graph); 1312 1313 EXPECT_EQ(clusters["test/b"], ""); 1314 EXPECT_EQ(clusters[resource_read_name], ""); 1315 } 1316 1317 } // namespace 1318 } // namespace tensorflow 1319