Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
     17 
     18 #include <map>
     19 #include <memory>
     20 
     21 #include "tensorflow/compiler/xla/literal_util.h"
     22 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
     23 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
     24 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
     25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     26 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
     27 #include "tensorflow/compiler/xla/service/instruction_fusion.h"
     28 #include "tensorflow/compiler/xla/shape_util.h"
     29 #include "tensorflow/compiler/xla/test.h"
     30 #include "tensorflow/compiler/xla/test_helpers.h"
     31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     32 #include "tensorflow/compiler/xla/xla_data.pb.h"
     33 #include "tensorflow/core/lib/core/status_test_util.h"
     34 #include "tensorflow/core/platform/logging.h"
     35 #include "tensorflow/core/platform/test.h"
     36 
     37 namespace xla {
     38 namespace {
     39 
     40 using ::testing::UnorderedElementsAre;
     41 
     42 class HloAliasAnalysisTest : public HloTestBase {
     43  protected:
     44   HloAliasAnalysisTest() : module_(CreateNewModule()) {}
     45 
     46   // Run alias analysis on the member module. For convenience returns a
     47   // reference to the generated analysis stored in analysis_.
     48   HloAliasAnalysis& RunAnalysis() {
     49     hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before alias analysis");
     50     analysis_ = HloAliasAnalysis::Run(module_.get()).ConsumeValueOrDie();
     51     return *analysis_;
     52   }
     53 
     54   // Return a vector of the buffers in the buffer set at the current position
     55   // sorted by buffer id.
     56   std::vector<HloBuffer> GetBuffersAt(const HloInstruction* instruction,
     57                                       const ShapeIndex& index = {}) const {
     58     std::set<HloBuffer::Id> buffer_ids;
     59     for (const HloValue* value : analysis_->dataflow_analysis()
     60                                      .GetValueSet(instruction, index)
     61                                      .values()) {
     62       buffer_ids.insert(analysis_->GetBufferContainingValue(*value).id());
     63     }
     64 
     65     std::vector<HloBuffer> buffers;
     66     for (HloBuffer::Id id : buffer_ids) {
     67       buffers.push_back(analysis_->GetBuffer(id));
     68     }
     69     return buffers;
     70   }
     71 
     72   // Return a vector containing all of the HloValues in the given buffer.
     73   std::vector<HloValue> GetValuesInBuffer(const HloBuffer& buffer) {
     74     std::vector<HloValue> values;
     75     for (const HloValue* value : buffer.values()) {
     76       values.push_back(*value);
     77     }
     78     return values;
     79   }
     80 
     81   // Return the HloValue defined at the given position.
     82   const HloValue& GetValueDefinedAt(const HloInstruction* instruction,
     83                                     const ShapeIndex& index = {}) const {
     84     return analysis_->dataflow_analysis().GetValueDefinedAt(instruction, index);
     85   }
     86 
     87   // Returns true if any values held in the same buffer interfere. Generally, in
     88   // the compiler pipeline copy-insertion will guarantee that this interference
     89   // never occurs, but HLO graphs with interference can be explicitly
     90   // constructed.
     91   bool AnyValuesInSameBufferInterfere() {
     92     DependencyHloOrdering ordering(module_.get());
     93     for (const HloBuffer& buffer : analysis_->buffers()) {
     94       for (const HloValue* value_a : buffer.values()) {
     95         for (const HloValue* value_b : buffer.values()) {
     96           if (*value_a != *value_b &&
     97               ordering.MayInterfere(*value_a, *value_b,
     98                                     analysis_->dataflow_analysis())) {
     99             VLOG(1) << *value_a << " interferes with " << *value_b
    100                     << " in buffer: " << buffer;
    101             return true;
    102           }
    103         }
    104       }
    105     }
    106     return false;
    107   }
    108 
    109   std::unique_ptr<HloModule> module_;
    110   std::unique_ptr<HloAliasAnalysis> analysis_;
    111 
    112   const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {});
    113 };
    114 
    115 TEST_F(HloAliasAnalysisTest, BinaryOperation) {
    116   // Test the analysis on a single binary operation (Add).
    117   auto builder = HloComputation::Builder(TestName());
    118   auto constant1 = builder.AddInstruction(
    119       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    120   auto constant2 = builder.AddInstruction(
    121       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
    122   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
    123       scalar_shape_, HloOpcode::kAdd, constant1, constant2));
    124   module_->AddEntryComputation(builder.Build());
    125 
    126   const HloAliasAnalysis& analysis = RunAnalysis();
    127 
    128   EXPECT_EQ(analysis.buffers().size(), 3);
    129 
    130   // All of the buffer sets should trivially contain a single buffer containing
    131   // a single value.
    132   for (const HloInstruction* instruction : {constant1, constant2, add}) {
    133     EXPECT_EQ(analysis.GetUniqueBufferAt(instruction).GetUniqueValue(),
    134               GetValueDefinedAt(instruction));
    135   }
    136 
    137   EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(add));
    138   EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(add));
    139 
    140   EXPECT_FALSE(AnyValuesInSameBufferInterfere());
    141 }
    142 
    143 TEST_F(HloAliasAnalysisTest, TupleAndGtes) {
    144   // Verify the analysis for a Tuple and GetTupleElement instructions.
    145   auto builder = HloComputation::Builder(TestName());
    146   auto param0 = builder.AddInstruction(
    147       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
    148   auto param1 = builder.AddInstruction(
    149       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
    150   auto tuple =
    151       builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
    152   auto gte0 = builder.AddInstruction(
    153       HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 0));
    154   auto gte1 = builder.AddInstruction(
    155       HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1));
    156   builder.AddInstruction(
    157       HloInstruction::CreateBinary(scalar_shape_, HloOpcode::kAdd, gte0, gte1));
    158   module_->AddEntryComputation(builder.Build());
    159 
    160   const HloAliasAnalysis& analysis = RunAnalysis();
    161 
    162   EXPECT_EQ(analysis.buffers().size(), 4);
    163 
    164   // Verify the expected aliasing of the tuple elements.
    165   EXPECT_EQ(analysis.GetUniqueBufferAt(tuple, /*index=*/{}).GetUniqueValue(),
    166             GetValueDefinedAt(tuple, /*index=*/{}));
    167   EXPECT_EQ(analysis.GetUniqueBufferAt(tuple, /*index=*/{0}).GetUniqueValue(),
    168             GetValueDefinedAt(param0));
    169   EXPECT_EQ(analysis.GetUniqueBufferAt(tuple, /*index=*/{1}).GetUniqueValue(),
    170             GetValueDefinedAt(param1));
    171 
    172   // The tuple operand, tuple element, and result of the GTE instruction should
    173   // all be the same buffer.
    174   EXPECT_EQ(analysis.GetUniqueBufferAt(param0),
    175             analysis.GetUniqueBufferAt(tuple, /*index=*/{0}));
    176   EXPECT_EQ(analysis.GetUniqueBufferAt(param0),
    177             analysis.GetUniqueBufferAt(gte0));
    178 
    179   // Verify the positions of an aliased buffer.
    180   EXPECT_THAT(
    181       analysis.GetUniqueBufferAt(param0).ComputePositions(),
    182       UnorderedElementsAre(HloPosition{param0, {}}, HloPosition{tuple, {0}},
    183                            HloPosition{gte0, {}}));
    184 
    185   EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(tuple));
    186   EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(tuple));
    187 
    188   EXPECT_FALSE(AnyValuesInSameBufferInterfere());
    189 }
    190 
    191 TEST_F(HloAliasAnalysisTest, NondistinctTuple) {
    192   // Test a expression with a non-distinct buffer set.
    193   auto builder = HloComputation::Builder(TestName());
    194   auto param0 = builder.AddInstruction(
    195       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
    196   auto param1 = builder.AddInstruction(
    197       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
    198   // param0 is included twice in the tuple.
    199   auto tuple = builder.AddInstruction(
    200       HloInstruction::CreateTuple({param0, param1, param0}));
    201   module_->AddEntryComputation(builder.Build());
    202 
    203   const HloAliasAnalysis& analysis = RunAnalysis();
    204 
    205   EXPECT_THAT(
    206       analysis.GetUniqueBufferAt(param0).ComputePositions(),
    207       UnorderedElementsAre(HloPosition{param0, {}}, HloPosition{tuple, {0}},
    208                            HloPosition{tuple, {2}}));
    209 
    210   EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(tuple));
    211   EXPECT_FALSE(analysis.InstructionBuffersAreDistinct(tuple));
    212 
    213   EXPECT_FALSE(AnyValuesInSameBufferInterfere());
    214 }
    215 
    216 TEST_F(HloAliasAnalysisTest, SingleCall) {
    217   // Test a single call of a subcomputation. The subcomputation adds its two
    218   // array-shaped parameters.
    219   auto subbuilder = HloComputation::Builder("Subcomputation");
    220   auto subparam0 = subbuilder.AddInstruction(
    221       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
    222   auto subparam1 = subbuilder.AddInstruction(
    223       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
    224   auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary(
    225       scalar_shape_, HloOpcode::kAdd, subparam0, subparam1));
    226   HloComputation* called_computation =
    227       module_->AddEmbeddedComputation(subbuilder.Build());
    228 
    229   auto builder = HloComputation::Builder(TestName());
    230   auto constant1 = builder.AddInstruction(
    231       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    232   auto constant2 = builder.AddInstruction(
    233       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
    234   auto call = builder.AddInstruction(HloInstruction::CreateCall(
    235       scalar_shape_, {constant1, constant2}, called_computation));
    236   module_->AddEntryComputation(builder.Build());
    237 
    238   const HloAliasAnalysis& analysis = RunAnalysis();
    239 
    240   // Verify aliasing of the kCall operands and the subcomputation parameters.
    241   EXPECT_THAT(analysis.GetUniqueBufferAt(constant1).ComputePositions(),
    242               UnorderedElementsAre(HloPosition{constant1, {}},
    243                                    HloPosition{subparam0, {}}));
    244   EXPECT_THAT(analysis.GetUniqueBufferAt(constant2).ComputePositions(),
    245               UnorderedElementsAre(HloPosition{constant2, {}},
    246                                    HloPosition{subparam1, {}}));
    247 
    248   // The subcomputation root and the kCall itself should alias.
    249   EXPECT_THAT(
    250       analysis.GetUniqueBufferAt(add).ComputePositions(),
    251       UnorderedElementsAre(HloPosition{add, {}}, HloPosition{call, {}}));
    252 
    253   EXPECT_FALSE(AnyValuesInSameBufferInterfere());
    254 }
    255 
    256 TEST_F(HloAliasAnalysisTest, ComputationCalledTwice) {
    257   // Test a subcomputation which is called twice with different argument values.
    258   auto subbuilder = HloComputation::Builder("Subcomputation");
    259   auto subparam0 = subbuilder.AddInstruction(
    260       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
    261   auto subparam1 = subbuilder.AddInstruction(
    262       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
    263   auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary(
    264       scalar_shape_, HloOpcode::kAdd, subparam0, subparam1));
    265   HloComputation* called_computation =
    266       module_->AddEmbeddedComputation(subbuilder.Build());
    267 
    268   auto builder = HloComputation::Builder(TestName());
    269   auto constant1 = builder.AddInstruction(
    270       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    271   auto constant2 = builder.AddInstruction(
    272       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
    273   auto call1 = builder.AddInstruction(HloInstruction::CreateCall(
    274       scalar_shape_, {constant1, constant2}, called_computation));
    275   auto call2 = builder.AddInstruction(HloInstruction::CreateCall(
    276       scalar_shape_, {call1, constant2}, called_computation));
    277   module_->AddEntryComputation(builder.Build());
    278 
    279   const HloAliasAnalysis& analysis = RunAnalysis();
    280 
    281   EXPECT_THAT(analysis.GetUniqueBufferAt(constant1).ComputePositions(),
    282               UnorderedElementsAre(HloPosition{constant1, {}},
    283                                    HloPosition{subparam0, {}}));
    284   EXPECT_THAT(analysis.GetUniqueBufferAt(constant2).ComputePositions(),
    285               UnorderedElementsAre(HloPosition{constant2, {}},
    286                                    HloPosition{subparam1, {}}));
    287 
    288   // The 'add' (root of the subcomputation) aliases the two call instruction,
    289   // and the first parameter of the subcomputation because 'call1' it is passed
    290   // as an argument to the subcomputation in 'call2'.
    291   EXPECT_THAT(
    292       analysis.GetUniqueBufferAt(add).ComputePositions(),
    293       UnorderedElementsAre(HloPosition{add, {}}, HloPosition{call1, {}},
    294                            HloPosition{subparam0, {}}, HloPosition{call2, {}}));
    295 
    296   EXPECT_THAT(GetBuffersAt(subparam0),
    297               UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1),
    298                                    analysis.GetUniqueBufferAt(add)));
    299   EXPECT_THAT(GetBuffersAt(subparam1),
    300               UnorderedElementsAre(analysis.GetUniqueBufferAt(constant2)));
    301 
    302   EXPECT_TRUE(analysis.InstructionBuffersAreAmbiguous(subparam0));
    303   EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(subparam1));
    304   EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(subparam0));
    305   EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(subparam1));
    306 
    307   EXPECT_FALSE(AnyValuesInSameBufferInterfere());
    308 }
    309 
    310 TEST_F(HloAliasAnalysisTest, SingleWhile) {
    311   // Test a simple single while instruction. The while body includes a
    312   // pass-through value. HLO:
    313   //
    314   // body((F32[], F32[]) %tuple_param):
    315   //   %add = Add(%tuple_param{0}, %tuple_param{1})
    316   //   return Tuple(%tuple_param{0}, %add)
    317   //
    318   // condition((F32[], F32[]) %tuple_param):
    319   //   return Constant(false)
    320   //
    321   // entry:
    322   //   %constant1 = Constant(1.0)
    323   //   %constant2 = Constant(2.0)
    324   //   %tuple = Tuple(%constant1, %constant2)
    325   //   return While(%tuple, body, condition)
    326   //
    327   const Shape tuple_shape =
    328       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
    329 
    330   // Element 0 passes transparently through the body.
    331   auto body_builder = HloComputation::Builder("body");
    332   auto body_param = body_builder.AddInstruction(
    333       HloInstruction::CreateParameter(0, tuple_shape, "param"));
    334   auto body_element_0 = body_builder.AddInstruction(
    335       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
    336   auto body_element_1 = body_builder.AddInstruction(
    337       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
    338   auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
    339       scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1));
    340   auto body_tuple = body_builder.AddInstruction(
    341       HloInstruction::CreateTuple({body_element_0, add}));
    342   HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
    343 
    344   // Condition computation trivially returns a constant "false".
    345   auto cond_builder = HloComputation::Builder("condition");
    346   auto cond_param = cond_builder.AddInstruction(
    347       HloInstruction::CreateParameter(0, tuple_shape, "param"));
    348   cond_builder.AddInstruction(
    349       HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
    350   HloComputation* condition =
    351       module_->AddEmbeddedComputation(cond_builder.Build());
    352 
    353   auto builder = HloComputation::Builder(TestName());
    354   auto constant1 = builder.AddInstruction(
    355       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    356   auto constant2 = builder.AddInstruction(
    357       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
    358   auto tuple = builder.AddInstruction(
    359       HloInstruction::CreateTuple({constant1, constant2}));
    360   auto xla_while = builder.AddInstruction(
    361       HloInstruction::CreateWhile(tuple_shape, condition, body, tuple));
    362   module_->AddEntryComputation(builder.Build());
    363 
    364   const HloAliasAnalysis& analysis = RunAnalysis();
    365 
    366   // Verify the positions of the aliased while buffers.
    367   EXPECT_THAT(
    368       analysis.GetUniqueBufferAt(xla_while, /*index=*/{}).ComputePositions(),
    369       UnorderedElementsAre(HloPosition{tuple, {}}, HloPosition{xla_while, {}},
    370                            HloPosition{body_param, {}},
    371                            HloPosition{body_tuple, {}},
    372                            HloPosition{cond_param, {}}));
    373   EXPECT_THAT(
    374       analysis.GetUniqueBufferAt(xla_while, /*index=*/{0}).ComputePositions(),
    375       UnorderedElementsAre(
    376           HloPosition{constant1, {}}, HloPosition{tuple, {0}},
    377           HloPosition{xla_while, {0}}, HloPosition{body_param, {0}},
    378           HloPosition{body_element_0, {}}, HloPosition{body_tuple, {0}},
    379           HloPosition{cond_param, {0}}));
    380   EXPECT_THAT(
    381       analysis.GetUniqueBufferAt(xla_while, /*index=*/{1}).ComputePositions(),
    382       UnorderedElementsAre(
    383           HloPosition{constant2, {}}, HloPosition{tuple, {1}},
    384           HloPosition{xla_while, {1}}, HloPosition{body_param, {1}},
    385           HloPosition{body_element_1, {}}, HloPosition{add, {}},
    386           HloPosition{body_tuple, {1}}, HloPosition{cond_param, {1}}));
    387 
    388   EXPECT_THAT(
    389       GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0})),
    390       UnorderedElementsAre(GetValueDefinedAt(constant1)));
    391   EXPECT_THAT(
    392       GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{1})),
    393       UnorderedElementsAre(GetValueDefinedAt(constant2),
    394                            GetValueDefinedAt(xla_while, /*index=*/{1}),
    395                            GetValueDefinedAt(body_param, {1}),
    396                            GetValueDefinedAt(cond_param, {1}),
    397                            GetValueDefinedAt(add)));
    398 
    399   EXPECT_FALSE(AnyValuesInSameBufferInterfere());
    400 }
    401 
    402 TEST_F(HloAliasAnalysisTest, SequentialWhiles) {
    403   // Test sequential while instructions. The while body includes a
    404   // pass-through value. HLO:
    405   //
    406   // body((F32[], F32[]) %tuple_param):
    407   //   %add = Add(%tuple_param{0}, %tuple_param{1})
    408   //   return Tuple(%tuple_param{0}, %add)
    409   //
    410   // condition((F32[], F32[]) %tuple_param):
    411   //   return Constant(false)
    412   //
    413   // entry:
    414   //   %constant1 = Constant(1.0)
    415   //   %constant2 = Constant(2.0)
    416   //   %tuple = Tuple(%constant1, %constant2)
    417   //   %while0 = While(%tuple, body, condition)
    418   //   %while1 = While(%while0, body, condition)
    419   //   return While(%while1, body, condition)
    420   //
    421   const Shape tuple_shape =
    422       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
    423 
    424   // Element 0 passes transparently through the body.
    425   auto body_builder = HloComputation::Builder("body");
    426   auto body_param = body_builder.AddInstruction(
    427       HloInstruction::CreateParameter(0, tuple_shape, "param"));
    428   auto body_element_0 = body_builder.AddInstruction(
    429       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
    430   auto body_element_1 = body_builder.AddInstruction(
    431       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
    432   auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
    433       scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1));
    434   body_builder.AddInstruction(
    435       HloInstruction::CreateTuple({body_element_0, add}));
    436   HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
    437 
    438   auto cond_builder = HloComputation::Builder("condition");
    439   cond_builder.AddInstruction(
    440       HloInstruction::CreateParameter(0, tuple_shape, "param"));
    441   cond_builder.AddInstruction(
    442       HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
    443   HloComputation* condition =
    444       module_->AddEmbeddedComputation(cond_builder.Build());
    445 
    446   auto builder = HloComputation::Builder(TestName());
    447   auto constant1 = builder.AddInstruction(
    448       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    449   auto constant2 = builder.AddInstruction(
    450       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
    451   auto tuple = builder.AddInstruction(
    452       HloInstruction::CreateTuple({constant1, constant2}));
    453   auto xla_while0 = builder.AddInstruction(
    454       HloInstruction::CreateWhile(tuple_shape, condition, body, tuple));
    455   auto xla_while1 = builder.AddInstruction(
    456       HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while0));
    457   auto xla_while2 = builder.AddInstruction(
    458       HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while1));
    459   module_->AddEntryComputation(builder.Build());
    460 
    461   FlattenCallGraph flattener;
    462   TF_ASSERT_OK(flattener.Run(module_.get()).status());
    463 
    464   const HloAliasAnalysis& analysis = RunAnalysis();
    465 
    466   EXPECT_EQ(analysis.GetUniqueBufferAt(tuple, /*index=*/{}),
    467             analysis.GetUniqueBufferAt(xla_while2, /*index=*/{}));
    468   EXPECT_EQ(analysis.GetUniqueBufferAt(constant1),
    469             analysis.GetUniqueBufferAt(xla_while2, /*index=*/{0}));
    470   EXPECT_EQ(analysis.GetUniqueBufferAt(constant2),
    471             analysis.GetUniqueBufferAt(xla_while2, /*index=*/{1}));
    472 }
    473 
    474 TEST_F(HloAliasAnalysisTest, NestedWhiles) {
    475   // Test nested while instructions. The inner body passes through element 0 of
    476   // its parameter, and the outer body passes through element 1.  HLO:
    477   //
    478   // inner_body((F32[], F32[]) %tuple_param):
    479   //   %add = Add(%tuple_param{0}, %tuple_param{1})
    480   //   return Tuple(%tuple_param{0}, %add)
    481   //
    482   // outer_body((F32[], F32[]) %tuple_param):
    483   //   %negate = Negate(%tuple_param{0})
    484   //   %tuple = Tuple(%negate, %tuple_param{1})
    485   //   return While(%tuple, inner_body, condition)
    486   //
    487   // entry:
    488   //   %constant1 = Constant(1.0)
    489   //   %constant2 = Constant(2.0)
    490   //   %tuple = Tuple(%constant1, %constant2)
    491   //   return While(%tuple, outer_body, condition)
    492   //
    493   const Shape tuple_shape =
    494       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
    495 
    496   auto build_cond_computation = [&tuple_shape]() {
    497     auto cond_builder = HloComputation::Builder("condition");
    498     cond_builder.AddInstruction(
    499         HloInstruction::CreateParameter(0, tuple_shape, "param"));
    500     cond_builder.AddInstruction(
    501         HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
    502     return cond_builder.Build();
    503   };
    504   // Build separate condition computations so the call graph is flat. The
    505   // callgraph is always flattened in the compiler pipeline, and the flattened
    506   // callgraph enables representative interference analysis.
    507   HloComputation* condition1 =
    508       module_->AddEmbeddedComputation(build_cond_computation());
    509   HloComputation* condition2 =
    510       module_->AddEmbeddedComputation(build_cond_computation());
    511 
    512   // Element 0 passes transparently through the body.
    513   auto inner_builder = HloComputation::Builder("inner_body");
    514   auto inner_param = inner_builder.AddInstruction(
    515       HloInstruction::CreateParameter(0, tuple_shape, "param"));
    516   auto inner_element_0 = inner_builder.AddInstruction(
    517       HloInstruction::CreateGetTupleElement(scalar_shape_, inner_param, 0));
    518   auto inner_element_1 = inner_builder.AddInstruction(
    519       HloInstruction::CreateGetTupleElement(scalar_shape_, inner_param, 1));
    520   auto add = inner_builder.AddInstruction(HloInstruction::CreateBinary(
    521       scalar_shape_, HloOpcode::kAdd, inner_element_0, inner_element_1));
    522   inner_builder.AddInstruction(
    523       HloInstruction::CreateTuple({inner_element_0, add}));
    524   HloComputation* inner_body =
    525       module_->AddEmbeddedComputation(inner_builder.Build());
    526 
    527   // Element 1 passes transparently through the body.
    528   auto outer_builder = HloComputation::Builder("outer_body");
    529   auto outer_param = outer_builder.AddInstruction(
    530       HloInstruction::CreateParameter(0, tuple_shape, "param"));
    531   auto outer_element_0 = outer_builder.AddInstruction(
    532       HloInstruction::CreateGetTupleElement(scalar_shape_, outer_param, 0));
    533   auto negate = outer_builder.AddInstruction(HloInstruction::CreateUnary(
    534       scalar_shape_, HloOpcode::kNegate, outer_element_0));
    535   auto outer_element_1 = outer_builder.AddInstruction(
    536       HloInstruction::CreateGetTupleElement(scalar_shape_, outer_param, 1));
    537   auto outer_tuple = outer_builder.AddInstruction(
    538       HloInstruction::CreateTuple({negate, outer_element_1}));
    539   auto nested_while = outer_builder.AddInstruction(HloInstruction::CreateWhile(
    540       tuple_shape, condition1, inner_body, outer_tuple));
    541   HloComputation* outer_body =
    542       module_->AddEmbeddedComputation(outer_builder.Build());
    543 
    544   auto builder = HloComputation::Builder(TestName());
    545   auto constant1 = builder.AddInstruction(
    546       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    547   auto constant2 = builder.AddInstruction(
    548       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
    549   auto tuple = builder.AddInstruction(
    550       HloInstruction::CreateTuple({constant1, constant2}));
    551   auto entry_while = builder.AddInstruction(
    552       HloInstruction::CreateWhile(tuple_shape, condition2, outer_body, tuple));
    553   module_->AddEntryComputation(builder.Build());
    554 
    555   const HloAliasAnalysis& analysis = RunAnalysis();
    556 
    557   EXPECT_EQ(analysis.GetUniqueBufferAt(constant1),
    558             analysis.GetUniqueBufferAt(entry_while, /*index=*/{0}));
    559   EXPECT_EQ(analysis.GetUniqueBufferAt(constant1),
    560             analysis.GetUniqueBufferAt(nested_while, /*index=*/{0}));
    561   EXPECT_EQ(analysis.GetUniqueBufferAt(constant1),
    562             analysis.GetUniqueBufferAt(inner_element_0));
    563 
    564   EXPECT_EQ(analysis.GetUniqueBufferAt(constant2),
    565             analysis.GetUniqueBufferAt(entry_while, /*index=*/{1}));
    566   EXPECT_EQ(analysis.GetUniqueBufferAt(constant2),
    567             analysis.GetUniqueBufferAt(nested_while, /*index=*/{1}));
    568   EXPECT_EQ(analysis.GetUniqueBufferAt(constant2),
    569             analysis.GetUniqueBufferAt(inner_element_1));
    570 
    571   EXPECT_FALSE(AnyValuesInSameBufferInterfere());
    572 }
    573 
    574 TEST_F(HloAliasAnalysisTest, SwizzlingWhile) {
    575   // Test a while instruction with a body which permutes it's tuple parameter
    576   // elements. HLO:
    577   //
    578   // body((F32[], F32[], F32[]) %tuple_param):
    579   //   return Tuple(%tuple_param{1}, %tuple_param{2}, %tuple_param{0})
    580   //
    581   // condition((F32[], F32[]) %tuple_param):
    582   //   return Constant(false)
    583   //
    584   // entry:
    585   //   %constant1 = Constant(1.0)
    586   //   %constant2 = Constant(2.0)
    587   //   %constant3 = Constant(3.0)
    588   //   %tuple = Tuple(%constant1, %constant2, %constant3)
    589   //   return While(%tuple, body, condition)
    590   //
    591   const Shape tuple_shape =
    592       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_, scalar_shape_});
    593 
    594   auto body_builder = HloComputation::Builder("body");
    595   auto body_param = body_builder.AddInstruction(
    596       HloInstruction::CreateParameter(0, tuple_shape, "param"));
    597   auto body_element_0 = body_builder.AddInstruction(
    598       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
    599   auto body_element_1 = body_builder.AddInstruction(
    600       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
    601   auto body_element_2 = body_builder.AddInstruction(
    602       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 2));
    603   body_builder.AddInstruction(HloInstruction::CreateTuple(
    604       {body_element_1, body_element_2, body_element_0}));
    605   HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
    606 
    607   auto cond_builder = HloComputation::Builder("condition");
    608   cond_builder.AddInstruction(
    609       HloInstruction::CreateParameter(0, tuple_shape, "param"));
    610   auto cond_constant = cond_builder.AddInstruction(
    611       HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
    612   HloComputation* condition =
    613       module_->AddEmbeddedComputation(cond_builder.Build());
    614 
    615   auto builder = HloComputation::Builder(TestName());
    616   auto constant1 = builder.AddInstruction(
    617       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    618   auto constant2 = builder.AddInstruction(
    619       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
    620   auto constant3 = builder.AddInstruction(
    621       HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
    622   auto tuple = builder.AddInstruction(
    623       HloInstruction::CreateTuple({constant1, constant2, constant3}));
    624   auto xla_while = builder.AddInstruction(
    625       HloInstruction::CreateWhile(tuple_shape, condition, body, tuple));
    626   module_->AddEntryComputation(builder.Build());
    627 
    628   const HloAliasAnalysis& analysis = RunAnalysis();
    629 
    630   // The swizzling while makes most positions in the module alias leaving only 3
    631   // HloBuffers.
    632   EXPECT_THAT(
    633       analysis.buffers(),
    634       UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1),
    635                            analysis.GetUniqueBufferAt(tuple, /*index=*/{}),
    636                            analysis.GetUniqueBufferAt(cond_constant)));
    637 
    638   // The tuple elements of the while and the three constant inputs should all be
    639   // smooshed into the same buffer.
    640   EXPECT_EQ(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0}),
    641             analysis.GetUniqueBufferAt(xla_while, /*index=*/{1}));
    642   EXPECT_EQ(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0}),
    643             analysis.GetUniqueBufferAt(xla_while, /*index=*/{2}));
    644   EXPECT_EQ(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0}),
    645             analysis.GetUniqueBufferAt(constant1));
    646   EXPECT_EQ(analysis.GetUniqueBufferAt(constant1),
    647             analysis.GetUniqueBufferAt(constant2));
    648   EXPECT_EQ(analysis.GetUniqueBufferAt(constant1),
    649             analysis.GetUniqueBufferAt(constant3));
    650 
    651   // All elements in of the loop state tuple are forced into the same buffer
    652   // resulting liveness interference.
    653   EXPECT_TRUE(AnyValuesInSameBufferInterfere());
    654 }
    655 
    656 TEST_F(HloAliasAnalysisTest, TupleSelect) {
    657   // Test a kSelect of a tuple value. Non-top-level element flow through the
    658   // instruction.
    659   auto builder = HloComputation::Builder(TestName());
    660   auto pred = builder.AddInstruction(
    661       HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
    662   auto constant1 = builder.AddInstruction(
    663       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    664   auto constant2 = builder.AddInstruction(
    665       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
    666   auto constant3 = builder.AddInstruction(
    667       HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
    668   auto constant4 = builder.AddInstruction(
    669       HloInstruction::CreateConstant(Literal::CreateR0<float>(4.0)));
    670   auto tuple1 =
    671       builder.AddInstruction(HloInstruction::CreateTuple({constant1}));
    672   auto tuple2 =
    673       builder.AddInstruction(HloInstruction::CreateTuple({constant2}));
    674   auto tuple3 =
    675       builder.AddInstruction(HloInstruction::CreateTuple({constant3}));
    676   auto tuple4 =
    677       builder.AddInstruction(HloInstruction::CreateTuple({constant4}));
    678   const Shape tuple_shape = tuple1->shape();
    679   auto select11 = builder.AddInstruction(HloInstruction::CreateTernary(
    680       tuple_shape, HloOpcode::kSelect, pred, tuple1, tuple1));
    681   auto select12 = builder.AddInstruction(HloInstruction::CreateTernary(
    682       tuple_shape, HloOpcode::kSelect, pred, tuple1, tuple2));
    683   auto select34 = builder.AddInstruction(HloInstruction::CreateTernary(
    684       tuple_shape, HloOpcode::kSelect, pred, tuple3, tuple4));
    685   auto select1234 = builder.AddInstruction(HloInstruction::CreateTernary(
    686       tuple_shape, HloOpcode::kSelect, pred, select12, select34));
    687 
    688   module_->AddEntryComputation(builder.Build());
    689 
    690   const HloAliasAnalysis& analysis = RunAnalysis();
    691 
    692   // Verify the buffer sets of each select.
    693   EXPECT_THAT(GetBuffersAt(select11, /*index=*/{0}),
    694               UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1)));
    695   EXPECT_THAT(GetBuffersAt(select12, /*index=*/{0}),
    696               UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1),
    697                                    analysis.GetUniqueBufferAt(constant2)));
    698   EXPECT_THAT(GetBuffersAt(select34, /*index=*/{0}),
    699               UnorderedElementsAre(analysis.GetUniqueBufferAt(constant3),
    700                                    analysis.GetUniqueBufferAt(constant4)));
    701   EXPECT_THAT(GetBuffersAt(select1234, /*index=*/{0}),
    702               UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1),
    703                                    analysis.GetUniqueBufferAt(constant2),
    704                                    analysis.GetUniqueBufferAt(constant3),
    705                                    analysis.GetUniqueBufferAt(constant4)));
    706 
    707   EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(select11));
    708   EXPECT_TRUE(analysis.InstructionBuffersAreAmbiguous(select12));
    709   EXPECT_TRUE(analysis.InstructionBuffersAreAmbiguous(select34));
    710   EXPECT_TRUE(analysis.InstructionBuffersAreAmbiguous(select1234));
    711 
    712   EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(select11));
    713   EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(select12));
    714   EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(select34));
    715   EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(select1234));
    716 
    717   EXPECT_FALSE(AnyValuesInSameBufferInterfere());
    718 }
    719 
    720 TEST_F(HloAliasAnalysisTest, TupleSelectToWhile) {
    721   // Test a tuple-shaped kSelect feeding a kWhile instruction. HLO:
    722   //
    723   // body((F32[], F32[]) %tuple_param):
    724   //   %negate = Negate(%tuple_param{0})
    725   //   return Tuple(%negate)
    726   //
    727   // condition((F32[], F32[]) %tuple_param):
    728   //   return Constant(false)
    729   //
    730   // entry:
    731   //   %constant1 = Constant(1.0)
    732   //   %constant2 = Constant(2.0)
    733   //   %tuple1 = Tuple(%constant1)
    734   //   %tuple2 = Tuple(%constant2)
    735   //   %select = Select(%tuple1, %tuple2)
    736   //   return While(%select, body, condition)
    737   //
    738   auto builder = HloComputation::Builder(TestName());
    739 
    740   const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_});
    741 
    742   // Element 0 passes transparently through the body.
    743   auto body_builder = HloComputation::Builder("body");
    744   auto body_param = body_builder.AddInstruction(
    745       HloInstruction::CreateParameter(0, tuple_shape, "param"));
    746   auto body_element = body_builder.AddInstruction(
    747       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
    748   auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
    749       scalar_shape_, HloOpcode::kNegate, body_element));
    750   body_builder.AddInstruction(HloInstruction::CreateTuple({negate}));
    751   HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
    752 
    753   auto cond_builder = HloComputation::Builder("condition");
    754   auto cond_param = cond_builder.AddInstruction(
    755       HloInstruction::CreateParameter(0, tuple_shape, "param"));
    756   cond_builder.AddInstruction(
    757       HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
    758   HloComputation* condition =
    759       module_->AddEmbeddedComputation(cond_builder.Build());
    760 
    761   auto pred = builder.AddInstruction(
    762       HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
    763   auto constant1 = builder.AddInstruction(
    764       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    765   auto constant2 = builder.AddInstruction(
    766       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
    767   auto tuple1 =
    768       builder.AddInstruction(HloInstruction::CreateTuple({constant1}));
    769   auto tuple2 =
    770       builder.AddInstruction(HloInstruction::CreateTuple({constant2}));
    771   auto select = builder.AddInstruction(HloInstruction::CreateTernary(
    772       tuple_shape, HloOpcode::kSelect, pred, tuple1, tuple2));
    773   auto xla_while = builder.AddInstruction(
    774       HloInstruction::CreateWhile(tuple_shape, condition, body, select));
    775 
    776   module_->AddEntryComputation(builder.Build());
    777 
    778   const HloAliasAnalysis& analysis = RunAnalysis();
    779 
    780   // The while should flatten the ambiguous select buffer set so that the buffer
    781   // set contents (constant1 and constant2) becomes a single buffer.
    782   EXPECT_EQ(analysis.GetUniqueBufferAt(constant1),
    783             analysis.GetUniqueBufferAt(constant2));
    784   EXPECT_EQ(analysis.GetUniqueBufferAt(constant1),
    785             analysis.GetUniqueBufferAt(xla_while, /*index=*/{0}));
    786 
    787   EXPECT_THAT(GetValuesInBuffer(analysis.GetUniqueBufferAt(constant1)),
    788               UnorderedElementsAre(GetValueDefinedAt(constant1),
    789                                    GetValueDefinedAt(constant2),
    790                                    GetValueDefinedAt(xla_while, /*index=*/{0}),
    791                                    GetValueDefinedAt(body_param, /*index=*/{0}),
    792                                    GetValueDefinedAt(cond_param, /*index=*/{0}),
    793                                    GetValueDefinedAt(negate)));
    794   EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(select));
    795   EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(xla_while));
    796 
    797   EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(select));
    798   EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(xla_while));
    799 
    800   // The two operands of the select get flattened into the same buffer resulting
    801   // in liveness interference.
    802   EXPECT_TRUE(AnyValuesInSameBufferInterfere());
    803 }
    804 
    805 TEST_F(HloAliasAnalysisTest, Bitcast) {
    806   // Bitcasting a value should not produce a new buffer.
    807   auto builder = HloComputation::Builder(TestName());
    808   auto constant = builder.AddInstruction(
    809       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    810   auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
    811       scalar_shape_, HloOpcode::kBitcast, constant));
    812 
    813   module_->AddEntryComputation(builder.Build());
    814 
    815   const HloAliasAnalysis& analysis = RunAnalysis();
    816 
    817   EXPECT_EQ(analysis.buffers().size(), 1);
    818 
    819   EXPECT_EQ(analysis.GetUniqueBufferAt(constant),
    820             analysis.GetUniqueBufferAt(bitcast));
    821 }
    822 
    823 TEST_F(HloAliasAnalysisTest, BitcastInterference) {
    824   // A bitcast value simultaneously live with its operand should not cause
    825   // interference.
    826   auto builder = HloComputation::Builder(TestName());
    827   auto constant = builder.AddInstruction(
    828       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    829   auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
    830       scalar_shape_, HloOpcode::kBitcast, constant));
    831   builder.AddInstruction(HloInstruction::CreateTuple({constant, bitcast}));
    832 
    833   module_->AddEntryComputation(builder.Build());
    834 
    835   const HloAliasAnalysis& analysis = RunAnalysis();
    836 
    837   DependencyHloOrdering ordering(module_.get());
    838   EXPECT_FALSE(analysis.HasLiveRangeInterference(ordering));
    839 }
    840 
    841 TEST_F(HloAliasAnalysisTest, WhileInterference) {
    842   // Build a while loop which has a parallel use of the init value. Depending on
    843   // ordering there may be interference between the update-in-place while and
    844   // the other use of the init.
    845   auto builder = HloComputation::Builder(TestName());
    846   auto init = builder.AddInstruction(
    847       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    848 
    849   auto cond_builder = HloComputation::Builder("condition");
    850   auto cond_param = cond_builder.AddInstruction(
    851       HloInstruction::CreateParameter(0, init->shape(), "param"));
    852   auto cond_root = cond_builder.AddInstruction(
    853       HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
    854   HloComputation* condition =
    855       module_->AddEmbeddedComputation(cond_builder.Build());
    856 
    857   auto body_builder = HloComputation::Builder("body");
    858   auto body_param = body_builder.AddInstruction(
    859       HloInstruction::CreateParameter(0, init->shape(), "param"));
    860   auto body_root = body_builder.AddInstruction(
    861       HloInstruction::CreateUnary(init->shape(), HloOpcode::kExp, body_param));
    862   HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
    863 
    864   auto xla_while = builder.AddInstruction(
    865       HloInstruction::CreateWhile(init->shape(), condition, body, init));
    866 
    867   auto negate = builder.AddInstruction(
    868       HloInstruction::CreateUnary(init->shape(), HloOpcode::kNegate, init));
    869   auto entry_root =
    870       builder.AddInstruction(HloInstruction::CreateTuple({negate, xla_while}));
    871 
    872   HloComputation* entry = module_->AddEntryComputation(builder.Build());
    873 
    874   const HloAliasAnalysis& analysis = RunAnalysis();
    875 
    876   {
    877     // Dependency ordering should interfere because the negate and while are
    878     // unordered.
    879     DependencyHloOrdering ordering(module_.get());
    880     EXPECT_TRUE(analysis.HasLiveRangeInterference(ordering));
    881   }
    882 
    883   // For a sequential order, if there is interference iff the negate is after
    884   // the while.
    885   SequentialHloOrdering::HloModuleSequence sequence;
    886   sequence[body] = {body_param, body_root};
    887   sequence[condition] = {cond_param, cond_root};
    888   {
    889     sequence[entry] = {init, xla_while, negate, entry_root};
    890     SequentialHloOrdering ordering(module_.get(), sequence);
    891     EXPECT_TRUE(analysis.HasLiveRangeInterference(ordering));
    892   }
    893 
    894   {
    895     sequence[entry] = {init, negate, xla_while, entry_root};
    896     SequentialHloOrdering ordering(module_.get(), sequence);
    897     EXPECT_FALSE(analysis.HasLiveRangeInterference(ordering));
    898   }
    899 }
    900 
    901 }  // namespace
    902 }  // namespace xla
    903