1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/call_graph.h" 17 18 #include "tensorflow/compiler/xla/literal_util.h" 19 #include "tensorflow/compiler/xla/service/hlo_computation.h" 20 #include "tensorflow/compiler/xla/shape_util.h" 21 #include "tensorflow/compiler/xla/status_macros.h" 22 #include "tensorflow/compiler/xla/test.h" 23 #include "tensorflow/compiler/xla/test_helpers.h" 24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 25 #include "tensorflow/compiler/xla/util.h" 26 #include "tensorflow/compiler/xla/xla_data.pb.h" 27 #include "tensorflow/core/lib/core/status_test_util.h" 28 29 namespace xla { 30 namespace { 31 32 using ::testing::UnorderedElementsAre; 33 34 class CallGraphTest : public HloTestBase { 35 protected: 36 // Build and return a trivial computation taking and returning a scalar. 37 std::unique_ptr<HloComputation> MakeScalarComputation( 38 HloOpcode opcode = HloOpcode::kNegate) { 39 HloComputation::Builder builder(TestName() + ".ScalarComputation"); 40 HloInstruction* param0 = builder.AddInstruction( 41 HloInstruction::CreateParameter(0, kScalarShape, "param0")); 42 builder.AddInstruction( 43 HloInstruction::CreateUnary(kScalarShape, opcode, param0)); 44 return builder.Build(); 45 } 46 47 // Build and return a computation which takes a scalar and maps (kMap) the 48 // given computation to the value 'callsites' number of times. 49 std::unique_ptr<HloComputation> MakeMappingComputation( 50 HloComputation* map_computation, int64 callsites) { 51 HloComputation::Builder builder(TestName() + ".MappingComputation"); 52 HloInstruction* param0 = builder.AddInstruction( 53 HloInstruction::CreateParameter(0, kScalarShape, "param0")); 54 HloInstruction* last_value = param0; 55 for (int64 i = 0; i < callsites; ++i) { 56 last_value = builder.AddInstruction(HloInstruction::CreateMap( 57 kScalarShape, {last_value}, map_computation)); 58 } 59 return builder.Build(); 60 } 61 62 // Build and return a computation which takes a scalar and calls (kCall) the 63 // given computation with value 'callsites' number of times. 64 std::unique_ptr<HloComputation> MakeCallingComputation( 65 HloComputation* callee_computation, int64 callsites, 66 const string& suffix = ".CallingComputation") { 67 HloComputation::Builder builder(TestName() + suffix); 68 HloInstruction* param0 = builder.AddInstruction( 69 HloInstruction::CreateParameter(0, kScalarShape, "param0")); 70 HloInstruction* last_value = param0; 71 for (int64 i = 0; i < callsites; ++i) { 72 last_value = builder.AddInstruction(HloInstruction::CreateCall( 73 kScalarShape, {last_value}, callee_computation)); 74 } 75 return builder.Build(); 76 } 77 78 // Build and return a computation which takes a scalar and returns a PRED 79 // value. 80 std::unique_ptr<HloComputation> MakeConditionComputation() { 81 HloComputation::Builder builder(TestName() + ".ConditionComputation"); 82 HloInstruction* param0 = builder.AddInstruction( 83 HloInstruction::CreateParameter(0, kScalarShape, "param0")); 84 HloInstruction* zero = builder.AddInstruction( 85 HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))); 86 builder.AddInstruction(HloInstruction::CreateBinary( 87 ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero)); 88 return builder.Build(); 89 } 90 91 const Shape kScalarShape = ShapeUtil::MakeShape(F32, {}); 92 }; 93 94 TEST_F(CallGraphTest, SingletonComputation) { 95 // Test the call graph of a module with a single computation. 96 auto module = CreateNewModule(); 97 HloComputation* computation = 98 module->AddEntryComputation(MakeScalarComputation()); 99 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get()); 100 EXPECT_EQ(1, call_graph->nodes().size()); 101 EXPECT_TRUE(call_graph->IsFlattened()); 102 103 const CallGraphNode& node = call_graph->GetNode(computation); 104 EXPECT_EQ(computation, node.computation()); 105 EXPECT_TRUE(node.callsites().empty()); 106 EXPECT_TRUE(node.callees().empty()); 107 EXPECT_TRUE(node.caller_callsites().empty()); 108 EXPECT_TRUE(node.callers().empty()); 109 EXPECT_EQ(CallContext::kSequential, node.context()); 110 } 111 112 TEST_F(CallGraphTest, UnreachableComputation) { 113 // Test the call graph of a module with an entry computation and an 114 // unreachable computation. 115 auto module = CreateNewModule(); 116 HloComputation* entry_computation = 117 module->AddEntryComputation(MakeScalarComputation()); 118 HloComputation* unreachable_computation = 119 module->AddEmbeddedComputation(MakeScalarComputation()); 120 121 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get()); 122 EXPECT_EQ(2, call_graph->nodes().size()); 123 124 const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); 125 EXPECT_EQ(entry_computation, entry_node.computation()); 126 EXPECT_EQ(CallContext::kSequential, entry_node.context()); 127 128 const CallGraphNode& unreachable_node = 129 call_graph->GetNode(unreachable_computation); 130 EXPECT_EQ(unreachable_computation, unreachable_node.computation()); 131 EXPECT_EQ(CallContext::kSequential, unreachable_node.context()); 132 } 133 134 TEST_F(CallGraphTest, ParallelComputation) { 135 // Test a call graph of a module with an entry computation which calls another 136 // computation in a parallel context via kMap. 137 auto module = CreateNewModule(); 138 HloComputation* map_computation = 139 module->AddEmbeddedComputation(MakeScalarComputation()); 140 HloComputation* entry_computation = module->AddEntryComputation( 141 MakeMappingComputation(map_computation, /*callsites=*/5)); 142 143 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get()); 144 EXPECT_EQ(2, call_graph->nodes().size()); 145 146 const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); 147 EXPECT_EQ(entry_computation, entry_node.computation()); 148 EXPECT_EQ(CallContext::kSequential, entry_node.context()); 149 EXPECT_EQ(5, entry_node.callsites().size()); 150 EXPECT_EQ(1, entry_node.callees().size()); 151 EXPECT_TRUE(entry_node.caller_callsites().empty()); 152 EXPECT_TRUE(entry_node.callers().empty()); 153 154 const CallGraphNode& map_node = call_graph->GetNode(map_computation); 155 EXPECT_EQ(map_computation, map_node.computation()); 156 EXPECT_EQ(CallContext::kParallel, map_node.context()); 157 EXPECT_TRUE(map_node.callsites().empty()); 158 EXPECT_TRUE(map_node.callees().empty()); 159 EXPECT_EQ(5, map_node.caller_callsites().size()); 160 EXPECT_EQ(1, map_node.callers().size()); 161 } 162 163 TEST_F(CallGraphTest, SequentialComputations) { 164 // Test a call graph of a module with an entry computation which calls another 165 // computation in a sequential context via kCall. 166 auto module = CreateNewModule(); 167 HloComputation* called_computation = 168 module->AddEmbeddedComputation(MakeScalarComputation()); 169 HloComputation* entry_computation = module->AddEntryComputation( 170 MakeCallingComputation(called_computation, /*callsites=*/3)); 171 172 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get()); 173 EXPECT_EQ(2, call_graph->nodes().size()); 174 175 // The called computation is only called from one other computation, but there 176 // are multiple callsites. 177 EXPECT_FALSE(call_graph->IsFlattened()); 178 179 const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); 180 EXPECT_EQ(entry_computation, entry_node.computation()); 181 EXPECT_EQ(CallContext::kSequential, entry_node.context()); 182 EXPECT_EQ(3, entry_node.callsites().size()); 183 EXPECT_EQ(1, entry_node.callees().size()); 184 EXPECT_TRUE(entry_node.caller_callsites().empty()); 185 EXPECT_TRUE(entry_node.callers().empty()); 186 187 const CallGraphNode& called_node = call_graph->GetNode(called_computation); 188 EXPECT_EQ(called_computation, called_node.computation()); 189 EXPECT_EQ(CallContext::kSequential, called_node.context()); 190 EXPECT_TRUE(called_node.callsites().empty()); 191 EXPECT_TRUE(called_node.callees().empty()); 192 EXPECT_EQ(3, called_node.caller_callsites().size()); 193 EXPECT_EQ(1, called_node.callers().size()); 194 } 195 196 TEST_F(CallGraphTest, ContextBothComputations) { 197 // Test a call graph of a module with an entry computation which calls another 198 // computation in both a parallel and sequential context. 199 auto module = CreateNewModule(); 200 HloComputation* subcomputation = 201 module->AddEmbeddedComputation(MakeScalarComputation()); 202 203 HloComputation::Builder builder(TestName()); 204 HloInstruction* param0 = builder.AddInstruction( 205 HloInstruction::CreateParameter(0, kScalarShape, "param0")); 206 HloInstruction* call = builder.AddInstruction( 207 HloInstruction::CreateCall(kScalarShape, {param0}, subcomputation)); 208 HloInstruction* map = builder.AddInstruction( 209 HloInstruction::CreateMap(kScalarShape, {call}, subcomputation)); 210 HloComputation* entry_computation = 211 module->AddEntryComputation(builder.Build()); 212 213 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get()); 214 EXPECT_EQ(2, call_graph->nodes().size()); 215 216 EXPECT_FALSE(call_graph->IsFlattened()); 217 218 const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); 219 EXPECT_EQ(entry_computation, entry_node.computation()); 220 EXPECT_EQ(2, entry_node.callsites().size()); 221 222 const CallSite& call_callsite = entry_node.callsites()[0]; 223 EXPECT_EQ(call, call_callsite.instruction()); 224 EXPECT_THAT(call_callsite.called_computations(), 225 UnorderedElementsAre(subcomputation)); 226 EXPECT_EQ(CallContext::kSequential, call_callsite.context()); 227 EXPECT_EQ(entry_node.GetCallSite(call), &call_callsite); 228 229 const CallSite& map_callsite = entry_node.callsites()[1]; 230 EXPECT_EQ(map, map_callsite.instruction()); 231 EXPECT_THAT(map_callsite.called_computations(), 232 UnorderedElementsAre(subcomputation)); 233 EXPECT_EQ(CallContext::kParallel, map_callsite.context()); 234 EXPECT_EQ(entry_node.GetCallSite(map), &map_callsite); 235 236 const CallGraphNode& sub_node = call_graph->GetNode(subcomputation); 237 EXPECT_EQ(CallContext::kBoth, sub_node.context()); 238 } 239 240 TEST_F(CallGraphTest, ComputationWithConditional) { 241 // Test a call graph of a module with a conditional. 242 auto module = CreateNewModule(); 243 HloComputation* true_computation = 244 module->AddEmbeddedComputation(MakeScalarComputation(HloOpcode::kCeil)); 245 HloComputation* false_computation = 246 module->AddEmbeddedComputation(MakeScalarComputation(HloOpcode::kFloor)); 247 248 HloComputation::Builder builder(TestName()); 249 HloInstruction* pred = builder.AddInstruction( 250 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 251 HloInstruction* const1 = builder.AddInstruction( 252 HloInstruction::CreateConstant(Literal::CreateR0<float>(56.4f))); 253 HloInstruction* const2 = builder.AddInstruction( 254 HloInstruction::CreateConstant(Literal::CreateR0<float>(12.6f))); 255 HloInstruction* conditional = 256 builder.AddInstruction(HloInstruction::CreateConditional( 257 kScalarShape, pred, const1, true_computation, const2, 258 false_computation)); 259 HloComputation* entry_computation = 260 module->AddEntryComputation(builder.Build()); 261 262 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get()); 263 264 EXPECT_EQ(3, call_graph->nodes().size()); 265 266 const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); 267 EXPECT_EQ(entry_computation, entry_node.computation()); 268 EXPECT_EQ(1, entry_node.callsites().size()); 269 270 const CallSite& conditional_callsite = entry_node.callsites()[0]; 271 EXPECT_EQ(conditional, conditional_callsite.instruction()); 272 EXPECT_THAT(conditional_callsite.called_computations(), 273 UnorderedElementsAre(true_computation, false_computation)); 274 EXPECT_EQ(CallContext::kSequential, conditional_callsite.context()); 275 EXPECT_EQ(entry_node.GetCallSite(conditional), &conditional_callsite); 276 277 const CallGraphNode& true_node = call_graph->GetNode(true_computation); 278 EXPECT_TRUE(true_node.callees().empty()); 279 EXPECT_EQ(1, true_node.callers().size()); 280 EXPECT_EQ(entry_computation, true_node.callers()[0]); 281 282 const CallGraphNode& false_node = call_graph->GetNode(false_computation); 283 EXPECT_TRUE(false_node.callees().empty()); 284 EXPECT_EQ(1, false_node.callers().size()); 285 EXPECT_EQ(entry_computation, false_node.callers()[0]); 286 } 287 288 TEST_F(CallGraphTest, ComplexGraph) { 289 // Test a call graph of a module with several computation called in various 290 // contexts. The call graph looks like: 291 // 292 // entry 293 // / | 294 // a | 295 // / | \ | 296 // b | cond 297 // \ | 298 // c 299 // 300 // Calls are made via kCall, kWhile, and kMap instructions. 301 auto module = CreateNewModule(); 302 HloComputation* cond_computation = 303 module->AddEmbeddedComputation(MakeConditionComputation()); 304 HloComputation* c_computation = 305 module->AddEmbeddedComputation(MakeScalarComputation()); 306 HloComputation* b_computation = module->AddEmbeddedComputation( 307 MakeMappingComputation(c_computation, /*callsites=*/1)); 308 309 HloComputation* a_computation; 310 { 311 HloComputation::Builder builder(TestName() + ".a"); 312 HloInstruction* param0 = builder.AddInstruction( 313 HloInstruction::CreateParameter(0, kScalarShape, "param0")); 314 HloInstruction* call = builder.AddInstruction( 315 HloInstruction::CreateCall(kScalarShape, {param0}, c_computation)); 316 builder.AddInstruction(HloInstruction::CreateWhile( 317 kScalarShape, cond_computation, b_computation, call)); 318 a_computation = module->AddEmbeddedComputation(builder.Build()); 319 } 320 321 HloComputation* entry_computation; 322 { 323 HloComputation::Builder builder(TestName() + ".entry"); 324 HloInstruction* param0 = builder.AddInstruction( 325 HloInstruction::CreateParameter(0, kScalarShape, "param0")); 326 builder.AddInstruction(HloInstruction::CreateWhile( 327 kScalarShape, cond_computation, a_computation, param0)); 328 entry_computation = module->AddEntryComputation(builder.Build()); 329 } 330 331 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get()); 332 EXPECT_EQ(5, call_graph->nodes().size()); 333 EXPECT_FALSE(call_graph->IsFlattened()); 334 335 // Entry computation has one while instruction calling two computations 336 // (cond_computation and a_computation). 337 const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); 338 ASSERT_EQ(1, entry_node.callsites().size()); 339 const std::vector<HloComputation*>& called_computations = 340 entry_node.callsites()[0].called_computations(); 341 EXPECT_THAT(called_computations, 342 UnorderedElementsAre(cond_computation, a_computation)); 343 EXPECT_EQ(CallContext::kSequential, entry_node.context()); 344 345 const CallGraphNode& c_node = call_graph->GetNode(c_computation); 346 EXPECT_TRUE(c_node.callsites().empty()); 347 EXPECT_THAT(c_node.callers(), 348 UnorderedElementsAre(a_computation, b_computation)); 349 EXPECT_EQ(CallContext::kBoth, c_node.context()); 350 351 // Visit the graph and verify nodes were visited in callee-before-caller 352 // order. 353 std::vector<const HloComputation*> visited; 354 TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) { 355 visited.push_back(node.computation()); 356 return Status::OK(); 357 })); 358 EXPECT_EQ(visited.size(), 5); 359 // All values in visited should be unique. 360 EXPECT_EQ( 361 std::unordered_set<const HloComputation*>(visited.begin(), visited.end()) 362 .size(), 363 5); 364 365 // Verify visitation order of some computations in the graph. 366 auto index_of = [&visited](const HloComputation* comp) { 367 auto it = std::find(visited.begin(), visited.end(), comp); 368 EXPECT_NE(it, visited.end()); 369 return std::distance(visited.begin(), it); 370 }; 371 EXPECT_EQ(4, index_of(entry_computation)); 372 EXPECT_LT(index_of(cond_computation), index_of(a_computation)); 373 EXPECT_LT(index_of(c_computation), index_of(b_computation)); 374 EXPECT_LT(index_of(b_computation), index_of(a_computation)); 375 376 // Verify dominance relations between computation in the graph. 377 378 // Entry dominates everybody, and is dominated by no one except itself. 379 EXPECT_TRUE(call_graph->Dominates(entry_computation, entry_computation)); 380 EXPECT_TRUE(call_graph->Dominates(entry_computation, a_computation)); 381 EXPECT_TRUE(call_graph->Dominates(entry_computation, b_computation)); 382 EXPECT_TRUE(call_graph->Dominates(entry_computation, c_computation)); 383 EXPECT_TRUE(call_graph->Dominates(entry_computation, cond_computation)); 384 EXPECT_FALSE(call_graph->Dominates(a_computation, entry_computation)); 385 EXPECT_FALSE(call_graph->Dominates(b_computation, entry_computation)); 386 EXPECT_FALSE(call_graph->Dominates(c_computation, entry_computation)); 387 EXPECT_FALSE(call_graph->Dominates(cond_computation, entry_computation)); 388 389 // 'a' only dominates 'b' and 'c'. 390 EXPECT_TRUE(call_graph->Dominates(a_computation, a_computation)); 391 EXPECT_TRUE(call_graph->Dominates(a_computation, b_computation)); 392 EXPECT_TRUE(call_graph->Dominates(a_computation, c_computation)); 393 EXPECT_FALSE(call_graph->Dominates(b_computation, a_computation)); 394 EXPECT_FALSE(call_graph->Dominates(c_computation, a_computation)); 395 EXPECT_FALSE(call_graph->Dominates(a_computation, cond_computation)); 396 397 EXPECT_TRUE(call_graph->Dominates(b_computation, b_computation)); 398 EXPECT_FALSE(call_graph->Dominates(b_computation, c_computation)); 399 EXPECT_FALSE(call_graph->Dominates(b_computation, cond_computation)); 400 401 EXPECT_TRUE(call_graph->Dominates(c_computation, c_computation)); 402 EXPECT_FALSE(call_graph->Dominates(c_computation, cond_computation)); 403 EXPECT_FALSE(call_graph->Dominates(cond_computation, c_computation)); 404 405 EXPECT_TRUE(call_graph->Dominates(cond_computation, cond_computation)); 406 } 407 408 TEST_F(CallGraphTest, ComplexGraphNearestAncestors) { 409 // Test NearestAncestorsInSameComputation on a call graph of a module with 410 // several computation called in various contexts. The call graph looks like: 411 // 412 // entry 413 // / | 414 // a | 415 // / | \ | 416 // b | cond 417 // \ | 418 // c 419 // 420 // Calls are made via kCall, kWhile, and kMap instructions. 421 auto module = CreateNewModule(); 422 HloComputation* cond_computation = 423 module->AddEmbeddedComputation(MakeConditionComputation()); 424 HloComputation* c_computation = 425 module->AddEmbeddedComputation(MakeScalarComputation()); 426 HloComputation* b_computation = module->AddEmbeddedComputation( 427 MakeMappingComputation(c_computation, /*callsites=*/1)); 428 HloInstruction* b_map = b_computation->root_instruction(); 429 430 HloComputation* a_computation; 431 HloInstruction* a_call; 432 HloInstruction* a_while; 433 { 434 HloComputation::Builder builder(TestName() + ".a"); 435 HloInstruction* param0 = builder.AddInstruction( 436 HloInstruction::CreateParameter(0, kScalarShape, "param0")); 437 a_call = builder.AddInstruction( 438 HloInstruction::CreateCall(kScalarShape, {param0}, c_computation)); 439 a_while = builder.AddInstruction(HloInstruction::CreateWhile( 440 kScalarShape, cond_computation, b_computation, a_call)); 441 a_computation = module->AddEmbeddedComputation(builder.Build()); 442 } 443 444 HloComputation* entry_computation; 445 HloInstruction* entry_while; 446 { 447 HloComputation::Builder builder(TestName() + ".entry"); 448 HloInstruction* param0 = builder.AddInstruction( 449 HloInstruction::CreateParameter(0, kScalarShape, "param0")); 450 entry_while = builder.AddInstruction(HloInstruction::CreateWhile( 451 kScalarShape, cond_computation, a_computation, param0)); 452 entry_computation = module->AddEntryComputation(builder.Build()); 453 } 454 455 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get()); 456 EXPECT_EQ(5, call_graph->nodes().size()); 457 458 // Verify NearestAncestorsInSameComputation for various instructions in the 459 // module. 460 EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(a_call, a_call), 461 std::make_pair(a_call, a_call)); 462 463 // c_computation is called from more than one site, so 464 // NearestAncestorsInSameComputation bails and returns nullptrs. 465 std::pair<HloInstruction*, HloInstruction*> null_pair = {nullptr, nullptr}; 466 EXPECT_EQ(call_graph->NearestAncestorsInSameComputation( 467 b_map, c_computation->root_instruction()), 468 null_pair); 469 470 EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(b_map, entry_while), 471 std::make_pair(entry_while, entry_while)); 472 EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(b_map, a_call), 473 std::make_pair(a_while, a_call)); 474 EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(a_while, a_call), 475 std::make_pair(a_while, a_call)); 476 EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(a_while, b_map), 477 std::make_pair(a_while, a_while)); 478 } 479 480 TEST_F(CallGraphTest, VisitSingletonComputation) { 481 // Test the call graph visitor with a call graph with a single node. 482 auto module = CreateNewModule(); 483 HloComputation* computation = 484 module->AddEntryComputation(MakeScalarComputation()); 485 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get()); 486 487 std::vector<HloComputation*> visited; 488 TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) { 489 visited.push_back(node.computation()); 490 return Status::OK(); 491 })); 492 EXPECT_THAT(visited, UnorderedElementsAre(computation)); 493 } 494 495 TEST_F(CallGraphTest, VisitUnreachableComputation) { 496 // Test the call graph visitor with a call graph with an unreachable node. 497 auto module = CreateNewModule(); 498 HloComputation* entry_computation = 499 module->AddEntryComputation(MakeScalarComputation()); 500 HloComputation* unreachable_computation = 501 module->AddEmbeddedComputation(MakeScalarComputation()); 502 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get()); 503 504 // Test visitation of only reachable nodes. 505 { 506 std::vector<const HloComputation*> visited; 507 TF_ASSERT_OK(call_graph->VisitNodes( 508 [&visited](const CallGraphNode& node) { 509 visited.push_back(node.computation()); 510 return Status::OK(); 511 }, 512 /*visit_unreachable_nodes=*/false)); 513 EXPECT_EQ(visited.size(), 1); 514 EXPECT_EQ(visited[0], entry_computation); 515 } 516 517 // Test visitation of all nodes (reachable and unreachable). 518 { 519 std::vector<HloComputation*> visited; 520 TF_ASSERT_OK(call_graph->VisitNodes( 521 [&visited](const CallGraphNode& node) { 522 visited.push_back(node.computation()); 523 return Status::OK(); 524 }, 525 /*visit_unreachable_nodes=*/true)); 526 EXPECT_EQ(visited.size(), 2); 527 EXPECT_THAT(visited, UnorderedElementsAre(entry_computation, 528 unreachable_computation)); 529 } 530 } 531 532 TEST_F(CallGraphTest, VisitWithError) { 533 // Test that the call graph visitor properly propagates errors. 534 auto module = CreateNewModule(); 535 module->AddEntryComputation(MakeScalarComputation()); 536 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get()); 537 538 Status status = call_graph->VisitNodes( 539 [](const CallGraphNode&) { return InternalError("Visitation failed"); }); 540 541 ASSERT_FALSE(status.ok()); 542 ASSERT_EQ(status.code(), tensorflow::error::INTERNAL); 543 ASSERT_THAT(status.error_message(), 544 ::testing::HasSubstr("Visitation failed")); 545 } 546 547 } // namespace 548 } // namespace xla 549