Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7 http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/call_graph.h"
     17 
     18 #include "tensorflow/compiler/xla/literal_util.h"
     19 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     20 #include "tensorflow/compiler/xla/shape_util.h"
     21 #include "tensorflow/compiler/xla/status_macros.h"
     22 #include "tensorflow/compiler/xla/test.h"
     23 #include "tensorflow/compiler/xla/test_helpers.h"
     24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     25 #include "tensorflow/compiler/xla/util.h"
     26 #include "tensorflow/compiler/xla/xla_data.pb.h"
     27 #include "tensorflow/core/lib/core/status_test_util.h"
     28 
     29 namespace xla {
     30 namespace {
     31 
     32 using ::testing::UnorderedElementsAre;
     33 
     34 class CallGraphTest : public HloTestBase {
     35  protected:
     36   // Build and return a trivial computation taking and returning a scalar.
     37   std::unique_ptr<HloComputation> MakeScalarComputation(
     38       HloOpcode opcode = HloOpcode::kNegate) {
     39     HloComputation::Builder builder(TestName() + ".ScalarComputation");
     40     HloInstruction* param0 = builder.AddInstruction(
     41         HloInstruction::CreateParameter(0, kScalarShape, "param0"));
     42     builder.AddInstruction(
     43         HloInstruction::CreateUnary(kScalarShape, opcode, param0));
     44     return builder.Build();
     45   }
     46 
     47   // Build and return a computation which takes a scalar and maps (kMap) the
     48   // given computation to the value 'callsites' number of times.
     49   std::unique_ptr<HloComputation> MakeMappingComputation(
     50       HloComputation* map_computation, int64 callsites) {
     51     HloComputation::Builder builder(TestName() + ".MappingComputation");
     52     HloInstruction* param0 = builder.AddInstruction(
     53         HloInstruction::CreateParameter(0, kScalarShape, "param0"));
     54     HloInstruction* last_value = param0;
     55     for (int64 i = 0; i < callsites; ++i) {
     56       last_value = builder.AddInstruction(HloInstruction::CreateMap(
     57           kScalarShape, {last_value}, map_computation));
     58     }
     59     return builder.Build();
     60   }
     61 
     62   // Build and return a computation which takes a scalar and calls (kCall) the
     63   // given computation with value 'callsites' number of times.
     64   std::unique_ptr<HloComputation> MakeCallingComputation(
     65       HloComputation* callee_computation, int64 callsites,
     66       const string& suffix = ".CallingComputation") {
     67     HloComputation::Builder builder(TestName() + suffix);
     68     HloInstruction* param0 = builder.AddInstruction(
     69         HloInstruction::CreateParameter(0, kScalarShape, "param0"));
     70     HloInstruction* last_value = param0;
     71     for (int64 i = 0; i < callsites; ++i) {
     72       last_value = builder.AddInstruction(HloInstruction::CreateCall(
     73           kScalarShape, {last_value}, callee_computation));
     74     }
     75     return builder.Build();
     76   }
     77 
     78   // Build and return a computation which takes a scalar and returns a PRED
     79   // value.
     80   std::unique_ptr<HloComputation> MakeConditionComputation() {
     81     HloComputation::Builder builder(TestName() + ".ConditionComputation");
     82     HloInstruction* param0 = builder.AddInstruction(
     83         HloInstruction::CreateParameter(0, kScalarShape, "param0"));
     84     HloInstruction* zero = builder.AddInstruction(
     85         HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
     86     builder.AddInstruction(HloInstruction::CreateBinary(
     87         ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero));
     88     return builder.Build();
     89   }
     90 
     91   const Shape kScalarShape = ShapeUtil::MakeShape(F32, {});
     92 };
     93 
     94 TEST_F(CallGraphTest, SingletonComputation) {
     95   // Test the call graph of a module with a single computation.
     96   auto module = CreateNewModule();
     97   HloComputation* computation =
     98       module->AddEntryComputation(MakeScalarComputation());
     99   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
    100   EXPECT_EQ(1, call_graph->nodes().size());
    101   EXPECT_TRUE(call_graph->IsFlattened());
    102 
    103   const CallGraphNode& node = call_graph->GetNode(computation);
    104   EXPECT_EQ(computation, node.computation());
    105   EXPECT_TRUE(node.callsites().empty());
    106   EXPECT_TRUE(node.callees().empty());
    107   EXPECT_TRUE(node.caller_callsites().empty());
    108   EXPECT_TRUE(node.callers().empty());
    109   EXPECT_EQ(CallContext::kSequential, node.context());
    110 }
    111 
    112 TEST_F(CallGraphTest, UnreachableComputation) {
    113   // Test the call graph of a module with an entry computation and an
    114   // unreachable computation.
    115   auto module = CreateNewModule();
    116   HloComputation* entry_computation =
    117       module->AddEntryComputation(MakeScalarComputation());
    118   HloComputation* unreachable_computation =
    119       module->AddEmbeddedComputation(MakeScalarComputation());
    120 
    121   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
    122   EXPECT_EQ(2, call_graph->nodes().size());
    123 
    124   const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
    125   EXPECT_EQ(entry_computation, entry_node.computation());
    126   EXPECT_EQ(CallContext::kSequential, entry_node.context());
    127 
    128   const CallGraphNode& unreachable_node =
    129       call_graph->GetNode(unreachable_computation);
    130   EXPECT_EQ(unreachable_computation, unreachable_node.computation());
    131   EXPECT_EQ(CallContext::kSequential, unreachable_node.context());
    132 }
    133 
    134 TEST_F(CallGraphTest, ParallelComputation) {
    135   // Test a call graph of a module with an entry computation which calls another
    136   // computation in a parallel context via kMap.
    137   auto module = CreateNewModule();
    138   HloComputation* map_computation =
    139       module->AddEmbeddedComputation(MakeScalarComputation());
    140   HloComputation* entry_computation = module->AddEntryComputation(
    141       MakeMappingComputation(map_computation, /*callsites=*/5));
    142 
    143   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
    144   EXPECT_EQ(2, call_graph->nodes().size());
    145 
    146   const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
    147   EXPECT_EQ(entry_computation, entry_node.computation());
    148   EXPECT_EQ(CallContext::kSequential, entry_node.context());
    149   EXPECT_EQ(5, entry_node.callsites().size());
    150   EXPECT_EQ(1, entry_node.callees().size());
    151   EXPECT_TRUE(entry_node.caller_callsites().empty());
    152   EXPECT_TRUE(entry_node.callers().empty());
    153 
    154   const CallGraphNode& map_node = call_graph->GetNode(map_computation);
    155   EXPECT_EQ(map_computation, map_node.computation());
    156   EXPECT_EQ(CallContext::kParallel, map_node.context());
    157   EXPECT_TRUE(map_node.callsites().empty());
    158   EXPECT_TRUE(map_node.callees().empty());
    159   EXPECT_EQ(5, map_node.caller_callsites().size());
    160   EXPECT_EQ(1, map_node.callers().size());
    161 }
    162 
    163 TEST_F(CallGraphTest, SequentialComputations) {
    164   // Test a call graph of a module with an entry computation which calls another
    165   // computation in a sequential context via kCall.
    166   auto module = CreateNewModule();
    167   HloComputation* called_computation =
    168       module->AddEmbeddedComputation(MakeScalarComputation());
    169   HloComputation* entry_computation = module->AddEntryComputation(
    170       MakeCallingComputation(called_computation, /*callsites=*/3));
    171 
    172   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
    173   EXPECT_EQ(2, call_graph->nodes().size());
    174 
    175   // The called computation is only called from one other computation, but there
    176   // are multiple callsites.
    177   EXPECT_FALSE(call_graph->IsFlattened());
    178 
    179   const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
    180   EXPECT_EQ(entry_computation, entry_node.computation());
    181   EXPECT_EQ(CallContext::kSequential, entry_node.context());
    182   EXPECT_EQ(3, entry_node.callsites().size());
    183   EXPECT_EQ(1, entry_node.callees().size());
    184   EXPECT_TRUE(entry_node.caller_callsites().empty());
    185   EXPECT_TRUE(entry_node.callers().empty());
    186 
    187   const CallGraphNode& called_node = call_graph->GetNode(called_computation);
    188   EXPECT_EQ(called_computation, called_node.computation());
    189   EXPECT_EQ(CallContext::kSequential, called_node.context());
    190   EXPECT_TRUE(called_node.callsites().empty());
    191   EXPECT_TRUE(called_node.callees().empty());
    192   EXPECT_EQ(3, called_node.caller_callsites().size());
    193   EXPECT_EQ(1, called_node.callers().size());
    194 }
    195 
    196 TEST_F(CallGraphTest, ContextBothComputations) {
    197   // Test a call graph of a module with an entry computation which calls another
    198   // computation in both a parallel and sequential context.
    199   auto module = CreateNewModule();
    200   HloComputation* subcomputation =
    201       module->AddEmbeddedComputation(MakeScalarComputation());
    202 
    203   HloComputation::Builder builder(TestName());
    204   HloInstruction* param0 = builder.AddInstruction(
    205       HloInstruction::CreateParameter(0, kScalarShape, "param0"));
    206   HloInstruction* call = builder.AddInstruction(
    207       HloInstruction::CreateCall(kScalarShape, {param0}, subcomputation));
    208   HloInstruction* map = builder.AddInstruction(
    209       HloInstruction::CreateMap(kScalarShape, {call}, subcomputation));
    210   HloComputation* entry_computation =
    211       module->AddEntryComputation(builder.Build());
    212 
    213   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
    214   EXPECT_EQ(2, call_graph->nodes().size());
    215 
    216   EXPECT_FALSE(call_graph->IsFlattened());
    217 
    218   const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
    219   EXPECT_EQ(entry_computation, entry_node.computation());
    220   EXPECT_EQ(2, entry_node.callsites().size());
    221 
    222   const CallSite& call_callsite = entry_node.callsites()[0];
    223   EXPECT_EQ(call, call_callsite.instruction());
    224   EXPECT_THAT(call_callsite.called_computations(),
    225               UnorderedElementsAre(subcomputation));
    226   EXPECT_EQ(CallContext::kSequential, call_callsite.context());
    227   EXPECT_EQ(entry_node.GetCallSite(call), &call_callsite);
    228 
    229   const CallSite& map_callsite = entry_node.callsites()[1];
    230   EXPECT_EQ(map, map_callsite.instruction());
    231   EXPECT_THAT(map_callsite.called_computations(),
    232               UnorderedElementsAre(subcomputation));
    233   EXPECT_EQ(CallContext::kParallel, map_callsite.context());
    234   EXPECT_EQ(entry_node.GetCallSite(map), &map_callsite);
    235 
    236   const CallGraphNode& sub_node = call_graph->GetNode(subcomputation);
    237   EXPECT_EQ(CallContext::kBoth, sub_node.context());
    238 }
    239 
    240 TEST_F(CallGraphTest, ComputationWithConditional) {
    241   // Test a call graph of a module with a conditional.
    242   auto module = CreateNewModule();
    243   HloComputation* true_computation =
    244       module->AddEmbeddedComputation(MakeScalarComputation(HloOpcode::kCeil));
    245   HloComputation* false_computation =
    246       module->AddEmbeddedComputation(MakeScalarComputation(HloOpcode::kFloor));
    247 
    248   HloComputation::Builder builder(TestName());
    249   HloInstruction* pred = builder.AddInstruction(
    250       HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
    251   HloInstruction* const1 = builder.AddInstruction(
    252       HloInstruction::CreateConstant(Literal::CreateR0<float>(56.4f)));
    253   HloInstruction* const2 = builder.AddInstruction(
    254       HloInstruction::CreateConstant(Literal::CreateR0<float>(12.6f)));
    255   HloInstruction* conditional =
    256       builder.AddInstruction(HloInstruction::CreateConditional(
    257           kScalarShape, pred, const1, true_computation, const2,
    258           false_computation));
    259   HloComputation* entry_computation =
    260       module->AddEntryComputation(builder.Build());
    261 
    262   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
    263 
    264   EXPECT_EQ(3, call_graph->nodes().size());
    265 
    266   const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
    267   EXPECT_EQ(entry_computation, entry_node.computation());
    268   EXPECT_EQ(1, entry_node.callsites().size());
    269 
    270   const CallSite& conditional_callsite = entry_node.callsites()[0];
    271   EXPECT_EQ(conditional, conditional_callsite.instruction());
    272   EXPECT_THAT(conditional_callsite.called_computations(),
    273               UnorderedElementsAre(true_computation, false_computation));
    274   EXPECT_EQ(CallContext::kSequential, conditional_callsite.context());
    275   EXPECT_EQ(entry_node.GetCallSite(conditional), &conditional_callsite);
    276 
    277   const CallGraphNode& true_node = call_graph->GetNode(true_computation);
    278   EXPECT_TRUE(true_node.callees().empty());
    279   EXPECT_EQ(1, true_node.callers().size());
    280   EXPECT_EQ(entry_computation, true_node.callers()[0]);
    281 
    282   const CallGraphNode& false_node = call_graph->GetNode(false_computation);
    283   EXPECT_TRUE(false_node.callees().empty());
    284   EXPECT_EQ(1, false_node.callers().size());
    285   EXPECT_EQ(entry_computation, false_node.callers()[0]);
    286 }
    287 
    288 TEST_F(CallGraphTest, ComplexGraph) {
    289   // Test a call graph of a module with several computation called in various
    290   // contexts. The call graph looks like:
    291   //
    292   //      entry
    293   //      /  |
    294   //     a   |
    295   //   / | \ |
    296   //  b  |  cond
    297   //   \ |
    298   //    c
    299   //
    300   // Calls are made via kCall, kWhile, and kMap instructions.
    301   auto module = CreateNewModule();
    302   HloComputation* cond_computation =
    303       module->AddEmbeddedComputation(MakeConditionComputation());
    304   HloComputation* c_computation =
    305       module->AddEmbeddedComputation(MakeScalarComputation());
    306   HloComputation* b_computation = module->AddEmbeddedComputation(
    307       MakeMappingComputation(c_computation, /*callsites=*/1));
    308 
    309   HloComputation* a_computation;
    310   {
    311     HloComputation::Builder builder(TestName() + ".a");
    312     HloInstruction* param0 = builder.AddInstruction(
    313         HloInstruction::CreateParameter(0, kScalarShape, "param0"));
    314     HloInstruction* call = builder.AddInstruction(
    315         HloInstruction::CreateCall(kScalarShape, {param0}, c_computation));
    316     builder.AddInstruction(HloInstruction::CreateWhile(
    317         kScalarShape, cond_computation, b_computation, call));
    318     a_computation = module->AddEmbeddedComputation(builder.Build());
    319   }
    320 
    321   HloComputation* entry_computation;
    322   {
    323     HloComputation::Builder builder(TestName() + ".entry");
    324     HloInstruction* param0 = builder.AddInstruction(
    325         HloInstruction::CreateParameter(0, kScalarShape, "param0"));
    326     builder.AddInstruction(HloInstruction::CreateWhile(
    327         kScalarShape, cond_computation, a_computation, param0));
    328     entry_computation = module->AddEntryComputation(builder.Build());
    329   }
    330 
    331   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
    332   EXPECT_EQ(5, call_graph->nodes().size());
    333   EXPECT_FALSE(call_graph->IsFlattened());
    334 
    335   // Entry computation has one while instruction calling two computations
    336   // (cond_computation and a_computation).
    337   const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
    338   ASSERT_EQ(1, entry_node.callsites().size());
    339   const std::vector<HloComputation*>& called_computations =
    340       entry_node.callsites()[0].called_computations();
    341   EXPECT_THAT(called_computations,
    342               UnorderedElementsAre(cond_computation, a_computation));
    343   EXPECT_EQ(CallContext::kSequential, entry_node.context());
    344 
    345   const CallGraphNode& c_node = call_graph->GetNode(c_computation);
    346   EXPECT_TRUE(c_node.callsites().empty());
    347   EXPECT_THAT(c_node.callers(),
    348               UnorderedElementsAre(a_computation, b_computation));
    349   EXPECT_EQ(CallContext::kBoth, c_node.context());
    350 
    351   // Visit the graph and verify nodes were visited in callee-before-caller
    352   // order.
    353   std::vector<const HloComputation*> visited;
    354   TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) {
    355     visited.push_back(node.computation());
    356     return Status::OK();
    357   }));
    358   EXPECT_EQ(visited.size(), 5);
    359   // All values in visited should be unique.
    360   EXPECT_EQ(
    361       std::unordered_set<const HloComputation*>(visited.begin(), visited.end())
    362           .size(),
    363       5);
    364 
    365   // Verify visitation order of some computations in the graph.
    366   auto index_of = [&visited](const HloComputation* comp) {
    367     auto it = std::find(visited.begin(), visited.end(), comp);
    368     EXPECT_NE(it, visited.end());
    369     return std::distance(visited.begin(), it);
    370   };
    371   EXPECT_EQ(4, index_of(entry_computation));
    372   EXPECT_LT(index_of(cond_computation), index_of(a_computation));
    373   EXPECT_LT(index_of(c_computation), index_of(b_computation));
    374   EXPECT_LT(index_of(b_computation), index_of(a_computation));
    375 
    376   // Verify dominance relations between computation in the graph.
    377 
    378   // Entry dominates everybody, and is dominated by no one except itself.
    379   EXPECT_TRUE(call_graph->Dominates(entry_computation, entry_computation));
    380   EXPECT_TRUE(call_graph->Dominates(entry_computation, a_computation));
    381   EXPECT_TRUE(call_graph->Dominates(entry_computation, b_computation));
    382   EXPECT_TRUE(call_graph->Dominates(entry_computation, c_computation));
    383   EXPECT_TRUE(call_graph->Dominates(entry_computation, cond_computation));
    384   EXPECT_FALSE(call_graph->Dominates(a_computation, entry_computation));
    385   EXPECT_FALSE(call_graph->Dominates(b_computation, entry_computation));
    386   EXPECT_FALSE(call_graph->Dominates(c_computation, entry_computation));
    387   EXPECT_FALSE(call_graph->Dominates(cond_computation, entry_computation));
    388 
    389   // 'a' only dominates 'b' and 'c'.
    390   EXPECT_TRUE(call_graph->Dominates(a_computation, a_computation));
    391   EXPECT_TRUE(call_graph->Dominates(a_computation, b_computation));
    392   EXPECT_TRUE(call_graph->Dominates(a_computation, c_computation));
    393   EXPECT_FALSE(call_graph->Dominates(b_computation, a_computation));
    394   EXPECT_FALSE(call_graph->Dominates(c_computation, a_computation));
    395   EXPECT_FALSE(call_graph->Dominates(a_computation, cond_computation));
    396 
    397   EXPECT_TRUE(call_graph->Dominates(b_computation, b_computation));
    398   EXPECT_FALSE(call_graph->Dominates(b_computation, c_computation));
    399   EXPECT_FALSE(call_graph->Dominates(b_computation, cond_computation));
    400 
    401   EXPECT_TRUE(call_graph->Dominates(c_computation, c_computation));
    402   EXPECT_FALSE(call_graph->Dominates(c_computation, cond_computation));
    403   EXPECT_FALSE(call_graph->Dominates(cond_computation, c_computation));
    404 
    405   EXPECT_TRUE(call_graph->Dominates(cond_computation, cond_computation));
    406 }
    407 
    408 TEST_F(CallGraphTest, ComplexGraphNearestAncestors) {
    409   // Test NearestAncestorsInSameComputation on a call graph of a module with
    410   // several computation called in various contexts. The call graph looks like:
    411   //
    412   //      entry
    413   //      /  |
    414   //     a   |
    415   //   / | \ |
    416   //  b  |  cond
    417   //   \ |
    418   //    c
    419   //
    420   // Calls are made via kCall, kWhile, and kMap instructions.
    421   auto module = CreateNewModule();
    422   HloComputation* cond_computation =
    423       module->AddEmbeddedComputation(MakeConditionComputation());
    424   HloComputation* c_computation =
    425       module->AddEmbeddedComputation(MakeScalarComputation());
    426   HloComputation* b_computation = module->AddEmbeddedComputation(
    427       MakeMappingComputation(c_computation, /*callsites=*/1));
    428   HloInstruction* b_map = b_computation->root_instruction();
    429 
    430   HloComputation* a_computation;
    431   HloInstruction* a_call;
    432   HloInstruction* a_while;
    433   {
    434     HloComputation::Builder builder(TestName() + ".a");
    435     HloInstruction* param0 = builder.AddInstruction(
    436         HloInstruction::CreateParameter(0, kScalarShape, "param0"));
    437     a_call = builder.AddInstruction(
    438         HloInstruction::CreateCall(kScalarShape, {param0}, c_computation));
    439     a_while = builder.AddInstruction(HloInstruction::CreateWhile(
    440         kScalarShape, cond_computation, b_computation, a_call));
    441     a_computation = module->AddEmbeddedComputation(builder.Build());
    442   }
    443 
    444   HloComputation* entry_computation;
    445   HloInstruction* entry_while;
    446   {
    447     HloComputation::Builder builder(TestName() + ".entry");
    448     HloInstruction* param0 = builder.AddInstruction(
    449         HloInstruction::CreateParameter(0, kScalarShape, "param0"));
    450     entry_while = builder.AddInstruction(HloInstruction::CreateWhile(
    451         kScalarShape, cond_computation, a_computation, param0));
    452     entry_computation = module->AddEntryComputation(builder.Build());
    453   }
    454 
    455   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
    456   EXPECT_EQ(5, call_graph->nodes().size());
    457 
    458   // Verify NearestAncestorsInSameComputation for various instructions in the
    459   // module.
    460   EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(a_call, a_call),
    461             std::make_pair(a_call, a_call));
    462 
    463   // c_computation is called from more than one site, so
    464   // NearestAncestorsInSameComputation bails and returns nullptrs.
    465   std::pair<HloInstruction*, HloInstruction*> null_pair = {nullptr, nullptr};
    466   EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(
    467                 b_map, c_computation->root_instruction()),
    468             null_pair);
    469 
    470   EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(b_map, entry_while),
    471             std::make_pair(entry_while, entry_while));
    472   EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(b_map, a_call),
    473             std::make_pair(a_while, a_call));
    474   EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(a_while, a_call),
    475             std::make_pair(a_while, a_call));
    476   EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(a_while, b_map),
    477             std::make_pair(a_while, a_while));
    478 }
    479 
    480 TEST_F(CallGraphTest, VisitSingletonComputation) {
    481   // Test the call graph visitor with a call graph with a single node.
    482   auto module = CreateNewModule();
    483   HloComputation* computation =
    484       module->AddEntryComputation(MakeScalarComputation());
    485   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
    486 
    487   std::vector<HloComputation*> visited;
    488   TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) {
    489     visited.push_back(node.computation());
    490     return Status::OK();
    491   }));
    492   EXPECT_THAT(visited, UnorderedElementsAre(computation));
    493 }
    494 
    495 TEST_F(CallGraphTest, VisitUnreachableComputation) {
    496   // Test the call graph visitor with a call graph with an unreachable node.
    497   auto module = CreateNewModule();
    498   HloComputation* entry_computation =
    499       module->AddEntryComputation(MakeScalarComputation());
    500   HloComputation* unreachable_computation =
    501       module->AddEmbeddedComputation(MakeScalarComputation());
    502   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
    503 
    504   // Test visitation of only reachable nodes.
    505   {
    506     std::vector<const HloComputation*> visited;
    507     TF_ASSERT_OK(call_graph->VisitNodes(
    508         [&visited](const CallGraphNode& node) {
    509           visited.push_back(node.computation());
    510           return Status::OK();
    511         },
    512         /*visit_unreachable_nodes=*/false));
    513     EXPECT_EQ(visited.size(), 1);
    514     EXPECT_EQ(visited[0], entry_computation);
    515   }
    516 
    517   // Test visitation of all nodes (reachable and unreachable).
    518   {
    519     std::vector<HloComputation*> visited;
    520     TF_ASSERT_OK(call_graph->VisitNodes(
    521         [&visited](const CallGraphNode& node) {
    522           visited.push_back(node.computation());
    523           return Status::OK();
    524         },
    525         /*visit_unreachable_nodes=*/true));
    526     EXPECT_EQ(visited.size(), 2);
    527     EXPECT_THAT(visited, UnorderedElementsAre(entry_computation,
    528                                               unreachable_computation));
    529   }
    530 }
    531 
    532 TEST_F(CallGraphTest, VisitWithError) {
    533   // Test that the call graph visitor properly propagates errors.
    534   auto module = CreateNewModule();
    535   module->AddEntryComputation(MakeScalarComputation());
    536   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
    537 
    538   Status status = call_graph->VisitNodes(
    539       [](const CallGraphNode&) { return InternalError("Visitation failed"); });
    540 
    541   ASSERT_FALSE(status.ok());
    542   ASSERT_EQ(status.code(), tensorflow::error::INTERNAL);
    543   ASSERT_THAT(status.error_message(),
    544               ::testing::HasSubstr("Visitation failed"));
    545 }
    546 
    547 }  // namespace
    548 }  // namespace xla
    549