Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
     17 
     18 #include <map>
     19 #include <memory>
     20 
     21 #include "tensorflow/compiler/xla/literal_util.h"
     22 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
     23 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     24 #include "tensorflow/compiler/xla/service/instruction_fusion.h"
     25 #include "tensorflow/compiler/xla/shape_util.h"
     26 #include "tensorflow/compiler/xla/test.h"
     27 #include "tensorflow/compiler/xla/test_helpers.h"
     28 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     29 #include "tensorflow/compiler/xla/xla_data.pb.h"
     30 #include "tensorflow/core/platform/logging.h"
     31 #include "tensorflow/core/platform/test.h"
     32 
     33 namespace op = xla::testing::opcode_matchers;
     34 
     35 namespace xla {
     36 namespace {
     37 
     38 using ::testing::UnorderedElementsAre;
     39 using ::testing::UnorderedElementsAreArray;
     40 
     41 class TuplePointsToAnalysisTest : public HloTestBase {
     42  protected:
     43   // Builds a module with the given entry computation and runs points to
     44   // analysis.
     45   void BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation) {
     46     BuildModule(std::move(computation));
     47     RunAnalysis();
     48   }
     49 
     50   void BuildModule(std::unique_ptr<HloComputation> computation) {
     51     module_ = CreateNewModule();
     52     module_->AddEntryComputation(std::move(computation));
     53   }
     54 
     55   void RunAnalysis() {
     56     CHECK_NOTNULL(module_.get());
     57     points_to_analysis_ =
     58         TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
     59   }
     60 
     61   // Returns the LogicalBuffer defined at the given instruction and
     62   // index. CHECKs if no buffer is defined at that point.
     63   const LogicalBuffer* const GetBuffer(const HloInstruction* instruction,
     64                                        const ShapeIndex& index) {
     65     const auto& pointed_to =
     66         points_to_analysis_->GetPointsToSet(instruction).element(index);
     67     CHECK_EQ(1, pointed_to.size());
     68     CHECK_EQ(instruction, pointed_to[0]->instruction());
     69     CHECK(index == pointed_to[0]->index());
     70     return pointed_to[0];
     71   }
     72 
     73   // Checks that the given points-to set contains exactly (unordered) the given
     74   // LogicalBuffers.
     75   void ExpectHasBuffers(
     76       const PointsToSet::BufferList& points_to_set,
     77       tensorflow::gtl::ArraySlice<const LogicalBuffer*> buffers) {
     78     std::vector<const LogicalBuffer*> vec(buffers.begin(), buffers.end());
     79     EXPECT_THAT(points_to_set, UnorderedElementsAreArray(vec));
     80   }
     81 
     82   // Checks that the given points-to set contains exactly (unordered) the
     83   // top-level buffers of the given instructions.
     84   void ExpectHasTopLevelBuffers(
     85       const PointsToSet::BufferList& points_to_set,
     86       tensorflow::gtl::ArraySlice<HloInstruction*> instructions) {
     87     PointsToSet::BufferList buffers;
     88     for (auto instruction : instructions) {
     89       buffers.push_back(GetBuffer(instruction, /*index=*/{}));
     90     }
     91     ExpectHasBuffers(points_to_set, buffers);
     92   }
     93 
     94   // Overload which takes a set instead of a vector.
     95   void ExpectHasTopLevelBuffers(
     96       const PointsToSet::BufferSet& points_to_set,
     97       tensorflow::gtl::ArraySlice<HloInstruction*> instructions) {
     98     ExpectHasTopLevelBuffers(
     99         PointsToSet::BufferList(points_to_set.begin(), points_to_set.end()),
    100         instructions);
    101   }
    102 
    103   // Checks that the buffer defined at the given instruction and index has
    104   // aliases which are exactly (unordered) the given instruction/index pairs.
    105   void ExpectHasBufferAliases(
    106       const HloInstruction* instruction, const ShapeIndex& index,
    107       tensorflow::gtl::ArraySlice<std::pair<HloInstruction*, ShapeIndex>>
    108           expected) {
    109     const LogicalBuffer* buffer =
    110         points_to_analysis_->GetBufferDefinedAt(instruction, index)
    111             .ValueOrDie();
    112     std::vector<BufferAlias> expected_aliases;
    113     for (auto& pair : expected) {
    114       expected_aliases.push_back(BufferAlias(pair.first, pair.second));
    115     }
    116     EXPECT_THAT(points_to_analysis_->GetBufferAliases(*buffer),
    117                 UnorderedElementsAreArray(expected_aliases));
    118   }
    119 
    120   std::unique_ptr<HloModule> module_;
    121   std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
    122 };
    123 
    124 TEST_F(TuplePointsToAnalysisTest, SimpleTuple) {
    125   auto builder = HloComputation::Builder(TestName());
    126   auto constant1 = builder.AddInstruction(
    127       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    128   auto constant2 = builder.AddInstruction(
    129       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
    130   auto tuple = builder.AddInstruction(
    131       HloInstruction::CreateTuple({constant1, constant2}));
    132 
    133   BuildModuleAndRunAnalysis(builder.Build());
    134   EXPECT_EQ(1, points_to_analysis_->GetPointsToSet(constant1).size());
    135   ExpectHasTopLevelBuffers(
    136       points_to_analysis_->GetPointsToSet(constant1).element({}), {constant1});
    137   EXPECT_TRUE(
    138       points_to_analysis_->GetPointsToSet(constant1).tuple_sources({}).empty());
    139   EXPECT_TRUE(points_to_analysis_->GetPointsToSet(tuple).IsDistinct());
    140 
    141   EXPECT_EQ(1, points_to_analysis_->GetPointsToSet(constant2).size());
    142   ExpectHasTopLevelBuffers(
    143       points_to_analysis_->GetPointsToSet(constant2).element({}), {constant2});
    144   EXPECT_TRUE(
    145       points_to_analysis_->GetPointsToSet(constant2).tuple_sources({}).empty());
    146 
    147   EXPECT_EQ(3, points_to_analysis_->GetPointsToSet(tuple).size());
    148   EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous());
    149   EXPECT_THAT(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({}),
    150               UnorderedElementsAre(tuple));
    151 
    152   ExpectHasTopLevelBuffers(
    153       points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(),
    154       {constant1, constant2, tuple});
    155   ExpectHasTopLevelBuffers(
    156       points_to_analysis_->GetPointsToSet(tuple).element({}), {tuple});
    157   ExpectHasTopLevelBuffers(
    158       points_to_analysis_->GetPointsToSet(tuple).element({0}), {constant1});
    159   ExpectHasTopLevelBuffers(
    160       points_to_analysis_->GetPointsToSet(tuple).element({1}), {constant2});
    161 
    162   const PointsToSet& tuple_points_to_set =
    163       points_to_analysis_->GetPointsToSet(tuple);
    164   EXPECT_TRUE(tuple_points_to_set.ContainsBufferAtIndex(
    165       *GetBuffer(constant1, {}), {0}));
    166   EXPECT_TRUE(tuple_points_to_set.ContainsBufferAtIndex(
    167       *GetBuffer(constant2, {}), {1}));
    168   EXPECT_FALSE(tuple_points_to_set.ContainsBufferAtIndex(
    169       *GetBuffer(constant2, {}), {0}));
    170   EXPECT_TRUE(tuple_points_to_set.ContainsBuffer(*GetBuffer(constant1, {})));
    171   EXPECT_TRUE(tuple_points_to_set.ContainsBuffer(*GetBuffer(constant2, {})));
    172 }
    173 
    174 TEST_F(TuplePointsToAnalysisTest, NestedTuple) {
    175   // Create a (nested) tuple containing an inner tuple. The points-to set of the
    176   // outer tuple should contain all elements of the points-to set of the inner
    177   // tuple.
    178   auto builder = HloComputation::Builder(TestName());
    179   auto constant1 = builder.AddInstruction(
    180       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    181   auto constant2 = builder.AddInstruction(
    182       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
    183   auto inner_tuple = builder.AddInstruction(
    184       HloInstruction::CreateTuple({constant1, constant2}));
    185 
    186   auto constant3 = builder.AddInstruction(
    187       HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
    188   auto tuple = builder.AddInstruction(
    189       HloInstruction::CreateTuple({inner_tuple, constant3}));
    190 
    191   BuildModuleAndRunAnalysis(builder.Build());
    192   ExpectHasTopLevelBuffers(
    193       points_to_analysis_->GetPointsToSet(constant1).element({}), {constant1});
    194   ExpectHasTopLevelBuffers(
    195       points_to_analysis_->GetPointsToSet(constant2).element({}), {constant2});
    196   ExpectHasTopLevelBuffers(
    197       points_to_analysis_->GetPointsToSet(constant3).element({}), {constant3});
    198 
    199   EXPECT_EQ(3, points_to_analysis_->GetPointsToSet(inner_tuple).size());
    200   EXPECT_FALSE(points_to_analysis_->GetPointsToSet(inner_tuple).IsAmbiguous());
    201   EXPECT_TRUE(points_to_analysis_->GetPointsToSet(inner_tuple).IsDistinct());
    202   ExpectHasTopLevelBuffers(
    203       points_to_analysis_->GetPointsToSet(inner_tuple).CreateFlattenedSet(),
    204       {constant1, constant2, inner_tuple});
    205   ExpectHasTopLevelBuffers(
    206       points_to_analysis_->GetPointsToSet(inner_tuple).element({}),
    207       {inner_tuple});
    208   EXPECT_THAT(
    209       points_to_analysis_->GetPointsToSet(inner_tuple).tuple_sources({}),
    210       UnorderedElementsAre(inner_tuple));
    211 
    212   EXPECT_EQ(5, points_to_analysis_->GetPointsToSet(tuple).size());
    213   EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous());
    214   ExpectHasTopLevelBuffers(
    215       points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(),
    216       {constant1, constant2, constant3, inner_tuple, tuple});
    217 
    218   EXPECT_THAT(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({}),
    219               UnorderedElementsAre(tuple));
    220   EXPECT_THAT(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({0}),
    221               UnorderedElementsAre(inner_tuple));
    222   EXPECT_TRUE(
    223       points_to_analysis_->GetPointsToSet(tuple).tuple_sources({1}).empty());
    224 
    225   ExpectHasTopLevelBuffers(
    226       points_to_analysis_->GetPointsToSet(tuple).element({0}), {inner_tuple});
    227   ExpectHasTopLevelBuffers(
    228       points_to_analysis_->GetPointsToSet(tuple).element({0, 0}), {constant1});
    229   ExpectHasTopLevelBuffers(
    230       points_to_analysis_->GetPointsToSet(tuple).element({0, 1}), {constant2});
    231   ExpectHasTopLevelBuffers(
    232       points_to_analysis_->GetPointsToSet(tuple).element({1}), {constant3});
    233 }
    234 
    235 TEST_F(TuplePointsToAnalysisTest, GetTupleElement) {
    236   // Create a nested tuple, then extract the inner tuple with GetTupleElement.
    237   // The points-to set of the GetTupleElement should be the same as the inner
    238   // tuple.
    239   auto builder = HloComputation::Builder(TestName());
    240   auto constant1 = builder.AddInstruction(
    241       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    242   auto constant2 = builder.AddInstruction(
    243       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
    244   auto inner_tuple = builder.AddInstruction(
    245       HloInstruction::CreateTuple({constant1, constant2}));
    246 
    247   auto constant3 = builder.AddInstruction(
    248       HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
    249   auto tuple = builder.AddInstruction(
    250       HloInstruction::CreateTuple({inner_tuple, constant3}));
    251 
    252   auto get_tuple_element = builder.AddInstruction(
    253       HloInstruction::CreateGetTupleElement(inner_tuple->shape(), tuple, 0));
    254 
    255   BuildModuleAndRunAnalysis(builder.Build());
    256 
    257   auto& points_to_set = points_to_analysis_->GetPointsToSet(get_tuple_element);
    258   EXPECT_EQ(3, points_to_set.size());
    259   EXPECT_FALSE(points_to_set.IsAmbiguous());
    260   EXPECT_TRUE(points_to_set.IsDistinct());
    261   ExpectHasTopLevelBuffers(points_to_set.CreateFlattenedSet(),
    262                            {constant1, constant2, inner_tuple});
    263   ExpectHasTopLevelBuffers(points_to_set.element({}), {inner_tuple});
    264 
    265   EXPECT_THAT(points_to_set.tuple_sources({}),
    266               UnorderedElementsAre(inner_tuple));
    267 }
    268 
    269 TEST_F(TuplePointsToAnalysisTest, DuplicatedElement) {
    270   // Create a tuple which contains duplicate elements.
    271   auto builder = HloComputation::Builder(TestName());
    272   auto constant = builder.AddInstruction(
    273       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    274   auto tuple = builder.AddInstruction(
    275       HloInstruction::CreateTuple({constant, constant, constant}));
    276 
    277   BuildModuleAndRunAnalysis(builder.Build());
    278 
    279   EXPECT_EQ(2, points_to_analysis_->GetPointsToSet(tuple).size());
    280   EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous());
    281   EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsDistinct());
    282   ExpectHasTopLevelBuffers(
    283       points_to_analysis_->GetPointsToSet(tuple).element({}), {tuple});
    284   ExpectHasTopLevelBuffers(
    285       points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(),
    286       {constant, tuple});
    287 }
    288 
    289 TEST_F(TuplePointsToAnalysisTest, TupleCopy) {
    290   // Create a copy (HloOpcode::kCopy) of a tuple. The points to sets should be
    291   // the same.
    292   auto builder = HloComputation::Builder(TestName());
    293   auto constant1 = builder.AddInstruction(
    294       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    295   auto constant2 = builder.AddInstruction(
    296       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
    297   auto tuple = builder.AddInstruction(
    298       HloInstruction::CreateTuple({constant1, constant2}));
    299   auto copy = builder.AddInstruction(
    300       HloInstruction::CreateUnary(tuple->shape(), HloOpcode::kCopy, tuple));
    301 
    302   BuildModuleAndRunAnalysis(builder.Build());
    303 
    304   EXPECT_FALSE(points_to_analysis_->GetPointsToSet(copy).IsAmbiguous());
    305   EXPECT_TRUE(points_to_analysis_->GetPointsToSet(copy).IsDistinct());
    306   ExpectHasTopLevelBuffers(
    307       points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(),
    308       {constant1, constant2, tuple});
    309   ExpectHasTopLevelBuffers(
    310       points_to_analysis_->GetPointsToSet(copy).element({}), {copy});
    311   ExpectHasTopLevelBuffers(
    312       points_to_analysis_->GetPointsToSet(copy).CreateFlattenedSet(),
    313       {constant1, constant2, copy});
    314 }
    315 
    316 TEST_F(TuplePointsToAnalysisTest, SendAndSendDone) {
    317   // Send forwards its operand to the output tuple at {0}.
    318   auto builder = HloComputation::Builder(TestName());
    319   auto constant = builder.AddInstruction(
    320       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    321   auto send = builder.AddInstruction(
    322       HloInstruction::CreateSend(constant, /*channel_id=*/0));
    323   auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
    324 
    325   BuildModuleAndRunAnalysis(builder.Build());
    326 
    327   EXPECT_FALSE(points_to_analysis_->GetPointsToSet(send).IsAmbiguous());
    328   EXPECT_TRUE(points_to_analysis_->GetPointsToSet(send).IsDistinct());
    329   EXPECT_FALSE(points_to_analysis_->GetPointsToSet(send_done).IsAmbiguous());
    330   EXPECT_TRUE(points_to_analysis_->GetPointsToSet(send_done).IsDistinct());
    331 
    332   ExpectHasTopLevelBuffers(
    333       points_to_analysis_->GetPointsToSet(send).element({}), {send});
    334   ExpectHasTopLevelBuffers(
    335       points_to_analysis_->GetPointsToSet(send).element({0}), {constant});
    336   ExpectHasTopLevelBuffers(
    337       points_to_analysis_->GetPointsToSet(send_done).CreateFlattenedSet(),
    338       {send_done});
    339   ExpectHasBufferAliases(constant, {}, {{constant, {}}, {send, {0}}});
    340 }
    341 
    342 TEST_F(TuplePointsToAnalysisTest, RecvAndRecvDone) {
    343   // RecvDone forwards its operand tuple element at {0} to the output.
    344   auto builder = HloComputation::Builder(TestName());
    345   auto recv = builder.AddInstruction(HloInstruction::CreateRecv(
    346       ShapeUtil::MakeShape(F32, {1, 2, 3}), /*channel_id=*/0));
    347   auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
    348 
    349   BuildModuleAndRunAnalysis(builder.Build());
    350 
    351   EXPECT_FALSE(points_to_analysis_->GetPointsToSet(recv).IsAmbiguous());
    352   EXPECT_TRUE(points_to_analysis_->GetPointsToSet(recv).IsDistinct());
    353   EXPECT_FALSE(points_to_analysis_->GetPointsToSet(recv_done).IsAmbiguous());
    354   EXPECT_TRUE(points_to_analysis_->GetPointsToSet(recv_done).IsDistinct());
    355 
    356   ExpectHasTopLevelBuffers(
    357       points_to_analysis_->GetPointsToSet(recv).element({}), {recv});
    358   ExpectHasBufferAliases(recv, {0}, {{recv, {0}}, {recv_done, {}}});
    359 }
    360 
    361 TEST_F(TuplePointsToAnalysisTest, TupleSelect) {
    362   // Select from two different tuples. This should create an ambiguous points to
    363   // set containing the union of both sides.
    364   auto builder = HloComputation::Builder(TestName());
    365   auto constant1 = builder.AddInstruction(
    366       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    367   auto constant2 = builder.AddInstruction(
    368       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
    369   auto tuple1 = builder.AddInstruction(
    370       HloInstruction::CreateTuple({constant1, constant2}));
    371   auto tuple2 = builder.AddInstruction(
    372       HloInstruction::CreateTuple({constant2, constant2}));
    373 
    374   auto pred = builder.AddInstruction(
    375       HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
    376   auto select = builder.AddInstruction(HloInstruction::CreateTernary(
    377       tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2));
    378 
    379   BuildModuleAndRunAnalysis(builder.Build());
    380 
    381   auto& points_to_set = points_to_analysis_->GetPointsToSet(select);
    382   EXPECT_EQ(3, points_to_set.size());
    383   EXPECT_TRUE(points_to_set.IsAmbiguous());
    384   EXPECT_FALSE(points_to_set.IsDistinct());
    385   ExpectHasTopLevelBuffers(points_to_set.element({}), {select});
    386   ExpectHasTopLevelBuffers(points_to_set.element({0}), {constant1, constant2});
    387   ExpectHasTopLevelBuffers(points_to_set.element({1}), {constant2});
    388   ExpectHasTopLevelBuffers(points_to_set.CreateFlattenedSet(),
    389                            {constant1, constant2, select});
    390 }
    391 
    392 TEST_F(TuplePointsToAnalysisTest, SelectTupleParameters) {
    393   // Create a Select which selects between two tuple parameters. Verify the
    394   // points-to sets and tuple sources are properly set.
    395   Shape tuple_shape = ShapeUtil::MakeTupleShape(
    396       {ShapeUtil::MakeShape(F32, {1, 2, 3}), ShapeUtil::MakeShape(U32, {5})});
    397 
    398   auto builder = HloComputation::Builder(TestName());
    399   auto param0 = builder.AddInstruction(
    400       HloInstruction::CreateParameter(0, tuple_shape, "param0"));
    401   auto param1 = builder.AddInstruction(
    402       HloInstruction::CreateParameter(1, tuple_shape, "param1"));
    403   auto pred = builder.AddInstruction(
    404       HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
    405   auto select = builder.AddInstruction(HloInstruction::CreateTernary(
    406       tuple_shape, HloOpcode::kSelect, pred, param0, param1));
    407   auto copy = builder.AddInstruction(
    408       HloInstruction::CreateUnary(tuple_shape, HloOpcode::kCopy, select));
    409 
    410   BuildModuleAndRunAnalysis(builder.Build());
    411 
    412   // The points-to set of each element of a tuple parameters should be itself
    413   // with the appropriate index.
    414   ExpectHasBuffers(points_to_analysis_->GetPointsToSet(param0).element({}),
    415                    {GetBuffer(param0, {})});
    416   ExpectHasBuffers(points_to_analysis_->GetPointsToSet(param0).element({0}),
    417                    {GetBuffer(param0, {0})});
    418   ExpectHasBuffers(points_to_analysis_->GetPointsToSet(param0).element({1}),
    419                    {GetBuffer(param0, {1})});
    420 
    421   // Select's point-to set of its subelements should be the respective
    422   // subelements of param0 and param1. The top-level buffer, however, does not
    423   // alias as it is created by the select instruction.
    424   ExpectHasBuffers(points_to_analysis_->GetPointsToSet(select).element({}),
    425                    {GetBuffer(select, {})});
    426   ExpectHasBuffers(points_to_analysis_->GetPointsToSet(select).element({0}),
    427                    {GetBuffer(param0, {0}), GetBuffer(param1, {0})});
    428   ExpectHasBuffers(points_to_analysis_->GetPointsToSet(select).element({1}),
    429                    {GetBuffer(param0, {1}), GetBuffer(param1, {1})});
    430 
    431   // Copy should be identical to select other than the top-level buffer.
    432   ExpectHasBuffers(points_to_analysis_->GetPointsToSet(copy).element({}),
    433                    {GetBuffer(copy, {})});
    434   ExpectHasBuffers(points_to_analysis_->GetPointsToSet(copy).element({0}),
    435                    {GetBuffer(param0, {0}), GetBuffer(param1, {0})});
    436   ExpectHasBuffers(points_to_analysis_->GetPointsToSet(copy).element({1}),
    437                    {GetBuffer(param0, {1}), GetBuffer(param1, {1})});
    438 }
    439 
    440 TEST_F(TuplePointsToAnalysisTest, UnambiguousTupleSelect) {
    441   // Select from two identical tuples. The result should not be ambiguous.
    442   auto builder = HloComputation::Builder(TestName());
    443   auto constant1 = builder.AddInstruction(
    444       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    445   auto constant2 = builder.AddInstruction(
    446       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
    447   auto tuple1 = builder.AddInstruction(
    448       HloInstruction::CreateTuple({constant1, constant2}));
    449   auto tuple2 = builder.AddInstruction(
    450       HloInstruction::CreateTuple({constant1, constant2}));
    451 
    452   auto pred = builder.AddInstruction(
    453       HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
    454   auto select = builder.AddInstruction(HloInstruction::CreateTernary(
    455       tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2));
    456 
    457   BuildModuleAndRunAnalysis(builder.Build());
    458 
    459   auto& points_to_set = points_to_analysis_->GetPointsToSet(select);
    460   EXPECT_EQ(3, points_to_set.size());
    461   EXPECT_FALSE(points_to_set.IsAmbiguous());
    462   EXPECT_TRUE(points_to_set.IsDistinct());
    463   ExpectHasTopLevelBuffers(points_to_set.element({}), {select});
    464   ExpectHasTopLevelBuffers(points_to_set.element({0}), {constant1});
    465   ExpectHasTopLevelBuffers(points_to_set.element({1}), {constant2});
    466   ExpectHasTopLevelBuffers(points_to_set.CreateFlattenedSet(),
    467                            {constant1, constant2, select});
    468 }
    469 
    470 TEST_F(TuplePointsToAnalysisTest, NestedTupleSelect) {
    471   // Select from nested tuples. Verify that the nested points-to sets contain
    472   // the right values.
    473   auto builder = HloComputation::Builder(TestName());
    474   auto constant1 = builder.AddInstruction(
    475       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    476   auto constant2 = builder.AddInstruction(
    477       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
    478   auto inner_tuple1 = builder.AddInstruction(
    479       HloInstruction::CreateTuple({constant1, constant2}));
    480   auto inner_tuple2 = builder.AddInstruction(
    481       HloInstruction::CreateTuple({constant2, constant2}));
    482 
    483   auto tuple1 =
    484       builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple1}));
    485   auto tuple2 =
    486       builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple2}));
    487 
    488   auto pred = builder.AddInstruction(
    489       HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
    490   auto select = builder.AddInstruction(HloInstruction::CreateTernary(
    491       tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2));
    492 
    493   BuildModuleAndRunAnalysis(builder.Build());
    494 
    495   auto& points_to_set = points_to_analysis_->GetPointsToSet(select);
    496   EXPECT_EQ(5, points_to_set.size());
    497   EXPECT_TRUE(points_to_set.IsAmbiguous());
    498   EXPECT_FALSE(points_to_set.IsDistinct());
    499 
    500   // Verify points-to set.
    501   ExpectHasTopLevelBuffers(points_to_set.element({}), {select});
    502   ExpectHasTopLevelBuffers(points_to_set.element({0}),
    503                            {inner_tuple1, inner_tuple2});
    504   ExpectHasTopLevelBuffers(points_to_set.element({0, 0}),
    505                            {constant1, constant2});
    506   ExpectHasTopLevelBuffers(points_to_set.element({0, 1}), {constant2});
    507 
    508   // Verify tuple sources.
    509   EXPECT_THAT(points_to_set.tuple_sources({}),
    510               UnorderedElementsAre(tuple1, tuple2));
    511   EXPECT_THAT(points_to_set.tuple_sources({0}),
    512               UnorderedElementsAre(inner_tuple1, inner_tuple2));
    513   EXPECT_EQ(0, points_to_set.tuple_sources({0, 0}).size());
    514   EXPECT_EQ(0, points_to_set.tuple_sources({0, 1}).size());
    515 }
    516 
    517 TEST_F(TuplePointsToAnalysisTest, TupleWithBitcast) {
    518   // Bitcast is an alias of its operand. A tuple with a bitcast element should
    519   // have the operand of the bitcast in its points-to set.
    520   auto builder = HloComputation::Builder(TestName());
    521   auto constant1 = builder.AddInstruction(
    522       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    523   auto constant2 = builder.AddInstruction(
    524       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
    525   auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
    526       constant2->shape(), HloOpcode::kBitcast, constant2));
    527   auto tuple =
    528       builder.AddInstruction(HloInstruction::CreateTuple({constant1, bitcast}));
    529 
    530   BuildModuleAndRunAnalysis(builder.Build());
    531 
    532   EXPECT_EQ(1, points_to_analysis_->GetPointsToSet(bitcast).size());
    533   ExpectHasTopLevelBuffers(
    534       points_to_analysis_->GetPointsToSet(bitcast).element({}), {constant2});
    535   EXPECT_TRUE(
    536       points_to_analysis_->GetPointsToSet(bitcast).tuple_sources({}).empty());
    537 
    538   EXPECT_EQ(3, points_to_analysis_->GetPointsToSet(tuple).size());
    539   EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous());
    540   EXPECT_THAT(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({}),
    541               UnorderedElementsAre(tuple));
    542 
    543   ExpectHasTopLevelBuffers(
    544       points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(),
    545       {constant1, constant2, tuple});
    546   ExpectHasTopLevelBuffers(
    547       points_to_analysis_->GetPointsToSet(tuple).element({}), {tuple});
    548   ExpectHasTopLevelBuffers(
    549       points_to_analysis_->GetPointsToSet(tuple).element({0}), {constant1});
    550   ExpectHasTopLevelBuffers(
    551       points_to_analysis_->GetPointsToSet(tuple).element({1}), {constant2});
    552 }
    553 
    554 TEST_F(TuplePointsToAnalysisTest, PointsToTupleConstantElements) {
    555   // Construct a tuple constant and kCopy it. Verify the points-to set of the
    556   // copy correctly correctly points into the nested elements of the constant.
    557   auto builder = HloComputation::Builder(TestName());
    558   auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant(
    559       Literal::MakeTuple({Literal::CreateR2<float>({{1.0}, {2.0}}).get(),
    560                           Literal::CreateR1<float>({2.0, 42}).get()})));
    561   auto copy = builder.AddInstruction(HloInstruction::CreateUnary(
    562       tuple_constant->shape(), HloOpcode::kCopy, tuple_constant));
    563 
    564   BuildModuleAndRunAnalysis(builder.Build());
    565 
    566   auto& points_to_set = points_to_analysis_->GetPointsToSet(copy);
    567 
    568   ExpectHasBuffers(points_to_set.element({}), {GetBuffer(copy, {})});
    569   ExpectHasBuffers(points_to_set.element({0}),
    570                    {GetBuffer(tuple_constant, {0})});
    571   ExpectHasBuffers(points_to_set.element({1}),
    572                    {GetBuffer(tuple_constant, {1})});
    573 }
    574 
    575 TEST_F(TuplePointsToAnalysisTest, BufferAliases) {
    576   // Create a nested tuple in which individual elements appear multiple
    577   // times. Verify buffer alias sets.
    578   auto builder = HloComputation::Builder(TestName());
    579   auto constant1 = builder.AddInstruction(
    580       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    581   auto constant2 = builder.AddInstruction(
    582       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
    583   auto inner_tuple = builder.AddInstruction(
    584       HloInstruction::CreateTuple({constant1, constant2}));
    585   auto tuple = builder.AddInstruction(
    586       HloInstruction::CreateTuple({inner_tuple, constant2}));
    587 
    588   BuildModuleAndRunAnalysis(builder.Build());
    589 
    590   ExpectHasBufferAliases(
    591       constant1, /*index=*/{},
    592       {{constant1, {}}, {inner_tuple, {0}}, {tuple, {0, 0}}});
    593   ExpectHasBufferAliases(
    594       constant2, /*index=*/{},
    595       {{constant2, {}}, {inner_tuple, {1}}, {tuple, {0, 1}}, {tuple, {1}}});
    596   ExpectHasBufferAliases(inner_tuple, /*index=*/{},
    597                          {{inner_tuple, {}}, {tuple, {0}}});
    598   ExpectHasBufferAliases(tuple, /*index=*/{}, {{tuple, {}}});
    599 }
    600 
    601 class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest {
    602  protected:
    603   // Builds a computation, runs instruction fusion HloPass, runs points-to
    604   // analysis, then checks for expected results (see unit test cases for
    605   // example computation graphs).
    606   void Run(const bool add_additional_gte0_user) {
    607     Shape input_shape = ShapeUtil::MakeShape(F32, {8});
    608     Shape update_shape = ShapeUtil::MakeShape(F32, {3});
    609     Shape starts_shape = ShapeUtil::MakeShape(S32, {1});
    610     Shape tuple_shape =
    611         ShapeUtil::MakeTupleShape({input_shape, update_shape, starts_shape});
    612 
    613     auto builder = HloComputation::Builder(TestName());
    614     // Create tuple-shaped parameter.
    615     auto tuple_param0 = builder.AddInstruction(
    616         HloInstruction::CreateParameter(0, tuple_shape, "param0"));
    617     // Create 'tuple_element1' = GetTupleElement(tuple_param0, 1).
    618     auto tuple_element1 = builder.AddInstruction(
    619         HloInstruction::CreateGetTupleElement(update_shape, tuple_param0, 1));
    620     auto ones = builder.AddInstruction(HloInstruction::CreateConstant(
    621         Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f})));
    622     // Create 'update' = Add(GetTupleElement(tuple_param0, 1), ones)
    623     auto update = builder.AddInstruction(HloInstruction::CreateBinary(
    624         update_shape, HloOpcode::kAdd, tuple_element1, ones));
    625     // Create 'input' = GetTupleElement(tuple_param0, 0).
    626     auto input = builder.AddInstruction(
    627         HloInstruction::CreateGetTupleElement(input_shape, tuple_param0, 0));
    628 
    629     if (add_additional_gte0_user) {
    630       // Create 'slice' as an additional user of 'input'.
    631       auto slice = builder.AddInstruction(
    632           HloInstruction::CreateSlice(update_shape, input, {0}, {3}, {1}));
    633       // Modify 'update' to take 'slice' output.
    634       update = builder.AddInstruction(HloInstruction::CreateBinary(
    635           update_shape, HloOpcode::kAdd, update, slice));
    636     }
    637 
    638     // Create slice 'starts' = GetTupleElement(tuple_param0, 2).
    639     auto starts = builder.AddInstruction(
    640         HloInstruction::CreateGetTupleElement(starts_shape, tuple_param0, 2));
    641     // Update 'input' with 'update' at dynamic 'starts' indices.
    642     builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
    643         input_shape, input, update, starts));
    644 
    645     // Build computation and add it to module as entry computation.
    646     BuildModule(builder.Build());
    647     // Run instruction fusion HloPass.
    648     EXPECT_TRUE(InstructionFusion(InstructionFusion::IsExpensive)
    649                     .Run(module_.get())
    650                     .ValueOrDie());
    651     // Get computation root instruction (should be a kFusion).
    652     auto* fusion = module_->entry_computation()->root_instruction();
    653     EXPECT_THAT(fusion, op::Fusion(tuple_param0));
    654     // Run points-to analysis (should include fused instructions from 'fusion').
    655     RunAnalysis();
    656 
    657     // Check points-to set of fusion parameter associated with 'tuple_param0'.
    658     auto* fusion_param = GetFusionParameterForOperand(fusion, tuple_param0);
    659     ExpectHasBuffers(
    660         points_to_analysis_->GetPointsToSet(fusion_param).element({}),
    661         {GetBuffer(fusion_param, {})});
    662     ExpectHasBuffers(
    663         points_to_analysis_->GetPointsToSet(fusion_param).element({0}),
    664         {GetBuffer(fusion_param, {0})});
    665     ExpectHasBuffers(
    666         points_to_analysis_->GetPointsToSet(fusion_param).element({1}),
    667         {GetBuffer(fusion_param, {1})});
    668     ExpectHasBuffers(
    669         points_to_analysis_->GetPointsToSet(fusion_param).element({2}),
    670         {GetBuffer(fusion_param, {2})});
    671 
    672     // Check that Gte at tuple_index = 0 points-to fusion_param({0})
    673     auto fused_gte0 = GetUniqueFusionParameterUserAt(fusion_param, 0);
    674     ExpectHasBuffers(
    675         points_to_analysis_->GetPointsToSet(fused_gte0).element({}),
    676         {GetBuffer(fusion_param, {0})});
    677     // Check that Gte at tuple_index = 1 points-to fusion_param({1})
    678     auto fused_gte1 = GetUniqueFusionParameterUserAt(fusion_param, 1);
    679     ExpectHasBuffers(
    680         points_to_analysis_->GetPointsToSet(fused_gte1).element({}),
    681         {GetBuffer(fusion_param, {1})});
    682     // Check that Gte at tuple_index = 2 points-to fusion_param({2})
    683     auto fused_gte2 = GetUniqueFusionParameterUserAt(fusion_param, 2);
    684     ExpectHasBuffers(
    685         points_to_analysis_->GetPointsToSet(fused_gte2).element({}),
    686         {GetBuffer(fusion_param, {2})});
    687 
    688     // Check buffer aliases of 'fusion_param' at shape index {0}.
    689     ExpectHasBufferAliases(fusion_param, /*index=*/{0},
    690                            {{fusion_param, {0}}, {fused_gte0, {}}});
    691     // Check buffer aliases of 'fusion_param' at shape index {1}.
    692     ExpectHasBufferAliases(fusion_param, /*index=*/{1},
    693                            {{fusion_param, {1}}, {fused_gte1, {}}});
    694     // Check buffer aliases of 'fusion_param' at shape index {2}.
    695     ExpectHasBufferAliases(fusion_param, /*index=*/{2},
    696                            {{fusion_param, {2}}, {fused_gte2, {}}});
    697 
    698     // Check number of users of 'fusion_param' aliases at shape index {0}.
    699     ExpectNumUsersOfAliases(fusion_param, {0},
    700                             add_additional_gte0_user ? 2 : 1);
    701   }
    702 
    703   // Returns fusion parameter (from 'fusion.fused_instructions') corresponding
    704   // to fusion 'operand'.
    705   HloInstruction* GetFusionParameterForOperand(HloInstruction* fusion,
    706                                                HloInstruction* operand) {
    707     auto it = std::find_if(
    708         fusion->fused_instructions().begin(),
    709         fusion->fused_instructions().end(), [=](const HloInstruction* fused) {
    710           return fused->opcode() == HloOpcode::kParameter &&
    711                  fusion->operand(fused->parameter_number()) == operand;
    712         });
    713     CHECK(it != fusion->fused_instructions().end());
    714     return *it;
    715   }
    716 
    717   // Returns all users of 'fusion_paran' at 'tuple_index'.
    718   std::vector<HloInstruction*> GetFusionParameterUsersAt(
    719       HloInstruction* fusion_param, int64 tuple_index) {
    720     CHECK(ShapeUtil::IsTuple(fusion_param->shape()));
    721     std::vector<HloInstruction*> users_at_tuple_index;
    722     for (auto user : fusion_param->users()) {
    723       CHECK_EQ(HloOpcode::kGetTupleElement, user->opcode());
    724       if (user->tuple_index() == tuple_index) {
    725         users_at_tuple_index.push_back(user);
    726       }
    727     }
    728     return users_at_tuple_index;
    729   }
    730 
    731   // Returns the unique user of 'fusion_param' at 'tuple_index'.
    732   HloInstruction* GetUniqueFusionParameterUserAt(HloInstruction* fusion_param,
    733                                                  int64 tuple_index) {
    734     std::vector<HloInstruction*> users =
    735         GetFusionParameterUsersAt(fusion_param, tuple_index);
    736     CHECK_EQ(1, users.size());
    737     return users[0];
    738   }
    739 
    740   // Checks that the count of all users of all aliases of 'instruction' at
    741   // 'index' match 'expected_num_users'.
    742   void ExpectNumUsersOfAliases(const HloInstruction* instruction,
    743                                const ShapeIndex& index,
    744                                const int64 expected_num_users) {
    745     const auto* buffer = GetBuffer(instruction, index);
    746     int64 num_users = 0;
    747     for (const auto& alias : points_to_analysis_->GetBufferAliases(*buffer)) {
    748       for (auto user : alias.instruction()->users()) {
    749         if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) {
    750           // Gte instructions only access the top-level buffer of their operand.
    751           continue;
    752         }
    753         ++num_users;
    754       }
    755     }
    756     EXPECT_EQ(expected_num_users, num_users);
    757   }
    758 };
    759 
    760 // Tests the points-to set of tuple-shaped fusion parameter 0 and all GTE users.
    761 // Tests the alias set of tuple-shaped fusion parameter 0 at all shape indices.
    762 // Tests that there is a single user of the aliases of tuple-shaped fusion
    763 // parameter 0 at shape index {0}.
    764 //
    765 //             Param0    Const
    766 //                 \      /
    767 //                  Fusion
    768 //                 /      \
    769 //        FusionParam0   FusionParam1
    770 //        /     |    \       |
    771 //     Gte(0) Gte(2) Gte(1)  /
    772 //        \     |      \    /
    773 //         \    |       Add
    774 //          \   |        /
    775 //           \0 |2      /1
    776 //          DynamicUpdateSlice  // fused root.
    777 //
    778 TEST_F(FusionPointsToAnalysisTest, FusionParam0OneUser) {
    779   Run(/*add_additional_gte0_user=*/false);
    780 }
    781 
    782 // Tests the points-to set of tuple-shaped fusion parameter 0 and all GTE users.
    783 // Tests the alias set of tuple-shaped fusion parameter 0 at all shape indices.
    784 // Tests that there are two users of the aliases of tuple-shaped fusion
    785 // parameter 0 at shape index {0}.
    786 //
    787 //             Param0    Const
    788 //                 \      /
    789 //                  Fusion
    790 //                 /      \
    791 //        FusionParam0   FusionParam1
    792 //        /     |    \       |
    793 //     Gte(2) Gte(0) Gte(1)  /
    794 //        \     |      \    /
    795 //         \    |\      Add
    796 //          \   | \      /
    797 //           |  | Slice /
    798 //           |  |   \  /
    799 //           |  |   Add
    800 //           |  |    |
    801 //           |2 |0   |1
    802 //          DynamicUpdateSlice  // fused root.
    803 //
    804 TEST_F(FusionPointsToAnalysisTest, FusionParam0TwoUsers) {
    805   Run(/*add_additional_gte0_user=*/true);
    806 }
    807 
    808 }  // namespace
    809 }  // namespace xla
    810