Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     17 
     18 #include <set>
     19 
     20 #include "tensorflow/compiler/xla/literal_util.h"
     21 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
     22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     23 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
     24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     25 #include "tensorflow/compiler/xla/shape_util.h"
     26 #include "tensorflow/compiler/xla/test.h"
     27 #include "tensorflow/compiler/xla/test_helpers.h"
     28 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     29 
     30 namespace op = xla::testing::opcode_matchers;
     31 
     32 namespace xla {
     33 
     34 namespace {
     35 
     36 using ::testing::ElementsAre;
     37 using ::testing::UnorderedElementsAre;
     38 
     39 class HloComputationTest : public HloTestBase {
     40  protected:
     41   HloComputationTest() {}
     42 
     43   // Create a computation which takes a scalar and returns its negation.
     44   std::unique_ptr<HloComputation> CreateNegateComputation() {
     45     auto builder = HloComputation::Builder("Negate");
     46     auto param = builder.AddInstruction(
     47         HloInstruction::CreateParameter(0, r0f32_, "param0"));
     48     builder.AddInstruction(
     49         HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param));
     50     return builder.Build();
     51   }
     52 
     53   // Creates a computation which calls map with the given computation.
     54   std::unique_ptr<HloComputation> CreateMapComputation(
     55       HloComputation* map_computation) {
     56     auto builder = HloComputation::Builder("Map");
     57     auto param = builder.AddInstruction(
     58         HloInstruction::CreateParameter(0, r0f32_, "param0"));
     59     builder.AddInstruction(
     60         HloInstruction::CreateMap(r0f32_, {param}, map_computation));
     61     return builder.Build();
     62   }
     63 
     64   Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
     65 };
     66 
     67 TEST_F(HloComputationTest, GetEmbeddedComputationsEmpty) {
     68   auto module = CreateNewModule();
     69   auto negate_computation =
     70       module->AddEntryComputation(CreateNegateComputation());
     71   EXPECT_TRUE(negate_computation->MakeEmbeddedComputationsList().empty());
     72 }
     73 
     74 TEST_F(HloComputationTest, GetEmbeddedComputationsOneComputation) {
     75   // Create computation which calls one other computation.
     76   auto module = CreateNewModule();
     77   auto negate_computation =
     78       module->AddEmbeddedComputation(CreateNegateComputation());
     79   auto map_computation =
     80       module->AddEntryComputation(CreateMapComputation(negate_computation));
     81   EXPECT_TRUE(negate_computation->MakeEmbeddedComputationsList().empty());
     82   EXPECT_THAT(map_computation->MakeEmbeddedComputationsList(),
     83               ElementsAre(negate_computation));
     84 }
     85 
     86 TEST_F(HloComputationTest, GetEmbeddedComputationsDiamond) {
     87   // Create computations with a diamond-shaped callgraph.
     88   auto module = CreateNewModule();
     89   auto negate_computation =
     90       module->AddEmbeddedComputation(CreateNegateComputation());
     91   auto map1_computation =
     92       module->AddEmbeddedComputation(CreateMapComputation(negate_computation));
     93   auto map2_computation =
     94       module->AddEmbeddedComputation(CreateMapComputation(negate_computation));
     95 
     96   auto builder = HloComputation::Builder(TestName());
     97   auto param = builder.AddInstruction(
     98       HloInstruction::CreateParameter(0, r0f32_, "param0"));
     99   auto map1 = builder.AddInstruction(
    100       HloInstruction::CreateMap(r0f32_, {param}, map1_computation));
    101   auto map2 = builder.AddInstruction(
    102       HloInstruction::CreateMap(r0f32_, {param}, map2_computation));
    103   builder.AddInstruction(
    104       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, map1, map2));
    105   auto computation = module->AddEntryComputation(builder.Build());
    106 
    107   auto embedded_computations = computation->MakeEmbeddedComputationsList();
    108   EXPECT_EQ(3, embedded_computations.size());
    109   // GetEmbeddedComputations returns a post order of the embedded computations,
    110   // so the negate computation must come first.
    111   EXPECT_EQ(negate_computation, *embedded_computations.begin());
    112   EXPECT_THAT(embedded_computations,
    113               UnorderedElementsAre(negate_computation, map1_computation,
    114                                    map2_computation));
    115 }
    116 
    117 TEST_F(HloComputationTest, PostOrderSingleton) {
    118   // Test GetInstructionPostOrder for a computation with one instruction.
    119   auto builder = HloComputation::Builder(TestName());
    120   auto constant = builder.AddInstruction(
    121       HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
    122   auto module = CreateNewModule();
    123   auto computation = module->AddEntryComputation(builder.Build());
    124   EXPECT_THAT(computation->MakeInstructionPostOrder(), ElementsAre(constant));
    125 }
    126 
    127 TEST_F(HloComputationTest, PostOrderSimple) {
    128   // Test GetInstructionPostOrder for a computation with a chain of
    129   // instructions.
    130   auto builder = HloComputation::Builder(TestName());
    131   auto constant = builder.AddInstruction(
    132       HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
    133   auto negate1 = builder.AddInstruction(
    134       HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
    135   auto negate2 = builder.AddInstruction(
    136       HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, negate1));
    137   auto module = CreateNewModule();
    138   auto computation = module->AddEntryComputation(builder.Build());
    139   EXPECT_THAT(computation->MakeInstructionPostOrder(),
    140               ElementsAre(constant, negate1, negate2));
    141 }
    142 
    143 TEST_F(HloComputationTest, PostOrderTrace) {
    144   // Test GetInstructionPostOrder for a computation with a trace instruction.
    145   auto builder = HloComputation::Builder(TestName());
    146   auto constant = builder.AddInstruction(
    147       HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
    148   auto negate1 = builder.AddInstruction(
    149       HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
    150   auto trace =
    151       builder.AddInstruction(HloInstruction::CreateTrace("foobar", negate1));
    152   auto negate2 = builder.AddInstruction(
    153       HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, negate1));
    154   auto module = CreateNewModule();
    155   auto computation = module->AddEntryComputation(builder.Build());
    156   // Trace instructions should be at the end of the sort.
    157   EXPECT_THAT(computation->MakeInstructionPostOrder(),
    158               ElementsAre(constant, negate1, negate2, trace));
    159 }
    160 
    161 TEST_F(HloComputationTest, PostOrderDisconnectedInstructions) {
    162   // Test GetInstructionPostOrder for a computation with multiple instructions
    163   // which are not connected.
    164   auto builder = HloComputation::Builder(TestName());
    165   auto constant1 = builder.AddInstruction(
    166       HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
    167   auto constant2 = builder.AddInstruction(
    168       HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
    169   auto constant3 = builder.AddInstruction(
    170       HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
    171   auto constant4 = builder.AddInstruction(
    172       HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
    173   auto module = CreateNewModule();
    174   auto computation = module->AddEntryComputation(builder.Build());
    175   EXPECT_THAT(computation->MakeInstructionPostOrder(),
    176               UnorderedElementsAre(constant1, constant2, constant3, constant4));
    177 }
    178 
    179 TEST_F(HloComputationTest, PostOrderWithMultipleRoots) {
    180   // Test GetInstructionPostOrder for a computation with multiple instructions
    181   // which are not connected.
    182   auto builder = HloComputation::Builder(TestName());
    183   auto constant1 = builder.AddInstruction(
    184       HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
    185   auto constant2 = builder.AddInstruction(
    186       HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
    187   auto constant3 = builder.AddInstruction(
    188       HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
    189   auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
    190       r0f32_, HloOpcode::kAdd, constant1, constant2));
    191   auto add2 = builder.AddInstruction(HloInstruction::CreateBinary(
    192       r0f32_, HloOpcode::kAdd, constant2, constant3));
    193   auto add3 = builder.AddInstruction(HloInstruction::CreateBinary(
    194       r0f32_, HloOpcode::kAdd, constant1, constant3));
    195   auto module = CreateNewModule();
    196   auto computation = module->AddEntryComputation(builder.Build());
    197   auto post_order = computation->MakeInstructionPostOrder();
    198   EXPECT_EQ(6, post_order.size());
    199   EXPECT_THAT(post_order, UnorderedElementsAre(constant1, constant2, constant3,
    200                                                add1, add2, add3));
    201 }
    202 
    203 TEST_F(HloComputationTest, VisitWithMultipleRoots) {
    204   // Test that Accept visits all instructions in the computation even if the
    205   // computation has multiple roots (dead code).
    206   auto builder = HloComputation::Builder(TestName());
    207   auto constant1 = builder.AddInstruction(
    208       HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
    209   auto constant2 = builder.AddInstruction(
    210       HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
    211   auto constant3 = builder.AddInstruction(
    212       HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
    213   // Add three disconnected add expressions.
    214   builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
    215                                                       constant1, constant2));
    216   builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
    217                                                       constant2, constant3));
    218   builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
    219                                                       constant1, constant3));
    220   auto module = CreateNewModule();
    221   auto computation = module->AddEntryComputation(builder.Build());
    222   // Visitor which keeps track of which instructions have been visited.
    223   class TestVisitor : public DfsHloVisitorWithDefault {
    224    public:
    225     explicit TestVisitor(HloComputation* computation)
    226         : computation_(computation) {}
    227 
    228     Status DefaultAction(HloInstruction* hlo_instruction) override {
    229       EXPECT_EQ(0, visited_set_.count(hlo_instruction));
    230       visited_set_.insert(hlo_instruction);
    231       last_visited_ = hlo_instruction;
    232       return Status::OK();
    233     }
    234 
    235     Status FinishVisit(HloInstruction* root) override {
    236       EXPECT_EQ(computation_->root_instruction(), root);
    237       ++finish_visit_calls_;
    238       return Status::OK();
    239     }
    240 
    241     HloComputation* computation_;
    242     std::set<HloInstruction*> visited_set_;
    243     int64 finish_visit_calls_ = 0;
    244     HloInstruction* last_visited_ = nullptr;
    245   };
    246 
    247   TestVisitor visitor(computation);
    248   EXPECT_IS_OK(computation->Accept(&visitor));
    249 
    250   EXPECT_EQ(6, visitor.visited_set_.size());
    251   EXPECT_EQ(1, visitor.finish_visit_calls_);
    252   EXPECT_EQ(computation->root_instruction(), visitor.last_visited_);
    253 }
    254 
    255 TEST_F(HloComputationTest, DeepCopyArray) {
    256   // Test that DeepCopyInstruction properly copies an array.
    257   auto builder = HloComputation::Builder(TestName());
    258   auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
    259       Literal::CreateR1<float>({1.0, 2.0, 3.0})));
    260   auto module = CreateNewModule();
    261   auto computation = module->AddEntryComputation(builder.Build());
    262   auto copy = computation->DeepCopyInstruction(constant).ValueOrDie();
    263 
    264   EXPECT_THAT(copy, op::Copy(constant));
    265 }
    266 
    267 TEST_F(HloComputationTest, DeepCopyTuple) {
    268   // Test that DeepCopyInstruction properly copies a tuple.
    269   auto builder = HloComputation::Builder(TestName());
    270   auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
    271       Literal::CreateR1<float>({1.0, 2.0, 3.0})));
    272   auto constant2 = builder.AddInstruction(
    273       HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
    274   auto tuple = builder.AddInstruction(
    275       HloInstruction::CreateTuple({constant1, constant2}));
    276 
    277   auto module = CreateNewModule();
    278   auto computation = module->AddEntryComputation(builder.Build());
    279   auto tuple_copy = computation->DeepCopyInstruction(tuple).ValueOrDie();
    280 
    281   EXPECT_THAT(tuple_copy, op::Tuple(op::Copy(op::GetTupleElement(tuple)),
    282                                     op::Copy(op::GetTupleElement(tuple))));
    283   EXPECT_EQ(0, tuple_copy->operand(0)->operand(0)->tuple_index());
    284   EXPECT_EQ(1, tuple_copy->operand(1)->operand(0)->tuple_index());
    285 }
    286 
    287 TEST_F(HloComputationTest, DeepCopyArrayAtIndices) {
    288   // Test that DeepCopyInstruction properly handles an array when the indices to
    289   // copy are specified.
    290   auto builder = HloComputation::Builder(TestName());
    291   auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
    292       Literal::CreateR1<float>({1.0, 2.0, 3.0})));
    293   auto computation = builder.Build();
    294 
    295   {
    296     // If the index is true, then a copy should be made.
    297     ShapeTree<bool> indices_to_copy(constant->shape(), /*init_value=*/true);
    298     EXPECT_THAT(computation->DeepCopyInstruction(constant, &indices_to_copy)
    299                     .ValueOrDie(),
    300                 op::Copy(constant));
    301   }
    302 
    303   {
    304     // If the index is false, then no copy should be made.
    305     ShapeTree<bool> indices_to_copy(constant->shape(), /*init_value=*/false);
    306     EXPECT_EQ(computation->DeepCopyInstruction(constant, &indices_to_copy)
    307                   .ValueOrDie(),
    308               constant);
    309   }
    310 }
    311 
    312 TEST_F(HloComputationTest, DeepCopyTupleAtIndices) {
    313   // Test that DeepCopyInstruction properly copies elements of a tuple as
    314   // specified by the given indices.
    315   auto builder = HloComputation::Builder(TestName());
    316   auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
    317       Literal::CreateR1<float>({1.0, 2.0, 3.0})));
    318   auto constant2 = builder.AddInstruction(
    319       HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
    320   auto tuple = builder.AddInstruction(
    321       HloInstruction::CreateTuple({constant1, constant2}));
    322   auto computation = builder.Build();
    323 
    324   {
    325     // All true values should copy all array elements.
    326     ShapeTree<bool> indices_to_copy(tuple->shape(), /*init_value=*/true);
    327     ShapeTree<HloInstruction*> copies_added(tuple->shape(),
    328                                             /*init_value=*/nullptr);
    329     HloInstruction* deep_copy =
    330         computation->DeepCopyInstruction(tuple, &indices_to_copy, &copies_added)
    331             .ValueOrDie();
    332 
    333     EXPECT_THAT(deep_copy, op::Tuple(op::Copy(op::GetTupleElement(tuple)),
    334                                      op::Copy(op::GetTupleElement(tuple))));
    335     EXPECT_THAT(deep_copy, op::Tuple(copies_added.element({0}),
    336                                      copies_added.element({1})));
    337   }
    338 
    339   {
    340     // All false elements should copy no array elements, but the GTE and tuple
    341     // instruction scaffolding should be built.
    342     ShapeTree<bool> indices_to_copy(tuple->shape(), /*init_value=*/false);
    343     ShapeTree<HloInstruction*> copies_added(tuple->shape(),
    344                                             /*init_value=*/nullptr);
    345     HloInstruction* deep_copy =
    346         computation->DeepCopyInstruction(tuple, &indices_to_copy, &copies_added)
    347             .ValueOrDie();
    348 
    349     EXPECT_THAT(deep_copy, op::Tuple(op::GetTupleElement(tuple),
    350                                      op::GetTupleElement(tuple)));
    351     EXPECT_TRUE(copies_added.element({}) == nullptr);
    352     EXPECT_TRUE(copies_added.element({0}) == nullptr);
    353     EXPECT_TRUE(copies_added.element({1}) == nullptr);
    354   }
    355 
    356   {
    357     // Verify one element copied, the other not.
    358     ShapeTree<bool> indices_to_copy(tuple->shape(), /*init_value=*/false);
    359     *indices_to_copy.mutable_element({0}) = true;
    360     ShapeTree<HloInstruction*> copies_added(tuple->shape(),
    361                                             /*init_value=*/nullptr);
    362     HloInstruction* deep_copy =
    363         computation->DeepCopyInstruction(tuple, &indices_to_copy, &copies_added)
    364             .ValueOrDie();
    365 
    366     EXPECT_THAT(deep_copy, op::Tuple(op::Copy(op::GetTupleElement(tuple)),
    367                                      op::GetTupleElement(tuple)));
    368     EXPECT_TRUE(copies_added.element({}) == nullptr);
    369     EXPECT_TRUE(copies_added.element({0}) != nullptr);
    370     EXPECT_TRUE(copies_added.element({1}) == nullptr);
    371   }
    372 }
    373 
    374 TEST_F(HloComputationTest, CycleDetection) {
    375   // Test whether the visitor can detect cycles in the graph.
    376   auto builder = HloComputation::Builder(TestName());
    377   auto constant = builder.AddInstruction(
    378       HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
    379   auto negate = builder.AddInstruction(
    380       HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
    381   auto add = builder.AddInstruction(
    382       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, negate, negate));
    383   auto module = CreateNewModule();
    384   auto computation = module->AddEntryComputation(builder.Build());
    385   // Add a control dependency to create a cycle.
    386   ASSERT_IS_OK(add->AddControlDependencyTo(negate));
    387 
    388   const auto visitor = [](HloInstruction* instruction) { return Status::OK(); };
    389   auto visit_status = computation->Accept(visitor);
    390   ASSERT_FALSE(visit_status.ok());
    391   ASSERT_THAT(visit_status.error_message(),
    392               ::testing::ContainsRegex("cycle is detecte"));
    393 }
    394 
    395 TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) {
    396   // Test RemoveInstructionAndUnusedOperands with an instruction which has a
    397   // duplicated (dead) operand. This verifies that the operand is not deleted
    398   // twice.
    399   auto builder = HloComputation::Builder(TestName());
    400   auto constant = builder.AddInstruction(
    401       HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
    402   auto dead_negate = builder.AddInstruction(
    403       HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
    404   auto dead_add = builder.AddInstruction(HloInstruction::CreateBinary(
    405       r0f32_, HloOpcode::kAdd, dead_negate, dead_negate));
    406   auto negate = builder.AddInstruction(
    407       HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
    408   auto module = CreateNewModule();
    409   auto computation = module->AddEntryComputation(builder.Build());
    410   EXPECT_EQ(4, computation->instruction_count());
    411   EXPECT_THAT(computation->root_instruction(), op::Negate(constant));
    412   EXPECT_EQ(negate, computation->root_instruction());
    413 
    414   ASSERT_IS_OK(computation->RemoveInstructionAndUnusedOperands(dead_add));
    415 
    416   EXPECT_EQ(2, computation->instruction_count());
    417   EXPECT_THAT(computation->root_instruction(), op::Negate(constant));
    418   EXPECT_EQ(negate, computation->root_instruction());
    419 }
    420 
    421 TEST_F(HloComputationTest, CloneWithControlDependency) {
    422   auto builder = HloComputation::Builder(TestName());
    423   auto constant1 = builder.AddInstruction(
    424       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
    425   auto constant2 = builder.AddInstruction(
    426       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0f)));
    427   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
    428       r0f32_, HloOpcode::kAdd, constant1, constant2));
    429 
    430   auto param = builder.AddInstruction(
    431       HloInstruction::CreateParameter(0, r0f32_, "param0"));
    432   auto negate = builder.AddInstruction(
    433       HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param));
    434   auto module = CreateNewModule();
    435   auto computation =
    436       module->AddEntryComputation(builder.Build(/*root_instruction=*/add));
    437 
    438   TF_CHECK_OK(negate->AddControlDependencyTo(add));
    439 
    440   auto clone = computation->Clone();
    441 
    442   auto cloned_add = clone->root_instruction();
    443   EXPECT_EQ(cloned_add->opcode(), HloOpcode::kAdd);
    444 
    445   auto predecessors = cloned_add->control_predecessors();
    446   EXPECT_EQ(1, predecessors.size());
    447   EXPECT_EQ(HloOpcode::kNegate, predecessors[0]->opcode());
    448   auto successors = predecessors[0]->control_successors();
    449   EXPECT_THAT(successors, ::testing::ElementsAre(cloned_add));
    450 }
    451 
    452 TEST_F(HloComputationTest, Reachability) {
    453   // Test reachability of a non-trivial computation:
    454   //
    455   // const1    const2
    456   //    |         |
    457   //    | +-------+
    458   //    | |       |
    459   //    add ..   negate
    460   //     |   .     |
    461   //     |   .... exp
    462   //     |         |
    463   //     +---+   +-+---+
    464   //         |   |     |
    465   //       multiply   copy
    466   //
    467   // There is a control dependency from 'add' to 'exp'.
    468   auto builder = HloComputation::Builder(TestName());
    469   auto constant1 = builder.AddInstruction(
    470       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
    471   auto constant2 = builder.AddInstruction(
    472       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0f)));
    473   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
    474       r0f32_, HloOpcode::kAdd, constant1, constant2));
    475   auto negate = builder.AddInstruction(
    476       HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant2));
    477   auto exp = builder.AddInstruction(
    478       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, negate));
    479   auto mul = builder.AddInstruction(
    480       HloInstruction::CreateBinary(r0f32_, HloOpcode::kMultiply, add, exp));
    481   auto copy = builder.AddInstruction(
    482       HloInstruction::CreateUnary(r0f32_, HloOpcode::kCopy, exp));
    483 
    484   auto module = CreateNewModule();
    485   auto computation =
    486       module->AddEntryComputation(builder.Build(/*root_instruction=*/mul));
    487 
    488   TF_CHECK_OK(add->AddControlDependencyTo(exp));
    489   auto reachability = computation->ComputeReachability();
    490 
    491   EXPECT_TRUE(reachability->IsReachable(constant1, constant1));
    492   EXPECT_FALSE(reachability->IsReachable(constant1, constant2));
    493   EXPECT_TRUE(reachability->IsReachable(constant1, add));
    494   EXPECT_FALSE(reachability->IsReachable(constant1, negate));
    495   EXPECT_TRUE(reachability->IsReachable(constant1, exp));
    496   EXPECT_TRUE(reachability->IsReachable(constant1, mul));
    497   EXPECT_TRUE(reachability->IsReachable(constant1, copy));
    498 
    499   EXPECT_FALSE(reachability->IsReachable(constant2, constant1));
    500   EXPECT_TRUE(reachability->IsReachable(constant2, constant2));
    501   EXPECT_TRUE(reachability->IsReachable(constant2, add));
    502   EXPECT_TRUE(reachability->IsReachable(constant2, negate));
    503   EXPECT_TRUE(reachability->IsReachable(constant2, exp));
    504   EXPECT_TRUE(reachability->IsReachable(constant2, mul));
    505   EXPECT_TRUE(reachability->IsReachable(constant2, copy));
    506 
    507   EXPECT_FALSE(reachability->IsReachable(exp, constant1));
    508   EXPECT_FALSE(reachability->IsReachable(exp, constant2));
    509   EXPECT_FALSE(reachability->IsReachable(exp, add));
    510   EXPECT_FALSE(reachability->IsReachable(exp, negate));
    511   EXPECT_TRUE(reachability->IsReachable(exp, exp));
    512   EXPECT_TRUE(reachability->IsReachable(exp, mul));
    513   EXPECT_TRUE(reachability->IsReachable(exp, copy));
    514 
    515   EXPECT_FALSE(reachability->IsReachable(mul, constant1));
    516   EXPECT_FALSE(reachability->IsReachable(mul, constant2));
    517   EXPECT_FALSE(reachability->IsReachable(mul, add));
    518   EXPECT_FALSE(reachability->IsReachable(mul, negate));
    519   EXPECT_FALSE(reachability->IsReachable(mul, exp));
    520   EXPECT_TRUE(reachability->IsReachable(mul, mul));
    521   EXPECT_FALSE(reachability->IsReachable(mul, copy));
    522 
    523   EXPECT_TRUE(reachability->IsConnected(constant1, copy));
    524   EXPECT_TRUE(reachability->IsConnected(copy, constant1));
    525   EXPECT_FALSE(reachability->IsConnected(negate, add));
    526   EXPECT_FALSE(reachability->IsConnected(add, negate));
    527 
    528   // Remove the control dependency then update and verify the reachability map
    529   ASSERT_IS_OK(add->RemoveControlDependencyTo(exp));
    530   computation->UpdateReachabilityThroughInstruction(exp, reachability.get());
    531 
    532   EXPECT_TRUE(reachability->IsReachable(constant1, constant1));
    533   EXPECT_FALSE(reachability->IsReachable(constant1, constant2));
    534   EXPECT_TRUE(reachability->IsReachable(constant1, add));
    535   EXPECT_FALSE(reachability->IsReachable(constant1, negate));
    536   EXPECT_FALSE(reachability->IsReachable(constant1, exp));
    537   EXPECT_TRUE(reachability->IsReachable(constant1, mul));
    538   EXPECT_FALSE(reachability->IsReachable(constant1, copy));
    539 
    540   // Change a use within the graph then update and verify the reachability map
    541   ASSERT_IS_OK(constant2->ReplaceUseWith(negate, constant1));
    542   computation->UpdateReachabilityThroughInstruction(negate, reachability.get());
    543 
    544   EXPECT_FALSE(reachability->IsReachable(constant2, constant1));
    545   EXPECT_TRUE(reachability->IsReachable(constant2, constant2));
    546   EXPECT_TRUE(reachability->IsReachable(constant2, add));
    547   EXPECT_FALSE(reachability->IsReachable(constant2, negate));
    548   EXPECT_FALSE(reachability->IsReachable(constant2, exp));
    549   EXPECT_TRUE(reachability->IsReachable(constant2, mul));
    550   EXPECT_FALSE(reachability->IsReachable(constant2, copy));
    551 }
    552 
    553 }  // namespace
    554 
    555 }  // namespace xla
    556