Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/buffer_liveness.h"
     17 
     18 #include <memory>
     19 #include <string>
     20 
     21 #include "tensorflow/compiler/xla/ptr_util.h"
     22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     25 #include "tensorflow/compiler/xla/shape_util.h"
     26 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     27 #include "tensorflow/compiler/xla/types.h"
     28 #include "tensorflow/compiler/xla/xla_data.pb.h"
     29 
     30 namespace xla {
     31 namespace {
     32 
     33 class BufferLivenessTest : public HloTestBase {
     34  protected:
     35   // Returns the LogicalBuffer defined at the given instruction and
     36   // index. CHECKs if no buffer is defined at that point.
     37   const LogicalBuffer& GetBuffer(const BufferLiveness& liveness,
     38                                  const HloInstruction* instruction,
     39                                  const ShapeIndex& index) {
     40     const auto& pointed_to = liveness.points_to_analysis()
     41                                  .GetPointsToSet(instruction)
     42                                  .element(index);
     43     CHECK_EQ(1, pointed_to.size());
     44     CHECK_EQ(instruction, pointed_to[0]->instruction());
     45     CHECK(index == pointed_to[0]->index());
     46     return *pointed_to[0];
     47   }
     48 
     49   // Returns true if the top-level buffers for instructions 'a' and 'b' may
     50   // interfere. Precondition: 'a' and 'b' are array-shaped.
     51   bool InstructionsMayInterfere(const BufferLiveness& liveness,
     52                                 HloInstruction* a, HloInstruction* b) {
     53     EXPECT_FALSE(ShapeUtil::IsTuple(a->shape()));
     54     EXPECT_FALSE(ShapeUtil::IsTuple(b->shape()));
     55     return liveness.MayInterfere(
     56         GetBuffer(liveness, /*instruction=*/a, /*index=*/{}),
     57         GetBuffer(liveness, /*instruction=*/b, /*index=*/{}));
     58   }
     59 
     60   // Returns true if the tuple elements at 'index' for instructions 'a' and 'b'
     61   // may interfere. Precondition: 'a' and 'b' are tuple-shaped, with equal
     62   // tuple element sub-shapes.
     63   bool TupleElementsMayInterfere(const BufferLiveness& liveness,
     64                                  HloInstruction* a, HloInstruction* b,
     65                                  const ShapeIndex& index) {
     66     // Check that top-level shapes are tuple and tuple element shapes are equal.
     67     EXPECT_TRUE(ShapeUtil::IsTuple(a->shape()));
     68     EXPECT_TRUE(ShapeUtil::IsTuple(b->shape()));
     69     EXPECT_TRUE(
     70         ShapeUtil::Compatible(ShapeUtil::GetSubshape(a->shape(), index),
     71                               ShapeUtil::GetSubshape(b->shape(), index)));
     72     // Lookup PointsTo set for instructions 'a' and 'b'.
     73     auto& points_to_analysis = liveness.points_to_analysis();
     74     const auto& points_to_a =
     75         points_to_analysis.GetPointsToSet(a).element(index);
     76     const auto& points_to_b =
     77         points_to_analysis.GetPointsToSet(b).element(index);
     78     // Make sure PointsTo sets for 'a' and 'b' are unambiguous.
     79     EXPECT_EQ(1, points_to_a.size());
     80     EXPECT_EQ(points_to_a.size(), points_to_b.size());
     81     // Check interference.
     82     return liveness.MayInterfere(*points_to_a[0], *points_to_b[0]);
     83   }
     84 
     85   // Returns true if the top-level buffers for the given instruction maybe
     86   // liveout of the entry computation.
     87   // Precondition: instruction is array-shaped.
     88   bool InstructionMaybeLiveOut(const BufferLiveness& liveness,
     89                                HloInstruction* instruction) {
     90     return liveness.MaybeLiveOut(
     91         GetBuffer(liveness, instruction, /*index=*/{}));
     92   }
     93 
     94   std::unique_ptr<HloComputation> BuildDummyComputation() {
     95     auto builder = HloComputation::Builder(TestName() + "_dummy");
     96     builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
     97     return builder.Build();
     98   }
     99 
    100   const Shape vec_ = ShapeUtil::MakeShape(xla::F32, {42});
    101 };
    102 
    103 TEST_F(BufferLivenessTest, ElementwiseChain) {
    104   // A simple chain of elementwise operations. No buffers should interfere.
    105   //
    106   // param --> negate -> exp -> log
    107   //
    108   auto builder = HloComputation::Builder(TestName());
    109   auto param =
    110       builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
    111   auto negate = builder.AddInstruction(
    112       HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param));
    113   auto exp = builder.AddInstruction(
    114       HloInstruction::CreateUnary(vec_, HloOpcode::kExp, negate));
    115   auto log = builder.AddInstruction(
    116       HloInstruction::CreateUnary(vec_, HloOpcode::kLog, exp));
    117 
    118   auto module = CreateNewModule();
    119   module->AddEntryComputation(builder.Build());
    120 
    121   auto liveness =
    122       BufferLiveness::Run(module.get(),
    123                           xla::MakeUnique<DependencyHloOrdering>(module.get()))
    124           .ConsumeValueOrDie();
    125 
    126   EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, negate));
    127   EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp));
    128   EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, log));
    129 
    130   // No buffers should interfere.
    131   EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, exp));
    132   EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, log));
    133   EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, negate));
    134   EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, log));
    135   EXPECT_FALSE(InstructionsMayInterfere(*liveness, log, negate));
    136   EXPECT_FALSE(InstructionsMayInterfere(*liveness, log, exp));
    137 
    138   // Buffers should interfere with itself.
    139   EXPECT_TRUE(InstructionsMayInterfere(*liveness, exp, exp));
    140 
    141   // Only log is live out.
    142   EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, param));
    143   EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, negate));
    144   EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, exp));
    145   EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, log));
    146 }
    147 
    148 TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) {
    149   // Two entry params, which interfere with each other.
    150   //
    151   // param0 --> negate ---------------\
    152   //                   param1 --> exp --> add
    153   auto builder = HloComputation::Builder(TestName());
    154   auto param0 = builder.AddInstruction(
    155       HloInstruction::CreateParameter(0, vec_, "param0"));
    156   auto param1 = builder.AddInstruction(
    157       HloInstruction::CreateParameter(1, vec_, "param1"));
    158   auto negate = builder.AddInstruction(
    159       HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param0));
    160   auto exp = builder.AddInstruction(
    161       HloInstruction::CreateUnary(vec_, HloOpcode::kExp, param1));
    162   auto add = builder.AddInstruction(
    163       HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp));
    164 
    165   auto module = CreateNewModule();
    166   HloComputation* entry = module->AddEntryComputation(builder.Build());
    167 
    168   SequentialHloOrdering::HloModuleSequence sequence;
    169   sequence.insert({entry, {param0, negate, param1, exp, add}});
    170   auto liveness =
    171       BufferLiveness::Run(module.get(), xla::MakeUnique<SequentialHloOrdering>(
    172                                             module.get(), sequence))
    173           .ConsumeValueOrDie();
    174 
    175   // Entry parameters interfere as if they are defined simultaneously at
    176   // the very beginning.
    177   EXPECT_TRUE(InstructionsMayInterfere(*liveness, param0, param1));
    178   EXPECT_FALSE(InstructionsMayInterfere(*liveness, param0, negate));
    179   EXPECT_FALSE(InstructionsMayInterfere(*liveness, param0, exp));
    180   EXPECT_FALSE(InstructionsMayInterfere(*liveness, param0, add));
    181   EXPECT_TRUE(InstructionsMayInterfere(*liveness, param1, param0));
    182   EXPECT_TRUE(InstructionsMayInterfere(*liveness, param1, negate));
    183   EXPECT_FALSE(InstructionsMayInterfere(*liveness, param1, exp));
    184   EXPECT_FALSE(InstructionsMayInterfere(*liveness, param1, add));
    185 
    186   // Negate and exp still interfere.
    187   EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, exp));
    188   EXPECT_TRUE(InstructionsMayInterfere(*liveness, exp, negate));
    189 
    190   // But {negate, add} and {exp, add} don't interfere.
    191   EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, add));
    192   EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, negate));
    193   EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, add));
    194   EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, exp));
    195 }
    196 
    197 TEST_F(BufferLivenessTest, NonElementwiseOperand) {
    198   // A chain of operations with two elementwise and one non-elementwise. The
    199   // elementwise op should not interfere with its operand, while the
    200   // non-elementwise op should interfere. Entry params always interfere.
    201   //
    202   // param --> exp -> negate -> reverse
    203   //
    204   auto builder = HloComputation::Builder(TestName());
    205   auto param =
    206       builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
    207   auto exp = builder.AddInstruction(
    208       HloInstruction::CreateUnary(vec_, HloOpcode::kExp, param));
    209   auto negate = builder.AddInstruction(
    210       HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, exp));
    211   auto reverse =
    212       builder.AddInstruction(HloInstruction::CreateReverse(vec_, negate, {0}));
    213 
    214   auto module = CreateNewModule();
    215   module->AddEntryComputation(builder.Build());
    216 
    217   auto liveness =
    218       BufferLiveness::Run(module.get(),
    219                           xla::MakeUnique<DependencyHloOrdering>(module.get()))
    220           .ConsumeValueOrDie();
    221 
    222   EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp));
    223   EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, negate));
    224   EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, reverse));
    225 
    226   // Negate is elementwise, so doesn't interfere with its operand.
    227   // Reverse is non-elementwise, so does interfere with its operand.
    228   EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, negate));
    229   EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, reverse));
    230 }
    231 
    232 TEST_F(BufferLivenessTest, OverlappedBuffers) {
    233   // Verify simultaneously live buffers interfere (exp and negate).
    234   //
    235   // param --> negate -> add
    236   //     \---> exp -----/
    237   //
    238   auto builder = HloComputation::Builder(TestName());
    239   auto param =
    240       builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
    241   auto negate = builder.AddInstruction(
    242       HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param));
    243   auto exp = builder.AddInstruction(
    244       HloInstruction::CreateUnary(vec_, HloOpcode::kExp, param));
    245   auto add = builder.AddInstruction(
    246       HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp));
    247 
    248   auto module = CreateNewModule();
    249   module->AddEntryComputation(builder.Build());
    250 
    251   auto liveness =
    252       BufferLiveness::Run(module.get(),
    253                           xla::MakeUnique<DependencyHloOrdering>(module.get()))
    254           .ConsumeValueOrDie();
    255 
    256   EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate));
    257   EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, exp));
    258   EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add));
    259 
    260   // Negate and exp interfere with each other, but not with add.
    261   EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, exp));
    262   EXPECT_TRUE(InstructionsMayInterfere(*liveness, exp, negate));
    263   EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, add));
    264   EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, negate));
    265   EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, add));
    266   EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, exp));
    267 }
    268 
    269 TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) {
    270   // Identical to the test OverlappedBuffer but using a sequential ordering of
    271   // HLO instructions.
    272   //
    273   // param --> negate -> add
    274   //     \---> exp -----/
    275   //
    276   // Sequential order:
    277   //  param, negate, exp, add
    278   //
    279   // Liveness is identical to the DependencyHloOrdering.
    280   auto builder = HloComputation::Builder(TestName());
    281   auto param =
    282       builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
    283   auto negate = builder.AddInstruction(
    284       HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param));
    285   auto exp = builder.AddInstruction(
    286       HloInstruction::CreateUnary(vec_, HloOpcode::kExp, param));
    287   auto add = builder.AddInstruction(
    288       HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp));
    289 
    290   auto module = CreateNewModule();
    291   auto computation = module->AddEntryComputation(builder.Build());
    292 
    293   SequentialHloOrdering::HloModuleSequence module_sequence;
    294   std::vector<const HloInstruction*> order = {param, negate, exp, add};
    295   module_sequence.emplace(computation, order);
    296   auto liveness =
    297       BufferLiveness::Run(module.get(), xla::MakeUnique<SequentialHloOrdering>(
    298                                             module.get(), module_sequence))
    299           .ConsumeValueOrDie();
    300 
    301   EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate));
    302   EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp));
    303   EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add));
    304 
    305   // Negate and exp interfere with each other, but not with add.
    306   EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, exp));
    307   EXPECT_TRUE(InstructionsMayInterfere(*liveness, exp, negate));
    308   EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, add));
    309   EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, negate));
    310   EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, add));
    311   EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, exp));
    312 }
    313 
    314 TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) {
    315   // Tests that when the root instruction is not the last instruction in the
    316   // schedule, the live range of its buffers interfere with the buffers of the
    317   // later instructions.
    318   //
    319   // Two sets of independent instructions are executed in the computation.
    320   // param --> add (root)
    321   // recv --> recv-done --> send --> send-done
    322   //
    323   // Sequential order:
    324   //  param, add (root), recv, recv-done, send, send-done
    325   auto builder = HloComputation::Builder(TestName());
    326   auto param =
    327       builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
    328   auto add = builder.AddInstruction(
    329       HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, param, param));
    330   auto recv = builder.AddInstruction(
    331       HloInstruction::CreateRecv(vec_, /*channel_id=*/0));
    332   auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
    333   auto send = builder.AddInstruction(
    334       HloInstruction::CreateSend(recv_done, /*channel_id=*/1));
    335   auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
    336 
    337   auto module = CreateNewModule();
    338   auto computation = module->AddEntryComputation(builder.Build(add));
    339 
    340   SequentialHloOrdering::HloModuleSequence module_sequence;
    341   std::vector<const HloInstruction*> order = {param,     add,  recv,
    342                                               recv_done, send, send_done};
    343   module_sequence.emplace(computation, order);
    344   auto liveness =
    345       BufferLiveness::Run(module.get(), xla::MakeUnique<SequentialHloOrdering>(
    346                                             module.get(), module_sequence))
    347           .ConsumeValueOrDie();
    348 
    349   EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add));
    350   // Check the root instruction (add) buffer interferes with the recv buffer.
    351   EXPECT_TRUE(
    352       liveness->MayInterfere(GetBuffer(*liveness, add, /*index=*/{}),
    353                              GetBuffer(*liveness, recv, /*index=*/{0})));
    354 }
    355 
    356 TEST_F(BufferLivenessTest, TupleLiveOut) {
    357   // Verify MaybeLiveOut with nested tuples. Result of computation looks like:
    358   //
    359   //   Tuple({Tuple({Negate(Param)}, Exp(Negate(Param)))})
    360   //
    361   // All values should be live out except Param.
    362   auto builder = HloComputation::Builder(TestName());
    363   auto param =
    364       builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
    365   auto negate = builder.AddInstruction(
    366       HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param));
    367   auto inner_tuple =
    368       builder.AddInstruction(HloInstruction::CreateTuple({negate}));
    369   auto exp = builder.AddInstruction(
    370       HloInstruction::CreateUnary(vec_, HloOpcode::kExp, negate));
    371   auto outer_tuple =
    372       builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple, exp}));
    373 
    374   auto module = CreateNewModule();
    375   module->AddEntryComputation(builder.Build());
    376 
    377   auto liveness =
    378       BufferLiveness::Run(module.get(),
    379                           xla::MakeUnique<DependencyHloOrdering>(module.get()))
    380           .ConsumeValueOrDie();
    381 
    382   // All buffers should be live out except the param
    383   EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, param));
    384   EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, negate));
    385   EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, inner_tuple));
    386   EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, exp));
    387   EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, outer_tuple));
    388 }
    389 
    390 // bitcast liveout.
    391 
    392 TEST_F(BufferLivenessTest, EmbeddedComputation) {
    393   // Test MaybeLiveOut and MayInterfere for embedded computation.
    394   auto module = CreateNewModule();
    395 
    396   auto embedded_builder = HloComputation::Builder(TestName() + "_embedded");
    397   auto embedded_param = embedded_builder.AddInstruction(
    398       HloInstruction::CreateParameter(0, vec_, "embedded_param"));
    399   auto embedded_log = embedded_builder.AddInstruction(
    400       HloInstruction::CreateUnary(vec_, HloOpcode::kLog, embedded_param));
    401 
    402   auto embedded_computation =
    403       module->AddEmbeddedComputation(embedded_builder.Build());
    404 
    405   auto builder = HloComputation::Builder(TestName());
    406   auto param =
    407       builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
    408   auto call = builder.AddInstruction(
    409       HloInstruction::CreateCall(vec_, {param}, embedded_computation));
    410 
    411   module->AddEntryComputation(builder.Build());
    412 
    413   auto liveness =
    414       BufferLiveness::Run(module.get(),
    415                           xla::MakeUnique<DependencyHloOrdering>(module.get()))
    416           .ConsumeValueOrDie();
    417 
    418   // Buffers in different computations should always interfere.
    419   EXPECT_TRUE(InstructionsMayInterfere(*liveness, embedded_log, call));
    420   EXPECT_TRUE(InstructionsMayInterfere(*liveness, embedded_param, param));
    421   EXPECT_FALSE(
    422       InstructionsMayInterfere(*liveness, embedded_param, embedded_log));
    423 
    424   // The only buffers for which MaybeLiveOut == true are those live out
    425   // of the entry computation. Buffers live out of embedded computations should
    426   // return false for this method.
    427   EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, embedded_log));
    428   EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, call));
    429 }
    430 
    431 TEST_F(BufferLivenessTest, TupleConstantLiveOut) {
    432   // Verify non top-level elements of a nested tuple constant are properly
    433   // marked as liveout. Computation:
    434   //
    435   //   GetTupleElement(0, TupleConstant({{0, 1}, {3}})
    436   //
    437   // Only the array buffers containing 0 and 1 are liveout of the
    438   // computation. The buffer containing {0, 1} is copied by GetTupleElement, and
    439   // the buffers containing {3} and 3 are dead.
    440   auto builder = HloComputation::Builder(TestName());
    441   auto inner_tuple0 = Literal::MakeTuple(
    442       {Literal::CreateR0<int64>(0).get(), Literal::CreateR0<int64>(1).get()});
    443   auto inner_tuple1 = Literal::MakeTuple({Literal::CreateR0<int64>(3).get()});
    444   auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant(
    445       Literal::MakeTuple({inner_tuple0.get(), inner_tuple1.get()})));
    446   builder.AddInstruction(HloInstruction::CreateGetTupleElement(
    447       inner_tuple0->shape(), tuple_constant, 0));
    448 
    449   auto module = CreateNewModule();
    450   module->AddEntryComputation(builder.Build());
    451 
    452   auto liveness =
    453       BufferLiveness::Run(module.get(),
    454                           xla::MakeUnique<DependencyHloOrdering>(module.get()))
    455           .ConsumeValueOrDie();
    456 
    457   // Only the element buffers of the tuple constant which are pointed to by
    458   // the GetTupleElement instruction should be liveout.
    459   EXPECT_FALSE(liveness->MaybeLiveOut(
    460       GetBuffer(*liveness, tuple_constant, /*index=*/{})));
    461   EXPECT_TRUE(liveness->MaybeLiveOut(
    462       GetBuffer(*liveness, tuple_constant, /*index=*/{0})));
    463   EXPECT_TRUE(liveness->MaybeLiveOut(
    464       GetBuffer(*liveness, tuple_constant, /*index=*/{0, 0})));
    465   EXPECT_TRUE(liveness->MaybeLiveOut(
    466       GetBuffer(*liveness, tuple_constant, /*index=*/{0, 1})));
    467   EXPECT_FALSE(liveness->MaybeLiveOut(
    468       GetBuffer(*liveness, tuple_constant, /*index=*/{1})));
    469   EXPECT_FALSE(liveness->MaybeLiveOut(
    470       GetBuffer(*liveness, tuple_constant, /*index=*/{1, 0})));
    471   EXPECT_FALSE(liveness->MaybeLiveOut(
    472       GetBuffer(*liveness, tuple_constant, /*index=*/{1, 0})));
    473 }
    474 
    475 TEST_F(BufferLivenessTest, IndependentTupleElements) {
    476   auto builder = HloComputation::Builder(TestName());
    477   // Create param0 Tuple.
    478   auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter(
    479       0,
    480       ShapeUtil::MakeTupleShape(
    481           {ShapeUtil::MakeShape(F32, {8}), ShapeUtil::MakeShape(S32, {4})}),
    482       "param0"));
    483   // Create independent computations for each tuple elememt.
    484 
    485   // Tuple element0 computation:
    486   //   Add(GetTupleElement(tuple_param0, 0), const0)
    487   auto tuple_element0_shape =
    488       ShapeUtil::GetSubshape(tuple_param0->shape(), {0});
    489   auto tuple_element0 =
    490       builder.AddInstruction(HloInstruction::CreateGetTupleElement(
    491           tuple_element0_shape, tuple_param0, 0));
    492   auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
    493       Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
    494   auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
    495       tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0));
    496 
    497   // Tuple element1 computation:
    498   //   Add(GetTupleElement(tuple_param0, 1), const1)
    499   auto tuple_element1_shape =
    500       ShapeUtil::GetSubshape(tuple_param0->shape(), {1});
    501   auto tuple_element1 =
    502       builder.AddInstruction(HloInstruction::CreateGetTupleElement(
    503           tuple_element1_shape, tuple_param0, 1));
    504   auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
    505       Literal::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f})));
    506   auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
    507       tuple_element1_shape, HloOpcode::kAdd, tuple_element1, const1));
    508 
    509   // Create output tuple.
    510   auto tuple_root =
    511       builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
    512 
    513   auto module = CreateNewModule();
    514   module->AddEntryComputation(BuildDummyComputation());
    515   module->AddEmbeddedComputation(builder.Build());
    516 
    517   auto liveness =
    518       BufferLiveness::Run(module.get(),
    519                           xla::MakeUnique<DependencyHloOrdering>(module.get()))
    520           .ConsumeValueOrDie();
    521 
    522   // We compare tuple element pairs that are input/output to the computation:
    523   // 1) (input_tuple_element, output_tuple_element) = ('tuple_element0', 'add0')
    524   // 2) (input_tuple_element, output_tuple_element) = ('tuple_element1', 'add1')
    525 
    526   // Tuple output element 'add0' does not depend on input 'tuple_element1'.
    527   // Tuple output element 'add1' does not depend on input 'tuple_element0'.
    528 
    529   // Both element pair does not interfere, because there is no other dependency
    530   // on the pairs tuple input element, and so liveness can compute that all
    531   // users of the input tuple element execute before the associated output
    532   // tuple element.
    533   EXPECT_FALSE(
    534       TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {0}));
    535   EXPECT_FALSE(
    536       TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1}));
    537 }
    538 
    539 TEST_F(BufferLivenessTest, DependentTupleElements) {
    540   auto builder = HloComputation::Builder(TestName());
    541   // Create param0 Tuple.
    542   auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter(
    543       0,
    544       ShapeUtil::MakeTupleShape(
    545           {ShapeUtil::MakeShape(F32, {8}), ShapeUtil::MakeShape(F32, {8})}),
    546       "param0"));
    547   // Create dependent computations for each tuple elememt.
    548 
    549   // Tuple element0 computation:
    550   //   Add(GetTupleElement(tuple_param0, 0), const0)
    551   auto tuple_element0_shape =
    552       ShapeUtil::GetSubshape(tuple_param0->shape(), {0});
    553   auto tuple_element0 =
    554       builder.AddInstruction(HloInstruction::CreateGetTupleElement(
    555           tuple_element0_shape, tuple_param0, 0));
    556   auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
    557       Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
    558   auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
    559       tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0));
    560 
    561   // Tuple element1 computation:
    562   //   Add(GetTupleElement(tuple_param0, 0), GetTupleElement(tuple_param0, 1))
    563   auto tuple_element1_shape =
    564       ShapeUtil::GetSubshape(tuple_param0->shape(), {1});
    565   auto tuple_element1 =
    566       builder.AddInstruction(HloInstruction::CreateGetTupleElement(
    567           tuple_element1_shape, tuple_param0, 1));
    568   auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
    569       tuple_element1_shape, HloOpcode::kAdd, tuple_element0, tuple_element1));
    570 
    571   // Create output tuple.
    572   auto tuple_root =
    573       builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
    574 
    575   auto module = CreateNewModule();
    576   module->AddEntryComputation(BuildDummyComputation());
    577   module->AddEmbeddedComputation(builder.Build());
    578 
    579   auto liveness =
    580       BufferLiveness::Run(module.get(),
    581                           xla::MakeUnique<DependencyHloOrdering>(module.get()))
    582           .ConsumeValueOrDie();
    583 
    584   // We compare tuple element pairs that are input/output to the computation:
    585   // 1) (input_tuple_element, output_tuple_element) = ('tuple_element0', 'add0')
    586   // 2) (input_tuple_element, output_tuple_element) = ('tuple_element1', 'add1')
    587 
    588   // The first tuple element pair output 'add0', has no dependency on second
    589   // tuple element pairs input 'tuple_element1'.
    590 
    591   // The second tuple element pair output 'add1', has a dependency on first
    592   // tuple element pairs input 'tuple_element0'.
    593 
    594   // The first tuple element pair does interfere, because liveness cannot
    595   // compute that all references to 'tuple_element0' are executed before 'add0'
    596   // (because of the depenency of 'add1' on 'tuple_element0').
    597   EXPECT_TRUE(
    598       TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {0}));
    599 
    600   // The second tuple element pair does not interfere, because there is no
    601   // other dependency on 'tuple_element1', and so liveness can compute that
    602   // all users execute before 'add1'.
    603   EXPECT_FALSE(
    604       TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1}));
    605 }
    606 
    607 class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
    608  protected:
    609   // Builds and runs a computation (see test case computation graphs below).
    610   // Runs BufferLiveness on this computation.
    611   // Returns whether buffer interference is detected between tuple-shaped
    612   // parameter and root instructions at tuple element 1.
    613   bool Run(const bool update_uses_tuple_element1,
    614            const bool fuse_gte0 = false) {
    615     auto builder = HloComputation::Builder(TestName());
    616     // Create param0 Tuple.
    617     Shape data_shape = ShapeUtil::MakeShape(F32, {8});
    618     Shape update_shape = ShapeUtil::MakeShape(F32, {3});
    619     auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter(
    620         0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "param0"));
    621 
    622     auto gte0 = builder.AddInstruction(
    623         HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 0));
    624 
    625     auto gte1 = builder.AddInstruction(
    626         HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1));
    627 
    628     auto update = builder.AddInstruction(HloInstruction::CreateConstant(
    629         Literal::CreateR1<float>({2.f, 2.f, 2.f})));
    630     HloInstruction* slice = nullptr;
    631     if (update_uses_tuple_element1) {
    632       // Create a slice instruction as an additional user of 'gte1'.
    633       slice = builder.AddInstruction(
    634           HloInstruction::CreateSlice(update_shape, gte1, {0}, {3}, {1}));
    635       update = builder.AddInstruction(HloInstruction::CreateBinary(
    636           update_shape, HloOpcode::kAdd, update, slice));
    637     }
    638     // Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'.
    639     auto starts = builder.AddInstruction(
    640         HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
    641     auto dynamic_update_slice =
    642         builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
    643             data_shape, gte1, update, starts));
    644     // Create output tuple.
    645     auto tuple_root = builder.AddInstruction(
    646         HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
    647     // Build module and get reference to entry computation.
    648     auto module = CreateNewModule();
    649     module->AddEntryComputation(BuildDummyComputation());
    650     auto* computation = module->AddEmbeddedComputation(builder.Build());
    651     // Create fusion instruction based on number of tuple element 1 users.
    652     if (update_uses_tuple_element1) {
    653       computation->CreateFusionInstruction(
    654           {dynamic_update_slice, starts, update, CHECK_NOTNULL(slice), gte1},
    655           HloInstruction::FusionKind::kLoop);
    656     } else {
    657       computation->CreateFusionInstruction(
    658           {dynamic_update_slice, starts, update, gte1},
    659           HloInstruction::FusionKind::kLoop);
    660     }
    661     // Create fusion instruction for tuple element 0 (if requested).
    662     if (fuse_gte0) {
    663       computation->CreateFusionInstruction({gte0},
    664                                            HloInstruction::FusionKind::kLoop);
    665     }
    666 
    667     // Run BufferLiveness on 'module'.
    668     auto liveness =
    669         BufferLiveness::Run(
    670             module.get(), xla::MakeUnique<DependencyHloOrdering>(module.get()))
    671             .ConsumeValueOrDie();
    672     // Return whether or not buffers interference is detected between
    673     // 'tuple_param0' and 'tuple_root' at shape index '{1}'.
    674     return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1});
    675   }
    676 };
    677 
    678 // Tests that live ranges of buffers Param0[1] and Tuple[1] (which alias fusion)
    679 // do not overlap with the following computation:
    680 //
    681 //         Param0
    682 //        /     \
    683 //     GTE(0)  Fusion ----------->  FusionParam
    684 //        |      |                      |
    685 //        |      |                    GTE(1) Const Const
    686 //        |      |                      \      |    /
    687 //        |      |                    DynamicUpdateSlice  // fused root
    688 //         \    /
    689 //          Tuple  // computation root
    690 //
    691 TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterference) {
    692   EXPECT_FALSE(Run(/*update_uses_tuple_element1=*/false));
    693 }
    694 
    695 // Tests that live ranges of buffers Param0[1] and Tuple[1] (which aliases
    696 // 'fusion1') do not overlap in the presence of another fusion instruction
    697 // (which is a user of 'param0' at a different tuple index).
    698 // BufferLiveness should detect no uses of Param0 at index {1} in Fusion0
    699 // (because Fusion0 only uses Param0 at index {0}).
    700 //
    701 //                               Param0
    702 //                               /    \
    703 //      FusionParam  <----- Fusion0  Fusion1 ------>  FusionParam
    704 //         |                    |      |                 |
    705 //        GTE(0)                |      |               GTE(1) Const Const
    706 //                              |      |                  \      |    /
    707 //                               \    /                DynamicUpdateSlice
    708 //                               Tuple
    709 //
    710 TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterferenceWithUnrelatedFusion) {
    711   EXPECT_FALSE(Run(/*update_uses_tuple_element1=*/false, /*fuse_gte0=*/true));
    712 }
    713 
    714 // Tests that live ranges of buffers Param0[1] and Tuple[1] (which alias fusion)
    715 // do overlap because GTE(1) has two users:
    716 // 1) DynamicUpdateSlice at operand 0.
    717 // 2) Slice at operand 0.
    718 //
    719 //         Param0
    720 //        /     \   Const
    721 //       /       \  /
    722 //     GTE(0)  Fusion ----------->  FusionParam FusionParam
    723 //        |      |                      |         |
    724 //        |      |                    GTE(1)      /
    725 //        |      |                      | \      /
    726 //        |      |                      | Slice /
    727 //        |      |                      |   \  /
    728 //        |      |                      |   Add   Const
    729 //        |      |                      |    |      |
    730 //        |      |                    DynamicUpdateSlice  // fused root
    731 //         \    /
    732 //          Tuple  // computation root
    733 //
    734 TEST_F(FusedDynamicUpdateSliceLivenessTest, WithInterference) {
    735   EXPECT_TRUE(Run(/*update_uses_tuple_element1=*/true));
    736 }
    737 
    738 class DynamicUpdateSliceLivenessTest : public BufferLivenessTest {
    739  protected:
    740   // Builds and runs a computation (see test case computation graphs below).
    741   // Runs BufferLiveness on this computation.
    742   // Returns whether buffer interference is detected between tuple-shaped
    743   // parameter and root instructions at tuple element 1.
    744   bool Run(const bool tuple_element1_has_two_uses) {
    745     auto builder = HloComputation::Builder(TestName());
    746     // Create param0 Tuple.
    747     Shape data_shape = ShapeUtil::MakeShape(F32, {8});
    748     Shape update_shape = ShapeUtil::MakeShape(F32, {3});
    749     auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter(
    750         0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "param0"));
    751 
    752     auto gte0 = builder.AddInstruction(
    753         HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 0));
    754 
    755     auto gte1 = builder.AddInstruction(
    756         HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1));
    757 
    758     auto update = builder.AddInstruction(HloInstruction::CreateConstant(
    759         Literal::CreateR1<float>({2.f, 2.f, 2.f})));
    760 
    761     if (tuple_element1_has_two_uses) {
    762       // Add 'gte0' and 'gte1' to create another user of 'gte1'.
    763       gte0 = builder.AddInstruction(HloInstruction::CreateBinary(
    764           data_shape, HloOpcode::kAdd, gte0, gte1));
    765     }
    766     // Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'.
    767     auto starts = builder.AddInstruction(
    768         HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
    769     auto dynamic_update_slice =
    770         builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
    771             data_shape, gte1, update, starts));
    772     // Create output tuple.
    773     auto tuple_root = builder.AddInstruction(
    774         HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
    775     // Build module and get reference to entry computation.
    776     auto module = CreateNewModule();
    777     module->AddEntryComputation(BuildDummyComputation());
    778     module->AddEmbeddedComputation(builder.Build());
    779     // Run BufferLiveness on 'module'.
    780     auto liveness =
    781         BufferLiveness::Run(
    782             module.get(), xla::MakeUnique<DependencyHloOrdering>(module.get()))
    783             .ConsumeValueOrDie();
    784     // Return whether or not buffers interference is detected between
    785     // 'tuple_param0' and 'tuple_root' at shape index '{1}'.
    786     return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1});
    787   }
    788 };
    789 
    790 // Tests that live ranges of buffers Param0[1] and Tuple[1] do not overlap in
    791 // the following computation (because DynamicUpdateSlice (at operand 0) is the
    792 // unique user):
    793 //
    794 //     Parameter0
    795 //      |      |
    796 //    GTE(0) GTE(1) Const Const
    797 //      |      \      |    /
    798 //      |    DynamicUpdateSlice
    799 //       \    /
    800 //        Tuple
    801 //
    802 TEST_F(DynamicUpdateSliceLivenessTest, NoInterference) {
    803   EXPECT_FALSE(Run(/*tuple_element1_has_two_uses=*/false));
    804 }
    805 
    806 // Tests that live ranges of buffers Param0[1] and Tuple[1] do overlap because
    807 // GTE(1) has two users:
    808 // 1) DynamicUpdateSlice at operand 0.
    809 // 2) Add at operand 1.
    810 //
    811 //     Parameter0
    812 //      |      |
    813 //    GTE(0) GTE(1)
    814 //      |   /  |
    815 //      |  /   |
    816 //      Add    |     Const Const
    817 //      |      |      |      |
    818 //      |    DynamicUpdateSlice
    819 //       \    /
    820 //        Tuple
    821 //
    822 TEST_F(DynamicUpdateSliceLivenessTest, WithInterference) {
    823   EXPECT_TRUE(Run(/*tuple_element1_has_two_uses=*/true));
    824 }
    825 
    826 }  // namespace
    827 
    828 }  // namespace xla
    829