1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" 17 18 #include "tensorflow/cc/framework/ops.h" 19 #include "tensorflow/cc/ops/control_flow_ops_internal.h" 20 #include "tensorflow/cc/ops/standard_ops.h" 21 #include "tensorflow/compiler/jit/defs.h" 22 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 23 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 24 #include "tensorflow/core/framework/node_def_util.h" 25 #include "tensorflow/core/framework/op.h" 26 #include "tensorflow/core/graph/graph_constructor.h" 27 #include "tensorflow/core/graph/graph_def_builder.h" 28 #include "tensorflow/core/graph/graph_def_builder_util.h" 29 #include "tensorflow/core/lib/core/status_test_util.h" 30 #include "tensorflow/core/platform/test.h" 31 32 namespace tensorflow { 33 namespace { 34 35 REGISTER_OP("UncompilableNullary").Output("o: float"); 36 REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float"); 37 38 Status MarkForCompilation(std::unique_ptr<Graph>* graph, 39 FunctionLibraryDefinition* flib_def) { 40 // Assign all nodes to the CPU device. 41 static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0"; 42 for (Node* n : (*graph)->nodes()) { 43 n->set_assigned_device_name(kCpuDevice); 44 } 45 46 GraphOptimizationPassOptions opt_options; 47 opt_options.graph = graph; 48 opt_options.flib_def = flib_def; 49 MarkForCompilationPass pass; 50 return pass.RunImpl(opt_options); 51 } 52 53 Status MarkForCompilation(std::unique_ptr<Graph>* graph) { 54 FunctionDefLibrary flib; 55 FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib); 56 return MarkForCompilation(graph, &flib_def); 57 } 58 59 std::unordered_map<string, string> GetClusters(const Graph& graph) { 60 std::unordered_map<string, string> ids; 61 for (Node* node : graph.nodes()) { 62 string cluster; 63 if (GetNodeAttr(node->attrs(), kXlaClusterAttr, &cluster).ok()) { 64 CHECK(!cluster.empty()); 65 ids[node->name()] = cluster; 66 } 67 } 68 return ids; 69 } 70 71 TEST(XlaCompilationTest, Chains) { 72 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 73 GraphDef graphdef; 74 { 75 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); 76 Node* a = 77 ops::SourceOp("UncompilableNullary", builder.opts().WithName("A")); 78 Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); 79 Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C")); 80 Node* d = 81 ops::UnaryOp("UncompilableUnary", c, builder.opts().WithName("D")); 82 Node* e = ops::UnaryOp("Relu", d, builder.opts().WithName("E")); 83 ops::UnaryOp("Relu", e, builder.opts().WithName("F")); 84 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); 85 } 86 87 TF_ASSERT_OK(MarkForCompilation(&graph)); 88 auto clusters = GetClusters(*graph); 89 EXPECT_EQ(4, clusters.size()); 90 EXPECT_EQ(clusters["B"], clusters["C"]); 91 EXPECT_EQ(clusters["E"], clusters["F"]); 92 EXPECT_NE(clusters["B"], clusters["E"]); 93 EXPECT_TRUE(clusters.find("A") == clusters.cend()); 94 EXPECT_TRUE(clusters.find("D") == clusters.cend()); 95 } 96 97 TEST(XlaCompilationTest, UncompilableCycles) { 98 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 99 GraphDef graphdef; 100 { 101 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); 102 Node* a = ops::SourceOp("Const", builder.opts() 103 .WithName("A") 104 .WithAttr("dtype", DT_FLOAT) 105 .WithAttr("value", Tensor())); 106 Node* b = 107 ops::UnaryOp("UncompilableUnary", a, builder.opts().WithName("B")); 108 ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); 109 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); 110 } 111 112 TF_ASSERT_OK(MarkForCompilation(&graph)); 113 auto clusters = GetClusters(*graph); 114 115 EXPECT_TRUE(clusters.empty()); 116 } 117 118 TEST(XlaCompilationTest, CompilableCycles) { 119 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 120 GraphDef graphdef; 121 { 122 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); 123 Node* a = ops::SourceOp("Const", builder.opts() 124 .WithName("A") 125 .WithAttr("dtype", DT_FLOAT) 126 .WithAttr("value", Tensor())); 127 Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); 128 ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); 129 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); 130 } 131 132 TF_ASSERT_OK(MarkForCompilation(&graph)); 133 auto clusters = GetClusters(*graph); 134 135 EXPECT_EQ(3, clusters.size()); 136 EXPECT_EQ(clusters["A"], clusters["B"]); 137 EXPECT_EQ(clusters["A"], clusters["C"]); 138 } 139 140 TEST(XlaCompilationTest, UnsupportedTypes) { 141 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 142 GraphDef graphdef; 143 { 144 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); 145 Node* a = ops::SourceOp( 146 "Const", builder.opts() 147 .WithName("A") 148 .WithAttr("dtype", DT_COMPLEX128) 149 .WithAttr("value", Tensor(DT_COMPLEX128, TensorShape()))); 150 Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B")); 151 ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); 152 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); 153 } 154 155 TF_ASSERT_OK(MarkForCompilation(&graph)); 156 auto clusters = GetClusters(*graph); 157 EXPECT_TRUE(clusters.empty()); 158 } 159 160 TEST(XlaCompilationTest, ConcatWithConstArg) { 161 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 162 GraphDef graphdef; 163 { 164 Tensor t(DT_INT32, TensorShape()); 165 t.scalar<int32>()() = 0; 166 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); 167 Node* dim = ops::SourceOp("Const", builder.opts() 168 .WithName("Dim") 169 .WithAttr("dtype", DT_INT32) 170 .WithAttr("value", t)); 171 Node* a = ops::SourceOp("Const", builder.opts() 172 .WithName("A") 173 .WithAttr("dtype", DT_FLOAT) 174 .WithAttr("value", t)); 175 176 NodeBuilder concat_builder("Concat", "Concat", 177 builder.opts().op_registry()); 178 concat_builder.Input(dim).Input({a, a}).Attr("N", 2); 179 builder.opts().FinalizeBuilder(&concat_builder); 180 181 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); 182 } 183 184 TF_ASSERT_OK(MarkForCompilation(&graph)); 185 auto clusters = GetClusters(*graph); 186 EXPECT_EQ(3, clusters.size()); // Everything should be compiled. 187 } 188 189 TEST(XlaCompilationTest, FunctionCalls) { 190 FunctionDef compilable = FunctionDefHelper::Define( 191 "CompilableFn", {"n_a:float", "n_b:float"}, {"n_c:float"}, {}, 192 {{{"n_c"}, "Add", {"n_a", "n_b"}, {{"T", DT_FLOAT}}}}); 193 FunctionDef uncompilable = 194 FunctionDefHelper::Define("UncompilableFn", {"n_a:float"}, {"n_c:float"}, 195 {}, {{{"n_c"}, "UncompilableUnary", {"n_a"}}}); 196 FunctionDef noinline = compilable; 197 noinline.mutable_signature()->set_name("NoInlineFn"); 198 AddAttr("_noinline", bool(true), noinline.mutable_attr()); 199 200 FunctionDefLibrary flib; 201 *flib.add_function() = compilable; 202 *flib.add_function() = uncompilable; 203 *flib.add_function() = noinline; 204 FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); 205 206 std::unique_ptr<Graph> graph(new Graph(&flib_def)); 207 GraphDef graphdef; 208 { 209 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately, &flib_def); 210 Node* a = 211 ops::SourceOp("UncompilableNullary", builder.opts().WithName("A")); 212 Node* b = ops::BinaryOp("CompilableFn", a, a, builder.opts().WithName("B")); 213 Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C")); 214 ops::UnaryOp("UncompilableFn", c, builder.opts().WithName("D")); 215 ops::BinaryOp("NoInlineFn", c, c, builder.opts().WithName("E")); 216 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); 217 } 218 219 TF_ASSERT_OK(MarkForCompilation(&graph, &flib_def)); 220 auto clusters = GetClusters(*graph); 221 222 EXPECT_EQ(2, clusters.size()); 223 EXPECT_FALSE(clusters["B"].empty()); 224 EXPECT_EQ(clusters["B"], clusters["C"]); 225 EXPECT_TRUE(clusters.find("A") == clusters.cend()); 226 EXPECT_TRUE(clusters.find("D") == clusters.cend()); 227 EXPECT_TRUE(clusters.find("E") == clusters.cend()); 228 } 229 230 // Metadata-only operators such as Shape/Rank/Size may not be the root of a 231 // cluster. This is partially to work around b/26800664, and partially because 232 // we should probably prefer to compile metadata operators with their producers 233 // wherever possible, rather than their consumers. 234 TEST(XlaCompilationTest, MetadataOpsDontStartClusters) { 235 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 236 GraphDef graphdef; 237 { 238 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); 239 Node* a = 240 ops::SourceOp("UncompilableNullary", builder.opts().WithName("A")); 241 // While all of the following ops are notionally compilable, none is 242 // permitted 243 // to start a cluster. So nothing should be compiled. 244 Node* b = ops::UnaryOp("Shape", a, builder.opts().WithName("B")); 245 Node* c = ops::UnaryOp("Rank", b, builder.opts().WithName("C")); 246 Node* d = ops::UnaryOp("Size", c, builder.opts().WithName("D")); 247 ops::UnaryOp("Shape", d, builder.opts().WithName("E")); 248 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); 249 } 250 TF_ASSERT_OK(MarkForCompilation(&graph)); 251 auto clusters = GetClusters(*graph); 252 EXPECT_EQ(0, clusters.size()); // Nothing should be compiled. 253 } 254 255 static Status GradForUnaryCwise(FunctionDef* g, 256 std::vector<FunctionDefHelper::Node> nodes) { 257 for (auto& n : nodes) { 258 if (n.attr.empty()) { 259 n.attr = {{"T", DT_FLOAT}}; 260 } 261 } 262 *g = FunctionDefHelper::Define( 263 // Arg defs 264 {"x: float", "dy: float"}, 265 // Ret val defs 266 {"dx: float"}, 267 // Attr defs 268 {}, 269 // Nodes 270 nodes); 271 return Status::OK(); 272 } 273 274 // A gradient containing only supported operators 275 Status SupportedGrad(const AttrSlice& attrs, FunctionDef* g) { 276 // clang-format off 277 return GradForUnaryCwise(g, { 278 {{"y"}, "Tanh", {"x"}}, 279 {{"y2"}, "Square", {"y"}, {}, {"dy"}}, 280 FunctionDefHelper::Const("one", 1.0f), 281 {{"a"}, "Sub", {"one", "y2"}}, 282 {{"dx"}, "Mul", {"dy", "a"}}, 283 }); 284 // clang-format on 285 } 286 REGISTER_OP_GRADIENT("Supported", SupportedGrad); 287 288 // A gradient containing an unsupported operator. 289 Status UnsupportedGrad(const AttrSlice& attrs, FunctionDef* g) { 290 // clang-format off 291 return GradForUnaryCwise(g, { 292 {{"y"}, "Tanh", {"x"}}, 293 {{"y2"}, "UncompilableUnary", {"y"}, {}, {"dy"}}, 294 FunctionDefHelper::Const("one", 1.0f), 295 {{"a"}, "Sub", {"one", "y2"}}, 296 {{"dx"}, "Mul", {"dy", "a"}}, 297 }); 298 // clang-format on 299 } 300 REGISTER_OP_GRADIENT("Unsupported", UnsupportedGrad); 301 302 TEST(XlaCompilationTest, SymbolicGradients) { 303 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 304 GraphDef graphdef; 305 { 306 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); 307 Node* a = 308 ops::SourceOp("UncompilableNullary", builder.opts().WithName("A")); 309 310 // Builds a Symbolic gradient for Supported 311 NodeBuilder b_builder("B", "SymbolicGradient", 312 builder.opts().op_registry()); 313 NameAttrList b_name_attr; 314 b_name_attr.set_name("Supported"); 315 b_builder.Attr("f", b_name_attr); 316 b_builder.Attr("Tin", {DT_FLOAT, DT_FLOAT}); 317 b_builder.Attr("Tout", {DT_FLOAT}); 318 b_builder.Input({a, a}); 319 Node* b = builder.opts().FinalizeBuilder(&b_builder); 320 321 Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C")); 322 323 // Builds a Symbolic gradient for Unsupported 324 NodeBuilder d_builder("D", "SymbolicGradient", 325 builder.opts().op_registry()); 326 NameAttrList d_name_attr; 327 d_name_attr.set_name("Unsupported"); 328 d_builder.Attr("f", d_name_attr); 329 d_builder.Attr("Tin", {DT_FLOAT, DT_FLOAT}); 330 d_builder.Attr("Tout", {DT_FLOAT}); 331 d_builder.Input({c, c}); 332 builder.opts().FinalizeBuilder(&d_builder); 333 334 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); 335 } 336 337 TF_ASSERT_OK(MarkForCompilation(&graph)); 338 auto clusters = GetClusters(*graph); 339 340 EXPECT_EQ(2, clusters.size()); 341 EXPECT_FALSE(clusters["B"].empty()); 342 EXPECT_EQ(clusters["B"], clusters["C"]); 343 EXPECT_TRUE(clusters.find("A") == clusters.cend()); 344 EXPECT_TRUE(clusters.find("D") == clusters.cend()); 345 } 346 347 TEST(XlaCompilationTest, Loops) { 348 // Regression test for b/32350199, where the autoclustering code introduced a 349 // deadlock in a graph containing a while loop. 350 Scope root = Scope::NewRootScope().ExitOnError(); 351 auto a = ops::Placeholder(root.WithOpName("A"), DT_FLOAT); 352 auto b = ops::Placeholder(root.WithOpName("B"), DT_FLOAT); 353 auto c = ops::Add(root.WithOpName("C"), a, b); 354 auto enter = ops::internal::Enter(root, c, "aframe"); 355 auto next_iter = ops::NextIteration(root, enter); 356 auto exit = ops::internal::Exit(root, next_iter); 357 auto d = ops::Add(root.WithOpName("D"), c, exit); 358 359 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 360 TF_EXPECT_OK(root.ToGraph(graph.get())); 361 362 TF_ASSERT_OK(MarkForCompilation(&graph)); 363 auto clusters = GetClusters(*graph); 364 365 // Nothing should be compiled. In particular, 'd' and 'c' must not be 366 // compiled. 367 EXPECT_EQ(0, clusters.size()); 368 } 369 370 TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) { 371 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 372 GraphDef graphdef; 373 { 374 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); 375 Node* a = ops::SourceOp("Const", builder.opts() 376 .WithName("A") 377 .WithAttr("dtype", DT_FLOAT) 378 .WithAttr("value", Tensor()) 379 .WithAttr(kXlaScopeAttr, "ScopeA")); 380 Node* b = ops::UnaryOp( 381 "Relu", a, 382 builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "ScopeB")); 383 ops::BinaryOp( 384 "MatMul", a, b, 385 builder.opts().WithName("C").WithAttr(kXlaScopeAttr, "ScopeC")); 386 TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); 387 } 388 389 TF_ASSERT_OK(MarkForCompilation(&graph)); 390 auto clusters = GetClusters(*graph); 391 392 // The computation is: C = A + relu(A) 393 // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC. 394 // In this case, we cannot fuse anything, and there are no clusters. 395 EXPECT_EQ(0, clusters.size()); 396 } 397 398 TEST(XlaCompilationTest, CyclesWithSplittingScopes) { 399 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 400 GraphDef graphdef; 401 { 402 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); 403 Node* a = ops::SourceOp("Const", builder.opts() 404 .WithName("A") 405 .WithAttr("dtype", DT_FLOAT) 406 .WithAttr("value", Tensor()) 407 .WithAttr(kXlaScopeAttr, "Scope1")); 408 Node* b = ops::UnaryOp( 409 "Relu", a, 410 builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "Scope1")); 411 Node* c = ops::BinaryOp( 412 "MatMul", a, b, 413 builder.opts().WithName("C").WithAttr(kXlaScopeAttr, "Scope2")); 414 ops::BinaryOp( 415 "Add", b, c, 416 builder.opts().WithName("D").WithAttr(kXlaScopeAttr, "Scope2")); 417 TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); 418 } 419 420 TF_ASSERT_OK(MarkForCompilation(&graph)); 421 auto clusters = GetClusters(*graph); 422 423 // The computation is: D = relu(A) + (A @ relu(A)) 424 // where A and relu(A) are in Scope1, and the @, + ops are in Scope2. 425 // In this case, we can fuse the A and relu(A), and we can fuse the 426 // second half of the operations; there are two clusters. 427 EXPECT_EQ(4, clusters.size()); 428 EXPECT_EQ(clusters["A"], clusters["B"]); 429 EXPECT_NE(clusters["A"], clusters["C"]); 430 EXPECT_EQ(clusters["C"], clusters["D"]); 431 } 432 433 TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) { 434 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 435 GraphDef graphdef; 436 { 437 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); 438 Node* a = ops::SourceOp("Const", builder.opts() 439 .WithName("A") 440 .WithAttr("dtype", DT_FLOAT) 441 .WithAttr("value", Tensor()) 442 .WithAttr(kXlaScopeAttr, "ScopeA")); 443 Node* b = ops::UnaryOp( 444 "Relu", a, 445 builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "ScopeB")); 446 ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); 447 TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); 448 } 449 450 TF_ASSERT_OK(MarkForCompilation(&graph)); 451 auto clusters = GetClusters(*graph); 452 453 // The computation is: C = A @ relu(A) 454 // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC. 455 // In this case, we cannot fuse anything. 456 EXPECT_EQ(2, clusters.size()); 457 EXPECT_NE(clusters["A"], clusters["B"]); 458 EXPECT_EQ(clusters["B"], clusters["C"]); 459 } 460 461 REGISTER_OP("ResourceInput").Input("a: resource").Output("o: float"); 462 REGISTER_OP("ResourceOutput").Input("a: float").Output("o: resource"); 463 464 namespace { 465 466 class DummyOp : public XlaOpKernel { 467 using XlaOpKernel::XlaOpKernel; 468 void Compile(XlaOpKernelContext* ctx) override {} 469 }; 470 471 REGISTER_XLA_OP(Name("ResourceInput"), DummyOp); 472 REGISTER_XLA_OP(Name("ResourceOutput"), DummyOp); 473 474 } // namespace 475 476 TEST(XlaCompilationTest, Resources) { 477 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 478 GraphDef graphdef; 479 { 480 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); 481 Node* a = 482 ops::SourceOp("UncompilableNullary", builder.opts().WithName("A")); 483 Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); 484 // We should not form clusters with resource ops by default. 485 Node* c = ops::UnaryOp("ResourceOutput", b, builder.opts().WithName("C")); 486 Node* d = ops::UnaryOp("ResourceInput", c, builder.opts().WithName("D")); 487 ops::UnaryOp("Relu", d, builder.opts().WithName("E")); 488 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); 489 } 490 TF_ASSERT_OK(MarkForCompilation(&graph)); 491 auto clusters = GetClusters(*graph); 492 EXPECT_EQ(0, clusters.size()); // Nothing should be compiled. 493 } 494 495 TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { 496 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 497 Scope root = Scope::NewRootScope().ExitOnError(); 498 { 499 auto BuildNoopNode = [](StringPiece name, Graph* graph) { 500 NodeDefBuilder builder(name, "NoOp"); 501 NodeDef def; 502 TF_CHECK_OK(builder.Finalize(&def)); 503 504 Status status; 505 Node* node = graph->AddNode(def, &status); 506 TF_CHECK_OK(status); 507 return node; 508 }; 509 510 Node* a = BuildNoopNode("a", graph.get()); 511 Node* b = BuildNoopNode("b", graph.get()); 512 Node* c = BuildNoopNode("c", graph.get()); 513 graph->AddControlEdge(a, b); 514 graph->AddControlEdge(b, c); 515 graph->AddControlEdge(c, a); 516 } 517 518 TF_EXPECT_OK(root.ToGraph(graph.get())); 519 520 Status status = MarkForCompilation(&graph); 521 EXPECT_FALSE(status.ok()); 522 EXPECT_TRUE(StringPiece(status.ToString()) 523 .contains("Edge from c to a would create a cycle.\n" 524 "+-> a\n" 525 "| b\n" 526 "+-- c\n")); 527 } 528 529 TEST(XlaCompilationTest, Retval) { 530 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 531 GraphDef graphdef; 532 { 533 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); 534 Node* a = ops::SourceOp("Const", builder.opts() 535 .WithName("A") 536 .WithAttr("dtype", DT_FLOAT) 537 .WithAttr("value", Tensor())); 538 Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); 539 ops::UnaryOp("_Retval", b, 540 builder.opts() 541 .WithName("R") 542 .WithAttr("T", DT_FLOAT) 543 .WithAttr("index", 0)); 544 545 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); 546 } 547 548 TF_ASSERT_OK(MarkForCompilation(&graph)); 549 auto clusters = GetClusters(*graph); 550 551 EXPECT_EQ(2, clusters.size()); 552 EXPECT_TRUE(clusters.find("R") == clusters.cend()); 553 EXPECT_EQ(clusters["A"], clusters["B"]); 554 } 555 556 } // namespace 557 } // namespace tensorflow 558