Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
     17 
     18 #include <memory>
     19 #include <set>
     20 #include <string>
     21 #include <utility>
     22 #include <vector>
     23 
     24 #include "tensorflow/compiler/xla/literal_util.h"
     25 #include "tensorflow/compiler/xla/ptr_util.h"
     26 #include "tensorflow/compiler/xla/service/call_graph.h"
     27 #include "tensorflow/compiler/xla/service/computation_tracker.h"
     28 #include "tensorflow/compiler/xla/service/copy_insertion.h"
     29 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
     30 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
     31 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     32 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     33 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     34 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
     35 #include "tensorflow/compiler/xla/service/hlo_scheduling.h"
     36 #include "tensorflow/compiler/xla/shape_util.h"
     37 #include "tensorflow/compiler/xla/test.h"
     38 #include "tensorflow/compiler/xla/test_helpers.h"
     39 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     40 #include "tensorflow/compiler/xla/types.h"
     41 #include "tensorflow/compiler/xla/xla_data.pb.h"
     42 #include "tensorflow/core/platform/macros.h"
     43 
     44 namespace xla {
     45 
     46 namespace {
     47 
     48 // DFS visitor that collects the instructions referenced by a computation
     49 // without descending into nested computations, i.e., only from the operands.
     50 class InstructionListVisitor : public DfsHloVisitorWithDefault {
     51  public:
     52   explicit InstructionListVisitor(const HloInstruction* root) : root_(root) {}
     53 
     54   Status DefaultAction(HloInstruction* hlo) override {
     55     // For each instruction, just push it on the list after walking the
     56     // operands.
     57     instructions_.push_back(hlo);
     58     VLOG(0) << "List instruction " << hlo->ToString();
     59     return Status::OK();
     60   }
     61 
     62   std::vector<const HloInstruction*> GetInstructions() { return instructions_; }
     63 
     64  private:
     65   // The instruction root of the computation.
     66   const HloInstruction* root_;
     67 
     68   // The full set of instructions found (may be duplicates, e.g., kParameter).
     69   std::vector<const HloInstruction*> instructions_;
     70 
     71   TF_DISALLOW_COPY_AND_ASSIGN(InstructionListVisitor);
     72 };
     73 
     74 const std::vector<const HloInstruction*> GetInstructions(HloInstruction* root) {
     75   InstructionListVisitor main_list(root);
     76   TF_CHECK_OK(root->Accept(&main_list));
     77   return main_list.GetInstructions();
     78 }
     79 
     80 class BufferAssignmentTest : public HloTestBase {
     81  protected:
     82   BufferAssignmentTest() : computation_tracker_() {}
     83   ~BufferAssignmentTest() override {}
     84 
     85   std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
     86                                                         int64 alignment = 1) {
     87     return BufferAssigner::Run(
     88                module, xla::MakeUnique<DependencyHloOrdering>(module),
     89                backend().compiler()->BufferSizeBytesFunction(),
     90                [alignment](LogicalBuffer::Color) { return alignment; })
     91         .ConsumeValueOrDie();
     92   }
     93 
     94   std::unique_ptr<BufferAssignment> RunColoredBufferAssignment(
     95       HloModule* module, BufferLiveness::Colorer colorer, int64 alignment = 1) {
     96     return BufferAssigner::Run(
     97                module, xla::MakeUnique<DependencyHloOrdering>(module),
     98                backend().compiler()->BufferSizeBytesFunction(),
     99                [alignment](LogicalBuffer::Color) { return alignment; }, false,
    100                std::move(colorer))
    101         .ConsumeValueOrDie();
    102   }
    103 
    104   // Builds an x+1.0 computation to use in a Map.
    105   std::unique_ptr<HloComputation> BuildMapComputationPlus1(const string& name) {
    106     auto builder = HloComputation::Builder(name);
    107     auto param =
    108         builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
    109     auto value = builder.AddInstruction(
    110         HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    111     builder.AddInstruction(
    112         HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param, value));
    113     return builder.Build();
    114   }
    115 
    116   // Builds a simple compare-to-limit (x < 4) computation for a While.
    117   //
    118   // condition:
    119   //   const4[s32] -----------------------------------\
    120   //                                                   \
    121   //   param[(s32,f32[4])] --- get-tuple-element[0] --- less-than
    122   //
    123   std::unique_ptr<HloComputation> BuildWhileConditionComputation(
    124       const string& name) {
    125     auto builder = HloComputation::Builder(name);
    126     auto const4 = builder.AddInstruction(
    127         HloInstruction::CreateConstant(Literal::CreateR0<int>(4)));
    128     auto param = builder.AddInstruction(
    129         HloInstruction::CreateParameter(0, t_s32_f32v4_, "x"));
    130     auto index = builder.AddInstruction(
    131         HloInstruction::CreateGetTupleElement(const4->shape(), param, 0));
    132     builder.AddInstruction(
    133         HloInstruction::CreateBinary(r0f32_, HloOpcode::kLt, index, const4));
    134     return builder.Build();
    135   }
    136 
    137   // Builds a simple body computation for a While.
    138   //
    139   // body:
    140   //   constv[f32[4]] --------------------------------------\
    141   //                                                         \
    142   //                           /--- get-tuple-elementv[1] --- addv ---\
    143   //   param[(s32,f32[4])] ---|                                    tuple
    144   //                           \--- get-tuple-elementc[0] --- addc ---/
    145   //                                                         /
    146   //   const1[s32] -----------------------------------------/
    147   //
    148   std::unique_ptr<HloComputation> BuildWhileBodyComputation(
    149       const string& name) {
    150     auto builder = HloComputation::Builder(name);
    151     auto const1 = builder.AddInstruction(
    152         HloInstruction::CreateConstant(Literal::CreateR0<int>(1)));
    153     auto constv = builder.AddInstruction(HloInstruction::CreateConstant(
    154         Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
    155     auto param = builder.AddInstruction(
    156         HloInstruction::CreateParameter(0, t_s32_f32v4_, "x"));
    157     auto indexc = builder.AddInstruction(
    158         HloInstruction::CreateGetTupleElement(const1->shape(), param, 0));
    159     auto addc = builder.AddInstruction(HloInstruction::CreateBinary(
    160         indexc->shape(), HloOpcode::kAdd, indexc, const1));
    161     auto indexv = builder.AddInstruction(
    162         HloInstruction::CreateGetTupleElement(constv->shape(), param, 1));
    163     auto addv = builder.AddInstruction(HloInstruction::CreateBinary(
    164         constv->shape(), HloOpcode::kAdd, indexv, constv));
    165     builder.AddInstruction(HloInstruction::CreateTuple({addc, addv}));
    166     return builder.Build();
    167   }
    168 
    169   std::unique_ptr<HloComputation> BuildR0F32UnaryOpComputation(
    170       HloOpcode opcode, const string& name) {
    171     auto builder = HloComputation::Builder(name);
    172     auto param =
    173         builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
    174     builder.AddInstruction(HloInstruction::CreateUnary(r0f32_, opcode, param));
    175     return builder.Build();
    176   }
    177 
    178   // Verifies that the given instruction hlo has a valid input buffer assigned,
    179   // i.e., the parameter number matches the op's.
    180   const BufferAllocation& GetAssignedInputAllocation(
    181       const BufferAssignment& buffers, HloInstruction* hlo) {
    182     LOG(INFO) << "Checking input: " << hlo->ToString();
    183     const BufferAllocation& buffer =
    184         *buffers.GetUniqueTopLevelSlice(hlo).ConsumeValueOrDie().allocation();
    185     EXPECT_EQ(hlo->parameter_number(), buffer.parameter_number());
    186     return buffer;
    187   }
    188 
    189   // Verifies that the given instruction hlo has a valid output buffer
    190   // assigned, and returns it.
    191   const BufferAllocation& GetAssignedOutputAllocation(
    192       const BufferAssignment& buffers, HloInstruction* hlo) {
    193     LOG(INFO) << "Checking output: " << hlo->ToString();
    194     const BufferAllocation& buffer = GetTopLevelAllocation(buffers, hlo);
    195     return buffer;
    196   }
    197 
    198   // Returns the allocation for the given instruction.
    199   const BufferAllocation& GetAllocation(const BufferAssignment& buffers,
    200                                         const HloInstruction* hlo,
    201                                         const ShapeIndex& index) {
    202     return *buffers.GetUniqueSlice(hlo, index).ConsumeValueOrDie().allocation();
    203   }
    204   const BufferAllocation& GetTopLevelAllocation(const BufferAssignment& buffers,
    205                                                 const HloInstruction* hlo) {
    206     return *buffers.GetUniqueTopLevelSlice(hlo)
    207                 .ConsumeValueOrDie()
    208                 .allocation();
    209   }
    210 
    211   // Verifies that all instructions in the given instruction list except
    212   // kConstant have assigned buffers, and returns their total size. If min_index
    213   // and max_index are not nullptr, the minimum and maximum buffer indices in
    214   // the assignment are written into them.
    215   int64 ValidateBuffers(const std::vector<const HloInstruction*>& instructions,
    216                         const BufferAssignment& buffers) {
    217     // Verifies all instructions have buffers, and gets the index ranges.
    218     for (const HloInstruction* hlo : instructions) {
    219       if (!buffers.HasTopLevelAllocation(hlo)) {
    220         // If `hlo` has no assigned buffer, it is either a constant or a nested
    221         // parameter.
    222         EXPECT_TRUE(HloOpcode::kConstant == hlo->opcode() ||
    223                     HloOpcode::kParameter == hlo->opcode());
    224         continue;
    225       }
    226     }
    227 
    228     // Gets the total size of all buffers assigned.
    229     int64 total_size = 0;
    230     for (auto& allocation : buffers.Allocations()) {
    231       total_size += allocation.size();
    232     }
    233     return total_size;
    234   }
    235 
    236   // Computation tracker for nested computations.
    237   ComputationTracker computation_tracker_;
    238 
    239   // Shapes for use in the examples.
    240   Shape s32_ = ShapeUtil::MakeShape(xla::S32, {});
    241   Shape r0f32_ = ShapeUtil::MakeShape(xla::F32, {});
    242   Shape f32vec4_ = ShapeUtil::MakeShape(F32, {4});
    243   Shape f32vec10_ = ShapeUtil::MakeShape(F32, {10});
    244   Shape f32vec100_ = ShapeUtil::MakeShape(F32, {100});
    245   Shape f32a100x10_ = ShapeUtil::MakeShape(F32, {100, 10});
    246   Shape t_s32_f32v4_ = ShapeUtil::MakeTupleShape({s32_, f32vec4_});
    247   Shape t_s32_f32v10_ = ShapeUtil::MakeTupleShape({s32_, f32vec10_});
    248 };
    249 
    250 // Returns true if the buffers assigned to instructions in "a" are distinct
    251 // from the buffers assigned to those in "b" (ie, intersection is empty).
    252 static bool BuffersDistinct(const std::vector<const HloInstruction*>& a,
    253                             const std::vector<const HloInstruction*>& b,
    254                             const BufferAssignment& assignment) {
    255   std::set<BufferAllocation::Slice> a_slices;
    256   for (const HloInstruction* instruction : a) {
    257     if (assignment.HasTopLevelAllocation(instruction)) {
    258       a_slices.insert(
    259           assignment.GetUniqueTopLevelSlice(instruction).ConsumeValueOrDie());
    260     }
    261   }
    262 
    263   for (const HloInstruction* instruction : b) {
    264     if (assignment.HasTopLevelAllocation(instruction)) {
    265       if (a_slices.count(assignment.GetUniqueTopLevelSlice(instruction)
    266                              .ConsumeValueOrDie())) {
    267         return false;
    268       }
    269     }
    270   }
    271   return true;
    272 }
    273 
    274 // Tests a computation consisting of a single scalar constant node.
    275 TEST_F(BufferAssignmentTest, ScalarConstant) {
    276   auto builder = HloComputation::Builder(TestName());
    277   auto const0 = builder.AddInstruction(
    278       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    279   auto module = CreateNewModule();
    280   module->AddEntryComputation(builder.Build());
    281 
    282   auto buffers = RunBufferAssignment(module.get());
    283   // Check that the constant does not have a buffer assigned.
    284   EXPECT_FALSE(buffers->HasTopLevelAllocation(const0));
    285 }
    286 
    287 TEST_F(BufferAssignmentTest, BufferForConst) {
    288   // Addition of two vector constants: checks that internal constant nodes have
    289   // no buffers assigned, and their consumer has a buffer.
    290   auto builder = HloComputation::Builder(TestName());
    291   auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
    292       Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
    293   auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
    294       Literal::CreateR1<float>({4.1f, 4.2f, 4.3f, 4.4f})));
    295   auto add = builder.AddInstruction(
    296       HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, const0, const1));
    297   auto module = CreateNewModule();
    298   module->AddEntryComputation(builder.Build());
    299 
    300   auto buffers = RunBufferAssignment(module.get());
    301   // The two constant nodes have no buffers assigned.
    302   EXPECT_FALSE(buffers->HasTopLevelAllocation(const0));
    303   EXPECT_FALSE(buffers->HasTopLevelAllocation(const1));
    304   // The add node has an output buffer.
    305   GetAssignedOutputAllocation(*buffers, add);
    306 }
    307 
    308 TEST_F(BufferAssignmentTest, HasAllocationAt) {
    309   // Create a tuple with non-const and const elements and check that
    310   // HasAllocationAt works correctly.
    311   auto builder = HloComputation::Builder(TestName());
    312   auto param0 = builder.AddInstruction(
    313       HloInstruction::CreateParameter(0, f32vec100_, "param0"));
    314   auto constant = builder.AddInstruction(
    315       HloInstruction::CreateConstant(Literal::CreateR0<int>(1)));
    316   auto negate = builder.AddInstruction(
    317       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
    318   auto tuple = builder.AddInstruction(
    319       HloInstruction::CreateTuple({negate, param0, constant}));
    320   auto module = CreateNewModule();
    321   module->AddEntryComputation(builder.Build());
    322 
    323   auto buffers = RunBufferAssignment(module.get());
    324   // Make sure that HasAllocationAt() agrees with what HasTopLevelAllocation()
    325   // reports for the instruction directly.
    326   EXPECT_EQ(buffers->HasTopLevelAllocation(tuple),
    327             buffers->HasAllocationAt(tuple, /*index=*/{}));
    328   EXPECT_EQ(buffers->HasTopLevelAllocation(negate),
    329             buffers->HasAllocationAt(tuple, /*index=*/{0}));
    330   EXPECT_EQ(buffers->HasTopLevelAllocation(param0),
    331             buffers->HasAllocationAt(tuple, /*index=*/{1}));
    332   EXPECT_EQ(buffers->HasTopLevelAllocation(constant),
    333             buffers->HasAllocationAt(tuple, /*index=*/{2}));
    334 }
    335 
    336 TEST_F(BufferAssignmentTest, BufferForOutputConst) {
    337   // This computation copies a constant to output.
    338   auto builder = HloComputation::Builder(TestName());
    339   auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
    340       Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
    341   auto copy = builder.AddInstruction(
    342       HloInstruction::CreateUnary(const0->shape(), HloOpcode::kCopy, const0));
    343   auto module = CreateNewModule();
    344   module->AddEntryComputation(builder.Build());
    345 
    346   auto buffers = RunBufferAssignment(module.get());
    347   // The copy node now has an output buffer.
    348   GetAssignedOutputAllocation(*buffers, copy);
    349 }
    350 
    351 TEST_F(BufferAssignmentTest, Basic) {
    352   // paramscalar ------- (mul) -- (add) -- (sub)
    353   //                     /        /        /
    354   // param0[100] -------/        /        /
    355   //                            /        /
    356   // param1[100] --------------/--------/
    357   auto builder = HloComputation::Builder(TestName());
    358   auto paramscalar =
    359       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, ""));
    360   auto param0 = builder.AddInstruction(
    361       HloInstruction::CreateParameter(1, f32vec100_, ""));
    362   auto param1 = builder.AddInstruction(
    363       HloInstruction::CreateParameter(2, f32vec100_, ""));
    364   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
    365       f32vec100_, HloOpcode::kMultiply, paramscalar, param0));
    366   auto add = builder.AddInstruction(
    367       HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
    368   auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
    369       f32vec100_, HloOpcode::kSubtract, add, param1));
    370   auto module = CreateNewModule();
    371   module->AddEntryComputation(builder.Build());
    372 
    373   auto buffers = RunBufferAssignment(module.get());
    374 
    375   // Distinct input buffers were assigned for parameters.
    376   BufferAllocation paramscalar_buffer =
    377       GetAssignedInputAllocation(*buffers, paramscalar);
    378   BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
    379   BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
    380   EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
    381   EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
    382   EXPECT_NE(param0_buffer.index(), param1_buffer.index());
    383 
    384   // The mul node has a valid buffer assigned, doesn't share with input.
    385   const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
    386   EXPECT_NE(mul_buffer.index(), param0_buffer.index());
    387 
    388   // The add node can reuse the mul node's buffer.
    389   const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
    390   EXPECT_EQ(add_buffer.index(), mul_buffer.index());
    391 
    392   // The sub node has a valid output buffer assigned.
    393   GetAssignedOutputAllocation(*buffers, sub);
    394 }
    395 
    396 TEST_F(BufferAssignmentTest, BasicUniquelyColored) {
    397   // paramscalar ------- (mul) -- (add) -- (sub)
    398   //                     /        /        /
    399   // param0[100] -------/        /        /
    400   //                            /        /
    401   // param1[100] --------------/--------/
    402   // The output of each op is colored with a different color, so we can not
    403   // share anything.
    404   auto builder = HloComputation::Builder(TestName());
    405   auto paramscalar =
    406       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, ""));
    407   auto param0 = builder.AddInstruction(
    408       HloInstruction::CreateParameter(1, f32vec100_, ""));
    409   auto param1 = builder.AddInstruction(
    410       HloInstruction::CreateParameter(2, f32vec100_, ""));
    411   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
    412       f32vec100_, HloOpcode::kMultiply, paramscalar, param0));
    413   auto add = builder.AddInstruction(
    414       HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
    415   auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
    416       f32vec100_, HloOpcode::kSubtract, add, param1));
    417   auto module = CreateNewModule();
    418   module->AddEntryComputation(builder.Build());
    419 
    420   auto colorer = [](const BufferLiveness& buffer_liveness) {
    421     int color = 0;
    422 
    423     for (LogicalBuffer::Id id = 0;
    424          id < buffer_liveness.points_to_analysis().num_logical_buffers();
    425          id++) {
    426       auto& buffer = buffer_liveness.points_to_analysis().logical_buffer(id);
    427       buffer.set_color(LogicalBuffer::Color(color++));
    428     }
    429     return Status::OK();
    430   };
    431 
    432   auto buffers = RunColoredBufferAssignment(module.get(), colorer);
    433 
    434   // Distinct input buffers were assigned for parameters.
    435   BufferAllocation paramscalar_buffer =
    436       GetAssignedInputAllocation(*buffers, paramscalar);
    437   BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
    438   BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
    439   EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
    440   EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
    441   EXPECT_NE(param0_buffer.index(), param1_buffer.index());
    442 
    443   // The mul node has a valid buffer assigned, doesn't share with input.
    444   const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
    445   EXPECT_NE(mul_buffer.index(), param0_buffer.index());
    446 
    447   // The add node can not reuse the mul node's buffer due to coloring.
    448   const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
    449   EXPECT_NE(add_buffer.index(), mul_buffer.index());
    450 
    451   // The sub node has a valid output buffer assigned.
    452   GetAssignedOutputAllocation(*buffers, sub);
    453 }
    454 
    455 TEST_F(BufferAssignmentTest, BasicPartiallyColored) {
    456   // paramscalar ------- (mul) -- (add) -- (sub)
    457   //                     /        /        /
    458   // param0[100] -------/        /        /
    459   //                            /        /
    460   // param1[100] --------------/--------/
    461   // The output of the mul and the add have the color 1, and the other buffers
    462   // have the color 0, which allows the mul and add to share buffers.
    463   auto builder = HloComputation::Builder(TestName());
    464   auto paramscalar =
    465       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, ""));
    466   auto param0 = builder.AddInstruction(
    467       HloInstruction::CreateParameter(1, f32vec100_, ""));
    468   auto param1 = builder.AddInstruction(
    469       HloInstruction::CreateParameter(2, f32vec100_, ""));
    470   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
    471       f32vec100_, HloOpcode::kMultiply, paramscalar, param0));
    472   auto add = builder.AddInstruction(
    473       HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
    474   auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
    475       f32vec100_, HloOpcode::kSubtract, add, param1));
    476   auto module = CreateNewModule();
    477   module->AddEntryComputation(builder.Build());
    478 
    479   auto colorer = [](const BufferLiveness& buffer_liveness) {
    480     for (LogicalBuffer::Id id = 0;
    481          id < buffer_liveness.points_to_analysis().num_logical_buffers();
    482          id++) {
    483       auto& buffer = buffer_liveness.points_to_analysis().logical_buffer(id);
    484       const auto& aliases =
    485           buffer_liveness.points_to_analysis().GetBufferAliases(buffer);
    486       for (const auto& alias : aliases) {
    487         if (alias.instruction()->opcode() == HloOpcode::kAdd ||
    488             alias.instruction()->opcode() == HloOpcode::kMultiply) {
    489           buffer.set_color(LogicalBuffer::Color(1));
    490         }
    491       }
    492       if (!buffer.has_color()) {
    493         buffer.set_color(LogicalBuffer::Color(0));
    494       }
    495     }
    496     return Status::OK();
    497   };
    498 
    499   auto buffers = RunColoredBufferAssignment(module.get(), colorer);
    500 
    501   // Distinct input buffers were assigned for parameters.
    502   BufferAllocation paramscalar_buffer =
    503       GetAssignedInputAllocation(*buffers, paramscalar);
    504   BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
    505   BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
    506   EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
    507   EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
    508   EXPECT_NE(param0_buffer.index(), param1_buffer.index());
    509 
    510   // The mul node has a valid buffer assigned, doesn't share with input.
    511   const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
    512   EXPECT_NE(mul_buffer.index(), param0_buffer.index());
    513 
    514   // The add node can reuse the mul node's buffer.
    515   const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
    516   EXPECT_EQ(add_buffer.index(), mul_buffer.index());
    517 
    518   // The sub node has a valid output buffer assigned.
    519   GetAssignedOutputAllocation(*buffers, sub);
    520 }
    521 
    522 TEST_F(BufferAssignmentTest, MultipleUsersForNode) {
    523   // This is similar to the Basic test, with the difference that (sub) is
    524   // another user of (mul)'s result, so (mul)'s buffer cannot be reused for
    525   // (add)'s output.
    526   //
    527   // paramscalar -------\     /-----------\
    528   //                     \   /             \
    529   // param0[100] ------- (mul) -- (add) -- (sub)
    530   //                              /
    531   // param1[100] ----------------/
    532   //
    533   auto builder = HloComputation::Builder(TestName());
    534   auto paramscalar =
    535       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, ""));
    536   auto param0 = builder.AddInstruction(
    537       HloInstruction::CreateParameter(1, f32vec100_, ""));
    538   auto param1 = builder.AddInstruction(
    539       HloInstruction::CreateParameter(2, f32vec100_, ""));
    540   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
    541       f32vec100_, HloOpcode::kMultiply, paramscalar, param0));
    542   auto add = builder.AddInstruction(
    543       HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
    544   auto sub = builder.AddInstruction(
    545       HloInstruction::CreateBinary(f32vec100_, HloOpcode::kSubtract, add, mul));
    546   auto module = CreateNewModule();
    547   module->AddEntryComputation(builder.Build());
    548 
    549   auto buffers = RunBufferAssignment(module.get());
    550 
    551   // Input buffers were assigned for parameters.
    552   BufferAllocation paramscalar_buffer =
    553       GetAssignedInputAllocation(*buffers, paramscalar);
    554   BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
    555   BufferAllocation param1_index = GetAssignedInputAllocation(*buffers, param1);
    556   EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
    557   EXPECT_NE(paramscalar_buffer.index(), param1_index.index());
    558   EXPECT_NE(param0_buffer.index(), param1_index.index());
    559 
    560   // The mul node had a buffer allocated.
    561   const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
    562 
    563   // Now the add node can't reuse the mul node's buffer.
    564   const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
    565   EXPECT_NE(add_buffer.index(), mul_buffer.index());
    566 
    567   // Log size information for inspection.
    568   const std::vector<const HloInstruction*> level0 = GetInstructions(sub);
    569   int64 size0 = ValidateBuffers(level0, *buffers);
    570   LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size()
    571             << " for " << level0.size() << " instructions; "
    572             << "total buffer size " << size0;
    573 }
    574 
    575 TEST_F(BufferAssignmentTest, TrivialMap) {
    576   // This tests a trivial x+1 map as the only operation.
    577   //
    578   // param0[100x10] ---> (map x+1)
    579   //
    580   // Builds the map function.
    581   auto module = CreateNewModule();
    582   auto map_computation =
    583       module->AddEmbeddedComputation(BuildMapComputationPlus1("f32+1"));
    584   auto inner_last = map_computation->root_instruction();
    585 
    586   // Creates the main kernel and verifies instruction counts.
    587   auto builder = HloComputation::Builder(TestName());
    588   auto param0 = builder.AddInstruction(
    589       HloInstruction::CreateParameter(0, f32a100x10_, ""));
    590   auto map = builder.AddInstruction(
    591       HloInstruction::CreateMap(f32a100x10_, {param0}, map_computation));
    592   module->AddEntryComputation(builder.Build());
    593 
    594   const std::vector<const HloInstruction*> level0 = GetInstructions(map);
    595   EXPECT_EQ(2, level0.size()) << "Invalid main kernel size";
    596   const std::vector<const HloInstruction*> level1 = GetInstructions(inner_last);
    597   EXPECT_EQ(3, level1.size()) << "Invalid nested add+1 size";
    598 
    599   // Assigns buffers and fetches sizes.
    600   auto buffers = RunBufferAssignment(module.get());
    601   int64 size0 = ValidateBuffers(level0, *buffers);
    602   int64 size1 = ValidateBuffers(level1, *buffers);
    603 
    604   // Both algorithms assign the map's buffer before processing the embedded
    605   // computation, so we can verify that the buffers aren't shared between them
    606   // by checking:
    607   EXPECT_TRUE(BuffersDistinct(level0, level1, *buffers))
    608       << "Reuse between main kernel and embedded mapping.";
    609 
    610   // An input buffer was assigned for the parameter.
    611   BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
    612 
    613   // An output buffer was assigned for the map.
    614   BufferAllocation map_buffer = GetAssignedOutputAllocation(*buffers, map);
    615   EXPECT_NE(param0_buffer.index(), map_buffer.index());
    616 
    617   // The final computation node of the map is an add of an f32 param and a
    618   // constant.
    619   EXPECT_EQ(HloOpcode::kAdd, inner_last->opcode());
    620   const BufferAllocation& inner_add_buffer =
    621       GetTopLevelAllocation(*buffers, inner_last);
    622   EXPECT_NE(inner_add_buffer.index(), map_buffer.index());
    623 
    624   // Log size information for inspection.
    625   LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size()
    626             << " for " << level0.size() + level1.size() << " instructions; "
    627             << "total buffer size " << size0 + size1;
    628 }
    629 
    630 TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) {
    631   // Make sure that the input buffer of a reduce cannot be reused for its
    632   // output.  (Reuse is not safe in the general case, as it reshapes and some
    633   // out-of-order reductions could overwrite an element before a use.)
    634   //
    635   // param0[100] --- (exp1) --- (exp2) --- (reduce x+1) --- (exp3)
    636   auto module = CreateNewModule();
    637   auto reduce_computation =
    638       module->AddEmbeddedComputation(BuildMapComputationPlus1("f32+1"));
    639 
    640   auto builder = HloComputation::Builder(TestName());
    641   auto param0 = builder.AddInstruction(
    642       HloInstruction::CreateParameter(0, f32a100x10_, ""));
    643   auto exp1 = builder.AddInstruction(
    644       HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, param0));
    645   auto exp2 = builder.AddInstruction(
    646       HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, exp1));
    647   auto const0 = builder.AddInstruction(
    648       HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
    649   auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
    650       /*shape=*/f32vec10_,
    651       /*operand=*/exp2,
    652       /*init_value=*/const0,
    653       /*dimensions_to_reduce=*/{0}, reduce_computation));
    654   auto exp3 = builder.AddInstruction(
    655       HloInstruction::CreateUnary(f32vec10_, HloOpcode::kExp, reduce));
    656 
    657   module->AddEntryComputation(builder.Build());
    658 
    659   auto buffers = RunBufferAssignment(module.get());
    660   const std::vector<const HloInstruction*> instrs = GetInstructions(exp3);
    661   ValidateBuffers(instrs, *buffers);
    662 
    663   const BufferAllocation& exp1_buffer = GetTopLevelAllocation(*buffers, exp1);
    664   const BufferAllocation& exp2_buffer = GetTopLevelAllocation(*buffers, exp2);
    665   const BufferAllocation& reduce_buffer =
    666       GetTopLevelAllocation(*buffers, reduce);
    667 
    668   // The buffer of exp1 is trivially reusable for exp2 - this is just for sanity
    669   // checking.
    670   EXPECT_EQ(exp1_buffer.index(), exp2_buffer.index());
    671 
    672   // The buffer of exp2 cannot be used for reduce, even though it's the only
    673   // operand.
    674   EXPECT_NE(exp2_buffer.index(), reduce_buffer.index());
    675 }
    676 
    677 TEST_F(BufferAssignmentTest, ExampleWhile) {
    678   // This tests a While loop example from the ir_semantics document.
    679   //
    680   // condition (s32,f32[4]) -> bool -- see BuildWhileConditionComputation.
    681   // body: (s32,f32[4]) -> (s32,f32[4]) -- see BuildWhileBodyComputation.
    682   //
    683   // const3[s32] -------\
    684   // const4[f32[4]] --- tuple --- while[condition, body]
    685   //
    686   // Builds the nested condition and body.
    687   auto module = CreateNewModule();
    688   auto condition_computation =
    689       module->AddEmbeddedComputation(BuildWhileConditionComputation("if<4"));
    690   auto body_computation =
    691       module->AddEmbeddedComputation(BuildWhileBodyComputation("add-update"));
    692 
    693   // Creates the main kernel and verifies instruction counts.
    694   auto builder = HloComputation::Builder(TestName());
    695   auto const3 = builder.AddInstruction(
    696       HloInstruction::CreateConstant(Literal::CreateR0<int>(0)));
    697   auto const4 = builder.AddInstruction(HloInstruction::CreateConstant(
    698       Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
    699   auto tuple =
    700       builder.AddInstruction(HloInstruction::CreateTuple({const3, const4}));
    701   auto while_op = builder.AddInstruction(HloInstruction::CreateWhile(
    702       t_s32_f32v4_, condition_computation, body_computation, tuple));
    703   module->AddEntryComputation(builder.Build());
    704 
    705   const std::vector<const HloInstruction*> level0 = GetInstructions(while_op);
    706   EXPECT_EQ(4, level0.size()) << "Invalid while kernel size";
    707   const std::vector<const HloInstruction*> levelc =
    708       GetInstructions(condition_computation->root_instruction());
    709   EXPECT_EQ(4, levelc.size()) << "Invalid nested condition size";
    710   const std::vector<const HloInstruction*> levelb =
    711       GetInstructions(body_computation->root_instruction());
    712   EXPECT_EQ(8, levelb.size()) << "Invalid nested body size";
    713 
    714   // Assigns buffers and fetches sizes.
    715   auto buffers = RunBufferAssignment(module.get());
    716   int64 size0 = ValidateBuffers(level0, *buffers);
    717   int64 sizec = ValidateBuffers(levelc, *buffers);
    718   int64 sizeb = ValidateBuffers(levelb, *buffers);
    719 
    720   // BufferAssignment will assign a single allocation for the following
    721   // instructions: while, while.cond.param, while.body.param, while.body.result.
    722   EXPECT_FALSE(BuffersDistinct(level0, levelc, *buffers))
    723       << "Should be reuse between main kernel and embedded condition.";
    724   EXPECT_FALSE(BuffersDistinct(levelb, levelc, *buffers))
    725       << "Should be reuse between embedded condition and body.";
    726   // Expect buffer reuse between main kernel and body computation.
    727   EXPECT_FALSE(BuffersDistinct(level0, levelb, *buffers))
    728       << "Should be reuse between main kernel and embedded body.";
    729 
    730   // The final computation node of the while body is a tuple of s32 and
    731   // f32[4] adds.
    732   HloInstruction* body_root = body_computation->root_instruction();
    733   EXPECT_EQ(HloOpcode::kTuple, body_root->opcode());
    734 
    735   // Check that buffer for each subshape of 'while_op' shares allocation with
    736   // corresponding buffer from while body computation at same index.
    737   ShapeUtil::ForEachSubshape(
    738       while_op->shape(),
    739       [this, &buffers, while_op, body_root](const Shape& /*subshape*/,
    740                                             const ShapeIndex& index) {
    741         auto while_op_allocation = GetAllocation(*buffers, while_op, index);
    742         auto body_root_allocation = GetAllocation(*buffers, body_root, index);
    743         EXPECT_EQ(while_op_allocation.index(), body_root_allocation.index());
    744       });
    745 
    746   // Log size information for inspection.
    747   LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size()
    748             << " for " << level0.size() + levelc.size() + levelb.size()
    749             << " instructions; total buffer size " << size0 + sizec + sizeb;
    750 }
    751 
    752 TEST_F(BufferAssignmentTest, ExampleConditional) {
    753   auto module = CreateNewModule();
    754   auto true_computation = module->AddEmbeddedComputation(
    755       BuildR0F32UnaryOpComputation(HloOpcode::kCeil, "Ceil"));
    756   auto false_computation = module->AddEmbeddedComputation(
    757       BuildR0F32UnaryOpComputation(HloOpcode::kFloor, "Floor"));
    758 
    759   auto builder = HloComputation::Builder(TestName());
    760   auto pred = builder.AddInstruction(
    761       HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
    762   auto const1 = builder.AddInstruction(
    763       HloInstruction::CreateConstant(Literal::CreateR0<float>(56.4f)));
    764   auto const2 = builder.AddInstruction(
    765       HloInstruction::CreateConstant(Literal::CreateR0<float>(12.4f)));
    766   auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
    767       r0f32_, pred, const1, true_computation, const2, false_computation));
    768   module->AddEntryComputation(builder.Build());
    769 
    770   const std::vector<const HloInstruction*> conditional_instrs =
    771       GetInstructions(conditional);
    772   const std::vector<const HloInstruction*> true_instrs =
    773       GetInstructions(true_computation->root_instruction());
    774   const std::vector<const HloInstruction*> false_instrs =
    775       GetInstructions(false_computation->root_instruction());
    776   EXPECT_EQ(4, conditional_instrs.size());
    777   EXPECT_EQ(2, true_instrs.size());
    778   EXPECT_EQ(2, false_instrs.size());
    779 
    780   auto buffers = RunBufferAssignment(module.get());
    781   ValidateBuffers(conditional_instrs, *buffers);
    782   ValidateBuffers(true_instrs, *buffers);
    783   ValidateBuffers(false_instrs, *buffers);
    784 
    785   EXPECT_FALSE(BuffersDistinct(conditional_instrs, true_instrs, *buffers))
    786       << "Should be reuse between conditional and true computation.";
    787   EXPECT_FALSE(BuffersDistinct(conditional_instrs, false_instrs, *buffers))
    788       << "Should be reuse between conditional and false computation.";
    789   EXPECT_FALSE(BuffersDistinct(true_instrs, false_instrs, *buffers))
    790       << "Should be reuse between true and false computations.";
    791 
    792   const BufferAllocation& conditional_buffer =
    793       GetTopLevelAllocation(*buffers, conditional);
    794   const BufferAllocation& true_buffer =
    795       GetTopLevelAllocation(*buffers, true_computation->root_instruction());
    796   const BufferAllocation& false_buffer =
    797       GetTopLevelAllocation(*buffers, false_computation->root_instruction());
    798   EXPECT_EQ(conditional_buffer.size(), true_buffer.size());
    799   EXPECT_EQ(conditional_buffer.size(), false_buffer.size());
    800 }
    801 
    802 TEST_F(BufferAssignmentTest, UnaryOpReuseChain) {
    803   // param0[100] ---> (exp) ---> (tanh) ---> (exp) ---> (neg)
    804   auto builder = HloComputation::Builder(TestName());
    805   auto param0 = builder.AddInstruction(
    806       HloInstruction::CreateParameter(0, f32vec100_, ""));
    807   auto exp1 = builder.AddInstruction(
    808       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kExp, param0));
    809   auto tanh = builder.AddInstruction(
    810       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kTanh, exp1));
    811   auto exp2 = builder.AddInstruction(
    812       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kExp, tanh));
    813   auto neg = builder.AddInstruction(
    814       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, exp2));
    815 
    816   auto module = CreateNewModule();
    817   module->AddEntryComputation(builder.Build());
    818   auto assignment = RunBufferAssignment(module.get());
    819 
    820   // tanh and exp2 can reuse exp1's buffer
    821   EXPECT_TRUE(assignment->HasTopLevelAllocation(exp1));
    822   auto& buffer_for_exp1 = GetTopLevelAllocation(*assignment, exp1);
    823   EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, tanh));
    824   EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, exp2));
    825   EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, neg));
    826 }
    827 
    828 TEST_F(BufferAssignmentTest, ReuseNonOperandBuffer) {
    829   // This computation is a chain of operations which decreases in buffer size
    830   // (via slice) then increases in size (via broadcast):
    831   //
    832   // param ---> (negate) ---> (slice) ---> (broadcast)
    833   //
    834   // The negate should share a buffer with broadcast.
    835   auto builder = HloComputation::Builder(TestName());
    836   auto param0 = builder.AddInstruction(
    837       HloInstruction::CreateParameter(0, f32vec100_, "param0"));
    838   auto negate = builder.AddInstruction(
    839       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
    840   auto slice = builder.AddInstruction(
    841       HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
    842   auto broadcast = builder.AddInstruction(
    843       HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
    844 
    845   auto module = CreateNewModule();
    846   module->AddEntryComputation(builder.Build());
    847   auto assignment = RunBufferAssignment(module.get());
    848 
    849   // negate and broadcast should share a buffer.
    850   EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast));
    851   auto& buffer_for_bcast = GetTopLevelAllocation(*assignment, broadcast);
    852   EXPECT_EQ(buffer_for_bcast, GetTopLevelAllocation(*assignment, negate));
    853 
    854   // Slice should have its own buffer.
    855   EXPECT_NE(buffer_for_bcast, GetTopLevelAllocation(*assignment, slice));
    856 }
    857 
    858 TEST_F(BufferAssignmentTest, NoReuseLiveBuffer) {
    859   // This computation is identical to that in ReuseNonOperandBuffer, but the
    860   // negate value is live until the end of the computation (due to it being an
    861   // operand of the output tuple) preventing reuse.
    862   //
    863   // param ---> (negate) ---> (slice) ---> (broadcast)-> (tuple)
    864   //                  \-----------------------------------/
    865   //
    866   // The negate should not share a buffer with broadcast.
    867   auto builder = HloComputation::Builder(TestName());
    868   auto param0 = builder.AddInstruction(
    869       HloInstruction::CreateParameter(0, f32vec100_, "param0"));
    870   auto negate = builder.AddInstruction(
    871       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
    872   auto slice = builder.AddInstruction(
    873       HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
    874   auto broadcast = builder.AddInstruction(
    875       HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
    876   builder.AddInstruction(HloInstruction::CreateTuple({negate, broadcast}));
    877 
    878   auto module = CreateNewModule();
    879   module->AddEntryComputation(builder.Build());
    880   auto assignment = RunBufferAssignment(module.get());
    881 
    882   // The instructions should not share buffers.
    883   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
    884             GetTopLevelAllocation(*assignment, negate));
    885   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
    886             GetTopLevelAllocation(*assignment, slice));
    887   EXPECT_NE(GetTopLevelAllocation(*assignment, negate),
    888             GetTopLevelAllocation(*assignment, slice));
    889 }
    890 
    891 TEST_F(BufferAssignmentTest, NoReuseAliasedBuffer) {
    892   // This computation is identical to that in ReuseNonOperandBuffer, but the
    893   // negate value is placed into a tuple which lives to the end of the
    894   // computation. This extends the live range of negate's buffer preventing
    895   // reuse due to buffer aliasing.
    896   //
    897   // param ---> (negate) ---> (tuple) -> (slice) ---> (broadcast)-> (tuple)
    898   //                              \-----------------------------------/
    899   //
    900   // The negate should not share a buffer with broadcast.
    901   auto builder = HloComputation::Builder(TestName());
    902   auto param0 = builder.AddInstruction(
    903       HloInstruction::CreateParameter(0, f32vec100_, "param0"));
    904   auto negate = builder.AddInstruction(
    905       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
    906   auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({negate}));
    907   auto tuple_element = builder.AddInstruction(
    908       HloInstruction::CreateGetTupleElement(f32vec100_, tuple, 0));
    909   auto slice = builder.AddInstruction(
    910       HloInstruction::CreateSlice(f32vec10_, tuple_element, {0}, {10}, {1}));
    911   auto broadcast = builder.AddInstruction(
    912       HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
    913   builder.AddInstruction(HloInstruction::CreateTuple({tuple, broadcast}));
    914 
    915   auto module = CreateNewModule();
    916   module->AddEntryComputation(builder.Build());
    917   auto assignment = RunBufferAssignment(module.get());
    918 
    919   // The instructions should not share buffers.
    920   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
    921             GetTopLevelAllocation(*assignment, negate));
    922   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
    923             GetTopLevelAllocation(*assignment, slice));
    924   EXPECT_NE(GetTopLevelAllocation(*assignment, negate),
    925             GetTopLevelAllocation(*assignment, slice));
    926 }
    927 
    928 TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBuffer) {
    929   // This computation is very similar to ReuseNonOperandBuffer except the
    930   // broadcast has a smaller output than the negate. This should block reuse of
    931   // negate's buffer by broadcast because the output buffer(s) of a computation
    932   // should be exactly sized for the value.
    933   //
    934   // param ---> (negate) ---> (slice) ---> (broadcast)
    935   //
    936   // Neither negate nor slice may share a buffer with broadcast.
    937   auto builder = HloComputation::Builder(TestName());
    938   auto param0 = builder.AddInstruction(
    939       HloInstruction::CreateParameter(0, f32vec100_, "param0"));
    940   // Negate output is 100 elements.
    941   auto negate = builder.AddInstruction(
    942       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
    943   // Slice output is 10 elements.
    944   auto slice = builder.AddInstruction(
    945       HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
    946   // Broadcast output is 40 elements.
    947   auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
    948       ShapeUtil::MakeShape(F32, {10, 4}), slice, {0}));
    949 
    950   auto module = CreateNewModule();
    951   module->AddEntryComputation(builder.Build());
    952   auto assignment = RunBufferAssignment(module.get());
    953 
    954   // The broadcast output buffer cannot be shared.
    955   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
    956             GetTopLevelAllocation(*assignment, negate));
    957   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
    958             GetTopLevelAllocation(*assignment, slice));
    959 }
    960 
    961 TEST_F(BufferAssignmentTest, ReuseOutputBufferIfExactlySized) {
    962   // This is identical to DoNotReuseOversizedOutputBuffer except the broadcast
    963   // output is exactly the same size as the negate (rather than being
    964   // smaller). This enables reuse of negate's buffer by the broadcast because
    965   // the output buffer will be sized exactly to its value.
    966   //
    967   // param ---> (negate) ---> (slice) ---> (broadcast)
    968   //
    969   // The negate should *not* share a buffer with broadcast.
    970   auto builder = HloComputation::Builder(TestName());
    971   auto param0 = builder.AddInstruction(
    972       HloInstruction::CreateParameter(0, f32vec100_, "param0"));
    973   // Negate output is 100 elements.
    974   auto negate = builder.AddInstruction(
    975       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
    976   auto slice = builder.AddInstruction(
    977       HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
    978   // Broadcast output is 40 elements.
    979   auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
    980       ShapeUtil::MakeShape(F32, {10, 10}), slice, {0}));
    981 
    982   auto module = CreateNewModule();
    983   module->AddEntryComputation(builder.Build());
    984   auto assignment = RunBufferAssignment(module.get());
    985 
    986   // negate and broadcast should share a buffer.
    987   EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast));
    988   auto& buffer_for_bcast = GetTopLevelAllocation(*assignment, broadcast);
    989   EXPECT_EQ(buffer_for_bcast, GetTopLevelAllocation(*assignment, negate));
    990 
    991   // Slice should have its own buffer.
    992   EXPECT_NE(buffer_for_bcast, GetTopLevelAllocation(*assignment, slice));
    993 }
    994 
    995 TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBufferInTuple) {
    996   // This computation is very similar to ReuseNonOperandBuffer except the
    997   // broadcast has a smaller output than the negate, and the broadcast is
    998   // contained in the computation output as a tuple element. This should block
    999   // reuse of the negate's buffer by the broadcast because the output buffer(s)
   1000   // of a computation should be exactly sized for the value. This includes those
   1001   // buffers aliased in the output (eg, contained as tuple elements).
   1002   //
   1003   // param ---> (negate) ---> (slice) ---> (broadcast) --> (tuple)
   1004   //
   1005   // Neither negate nor slice may share a buffer with broadcast.
   1006   auto builder = HloComputation::Builder(TestName());
   1007   auto param0 = builder.AddInstruction(
   1008       HloInstruction::CreateParameter(0, f32vec100_, "param0"));
   1009   // Negate output is 100 elements.
   1010   auto negate = builder.AddInstruction(
   1011       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
   1012   // Slice output is 10 elements.
   1013   auto slice = builder.AddInstruction(
   1014       HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
   1015   // Broadcast output is 40 elements.
   1016   auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
   1017       ShapeUtil::MakeShape(F32, {10, 4}), slice, {0}));
   1018   builder.AddInstruction(HloInstruction::CreateTuple({broadcast}));
   1019 
   1020   auto module = CreateNewModule();
   1021   module->AddEntryComputation(builder.Build());
   1022   auto assignment = RunBufferAssignment(module.get());
   1023 
   1024   // The broadcast output buffer cannot be shared.
   1025   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
   1026             GetTopLevelAllocation(*assignment, negate));
   1027   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
   1028             GetTopLevelAllocation(*assignment, slice));
   1029 }
   1030 
   1031 TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) {
   1032   // Verify that buffers for embedded computations are properly marked as
   1033   // thread-local and that embedded parameters are not marked as
   1034   // is_entry_computation_parameter.
   1035   auto module = CreateNewModule();
   1036   auto vec_shape = ShapeUtil::MakeShape(F32, {42});
   1037   auto scalar_shape = ShapeUtil::MakeShape(F32, {});
   1038 
   1039   // Create a scalar computation to use in a map.
   1040   auto map_builder = HloComputation::Builder(TestName() + "_map");
   1041   auto map_param = map_builder.AddInstruction(
   1042       HloInstruction::CreateParameter(0, scalar_shape, "map_param"));
   1043   auto map_root = map_builder.AddInstruction(
   1044       HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, map_param));
   1045   auto map_computation = module->AddEmbeddedComputation(map_builder.Build());
   1046 
   1047   // Create a vector computation to use in a kCall.
   1048   auto call_builder = HloComputation::Builder(TestName() + "_call");
   1049   auto call_param = call_builder.AddInstruction(
   1050       HloInstruction::CreateParameter(0, vec_shape, "vec_param"));
   1051   auto call_root = call_builder.AddInstruction(
   1052       HloInstruction::CreateUnary(vec_shape, HloOpcode::kExp, call_param));
   1053   auto call_computation = module->AddEmbeddedComputation(call_builder.Build());
   1054 
   1055   // Create entry computation which kCalls call_computation and then calls map
   1056   // with map_computation on the result.
   1057   auto builder = HloComputation::Builder(TestName());
   1058   auto param = builder.AddInstruction(
   1059       HloInstruction::CreateParameter(0, vec_shape, "param"));
   1060   auto call = builder.AddInstruction(
   1061       HloInstruction::CreateCall(vec_shape, {param}, call_computation));
   1062   auto map = builder.AddInstruction(
   1063       HloInstruction::CreateMap(vec_shape, {call}, map_computation));
   1064   module->AddEntryComputation(builder.Build());
   1065 
   1066   auto assignment = RunBufferAssignment(module.get());
   1067 
   1068   // Allocations for the map computation should be thread-local and not
   1069   // live-out.
   1070   auto& map_param_alloc = GetTopLevelAllocation(*assignment, map_param);
   1071   EXPECT_FALSE(map_param_alloc.is_entry_computation_parameter());
   1072   EXPECT_FALSE(map_param_alloc.maybe_live_out());
   1073   EXPECT_TRUE(map_param_alloc.is_thread_local());
   1074 
   1075   auto& map_root_alloc = GetTopLevelAllocation(*assignment, map_root);
   1076   EXPECT_FALSE(map_root_alloc.is_entry_computation_parameter());
   1077   EXPECT_FALSE(map_root_alloc.maybe_live_out());
   1078   EXPECT_TRUE(map_root_alloc.is_thread_local());
   1079 
   1080   // Allocations for the call computation should not be thread-local.
   1081   auto& call_param_alloc = GetTopLevelAllocation(*assignment, call_param);
   1082   EXPECT_FALSE(call_param_alloc.is_entry_computation_parameter());
   1083   EXPECT_FALSE(call_param_alloc.maybe_live_out());
   1084   EXPECT_FALSE(call_param_alloc.is_thread_local());
   1085 
   1086   auto& call_root_alloc = GetTopLevelAllocation(*assignment, call_root);
   1087   EXPECT_FALSE(call_root_alloc.is_entry_computation_parameter());
   1088   EXPECT_FALSE(call_root_alloc.is_thread_local());
   1089 
   1090   // Entry computation allocations can be marked liveout and
   1091   // is_entry_computation_parameter.
   1092   auto& param_alloc = GetTopLevelAllocation(*assignment, param);
   1093   EXPECT_TRUE(param_alloc.is_entry_computation_parameter());
   1094   EXPECT_FALSE(param_alloc.maybe_live_out());
   1095   EXPECT_FALSE(param_alloc.is_thread_local());
   1096 
   1097   auto& map_alloc = GetTopLevelAllocation(*assignment, map);
   1098   EXPECT_FALSE(map_alloc.is_entry_computation_parameter());
   1099   EXPECT_TRUE(map_alloc.maybe_live_out());
   1100   EXPECT_FALSE(map_alloc.is_thread_local());
   1101 }
   1102 
   1103 TEST_F(BufferAssignmentTest, TupleParameterAsOutput) {
   1104   // Test a computation that returns a tuple parameter.
   1105   auto builder = HloComputation::Builder(TestName());
   1106   auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter(
   1107       0,
   1108       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}),
   1109                                  ShapeUtil::MakeShape(F32, {}),
   1110                                  ShapeUtil::MakeShape(S32, {42})}),
   1111       "param0"));
   1112 
   1113   auto module = CreateNewModule();
   1114   module->AddEntryComputation(builder.Build());
   1115   auto assignment = RunBufferAssignment(module.get());
   1116 
   1117   // There should be four allocations: one for vector of pointers, and one for
   1118   // each tuple element.
   1119   EXPECT_EQ(4, assignment->Allocations().size());
   1120 
   1121   // Verify each buffer allocation is marked as an entry computation parameter
   1122   // and is liveout.
   1123   ShapeUtil::ForEachSubshape(
   1124       tuple_param->shape(),
   1125       [this, &assignment, tuple_param](const Shape& /*subshape*/,
   1126                                        const ShapeIndex& index) {
   1127         auto allocation = GetAllocation(*assignment, tuple_param, index);
   1128         EXPECT_TRUE(allocation.is_entry_computation_parameter());
   1129         EXPECT_EQ(0, allocation.parameter_number());
   1130         EXPECT_TRUE(allocation.maybe_live_out());
   1131       });
   1132 }
   1133 
   1134 TEST_F(BufferAssignmentTest, ElementOfNestedTupleParameterAsOutput) {
   1135   // Test a computation which returns a GetElementTuple of a nested tuple
   1136   // parameter.
   1137   auto builder = HloComputation::Builder(TestName());
   1138   auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter(
   1139       0,
   1140       ShapeUtil::MakeTupleShape(
   1141           {ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}),
   1142            ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {42}),
   1143                                       ShapeUtil::MakeShape(S32, {101})})}),
   1144       "param0"));
   1145   auto tuple_element =
   1146       builder.AddInstruction(HloInstruction::CreateGetTupleElement(
   1147           ShapeUtil::GetSubshape(tuple_param->shape(), {1}), tuple_param, 1));
   1148 
   1149   auto module = CreateNewModule();
   1150   module->AddEntryComputation(builder.Build());
   1151   auto assignment = RunBufferAssignment(module.get());
   1152 
   1153   // Only some of the elements of the input param are liveout.
   1154   EXPECT_FALSE(
   1155       GetAllocation(*assignment, tuple_param, /*index=*/{}).maybe_live_out());
   1156   // Tuple element at index={1} is live out because GetTupleElement({1})
   1157   // forwards a pointer to this allocation (instead of defining its own buffer).
   1158   EXPECT_TRUE(
   1159       GetAllocation(*assignment, tuple_param, /*index=*/{1}).maybe_live_out());
   1160   EXPECT_TRUE(GetAllocation(*assignment, tuple_param, /*index=*/{1, 0})
   1161                   .maybe_live_out());
   1162   EXPECT_TRUE(GetAllocation(*assignment, tuple_param, /*index=*/{1, 1})
   1163                   .maybe_live_out());
   1164 
   1165   // The GetTupleElement output is liveout.
   1166   EXPECT_TRUE(
   1167       GetTopLevelAllocation(*assignment, tuple_element).maybe_live_out());
   1168 
   1169   // Verify that the GetTupleElement allocations of its elements match the
   1170   // corresponding tuple parameter allocations because they alias.
   1171   EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1, 0}),
   1172             GetAllocation(*assignment, tuple_element, /*index=*/{0}));
   1173   EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1, 1}),
   1174             GetAllocation(*assignment, tuple_element, /*index=*/{1}));
   1175 
   1176   // GetTupleElement forwards a pointer to its underlying buffer, so verify
   1177   // that it has the same allocation than the corresponding parameter element.
   1178   EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1}),
   1179             GetTopLevelAllocation(*assignment, tuple_element));
   1180 }
   1181 
   1182 // TODO(b/32248867): Enable when buffer assignment gives allocations to
   1183 // constants.
   1184 TEST_F(BufferAssignmentTest, DISABLED_TupleConstantAsOutput) {
   1185   // Test that a tuple constant which is forwarded to the computation output
   1186   // is properly handled.
   1187   auto builder = HloComputation::Builder(TestName());
   1188   builder.AddInstruction(HloInstruction::CreateConstant(Literal::MakeTuple(
   1189       {Literal::CreateR0<int64>(0).get(), Literal::CreateR0<int64>(1).get()})));
   1190 
   1191   auto module = CreateNewModule();
   1192   module->AddEntryComputation(builder.Build());
   1193   auto assignment = RunBufferAssignment(module.get());
   1194 
   1195   EXPECT_EQ(3, assignment->Allocations().size());
   1196 }
   1197 
   1198 TEST_F(BufferAssignmentTest, TupleCustomCallAsOutput) {
   1199   // Test a computation which returns a tuple custom call value.
   1200   auto builder = HloComputation::Builder(TestName());
   1201   auto custom_call = builder.AddInstruction(HloInstruction::CreateCustomCall(
   1202       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}),
   1203                                  ShapeUtil::MakeShape(S32, {101})}),
   1204       /*operands=*/{}, /*custom_call_target=*/"foo_function"));
   1205   auto module = CreateNewModule();
   1206   module->AddEntryComputation(builder.Build());
   1207   auto assignment = RunBufferAssignment(module.get());
   1208 
   1209   EXPECT_EQ(3, assignment->Allocations().size());
   1210   EXPECT_TRUE(
   1211       GetAllocation(*assignment, custom_call, /*index=*/{}).maybe_live_out());
   1212   EXPECT_TRUE(
   1213       GetAllocation(*assignment, custom_call, /*index=*/{0}).maybe_live_out());
   1214   EXPECT_TRUE(
   1215       GetAllocation(*assignment, custom_call, /*index=*/{1}).maybe_live_out());
   1216 }
   1217 
   1218 TEST_F(BufferAssignmentTest, TupleCallAsOutput) {
   1219   // Test a computation which returns a tuple call value.
   1220   auto module = CreateNewModule();
   1221   auto elem_shape = f32vec4_;
   1222   auto tuple_shape = ShapeUtil::MakeTupleShape({elem_shape});
   1223 
   1224   auto sub_builder = HloComputation::Builder(TestName() + "_sub");
   1225   auto sub_param = sub_builder.AddInstruction(
   1226       HloInstruction::CreateParameter(0, elem_shape, "sub_param"));
   1227   auto sub_tuple =
   1228       sub_builder.AddInstruction(HloInstruction::CreateTuple({sub_param}));
   1229   auto sub_computation = module->AddEmbeddedComputation(sub_builder.Build());
   1230 
   1231   auto builder = HloComputation::Builder(TestName());
   1232   auto param = builder.AddInstruction(
   1233       HloInstruction::CreateParameter(0, elem_shape, "param"));
   1234   auto call = builder.AddInstruction(
   1235       HloInstruction::CreateCall(tuple_shape, {param}, sub_computation));
   1236   module->AddEntryComputation(builder.Build());
   1237 
   1238   auto assignment = RunBufferAssignment(module.get());
   1239 
   1240   EXPECT_EQ(3, assignment->Allocations().size());
   1241   // Buffers for call are colocated with the sub-computation.
   1242   EXPECT_EQ(GetAllocation(*assignment, call, /*index=*/{}),
   1243             GetAllocation(*assignment, sub_tuple, /*index=*/{}));
   1244   EXPECT_EQ(GetAllocation(*assignment, call, /*index=*/{0}),
   1245             GetAllocation(*assignment, sub_param, /*index=*/{}));
   1246   // The parameter isn't aliased with anything.
   1247   EXPECT_NE(GetTopLevelAllocation(*assignment, param),
   1248             GetTopLevelAllocation(*assignment, sub_tuple));
   1249   EXPECT_NE(GetTopLevelAllocation(*assignment, param),
   1250             GetTopLevelAllocation(*assignment, sub_param));
   1251 }
   1252 
   1253 TEST_F(BufferAssignmentTest, TupleChainedCallAsOutput) {
   1254   // Test a chain of calls with tuple output. The chain looks like:
   1255   // A: call(B, tuple(param))
   1256   // B: call(C, param)
   1257   // C: call(D, param)
   1258   // D: param
   1259   auto module = CreateNewModule();
   1260   auto elem_shape = f32vec4_;
   1261   auto tuple_shape = ShapeUtil::MakeTupleShape({elem_shape});
   1262 
   1263   auto d_builder = HloComputation::Builder(TestName() + "_d");
   1264   auto d_param = d_builder.AddInstruction(
   1265       HloInstruction::CreateParameter(0, tuple_shape, "d_param"));
   1266   auto d_computation = d_builder.Build();
   1267 
   1268   auto c_builder = HloComputation::Builder(TestName() + "_c");
   1269   auto c_param = c_builder.AddInstruction(
   1270       HloInstruction::CreateParameter(0, tuple_shape, "c_param"));
   1271   auto c_call = c_builder.AddInstruction(
   1272       HloInstruction::CreateCall(tuple_shape, {c_param}, d_computation.get()));
   1273   auto c_computation = c_builder.Build();
   1274 
   1275   auto b_builder = HloComputation::Builder(TestName() + "_b");
   1276   auto b_param = b_builder.AddInstruction(
   1277       HloInstruction::CreateParameter(0, tuple_shape, "b_param"));
   1278   auto b_call = b_builder.AddInstruction(
   1279       HloInstruction::CreateCall(tuple_shape, {b_param}, c_computation.get()));
   1280   auto b_computation = b_builder.Build();
   1281 
   1282   auto a_builder = HloComputation::Builder(TestName());
   1283   auto a_param = a_builder.AddInstruction(
   1284       HloInstruction::CreateParameter(0, elem_shape, "param"));
   1285   auto a_tuple =
   1286       a_builder.AddInstruction(HloInstruction::CreateTuple({a_param}));
   1287   auto a_call = a_builder.AddInstruction(
   1288       HloInstruction::CreateCall(tuple_shape, {a_tuple}, b_computation.get()));
   1289   auto a_computation = a_builder.Build();
   1290 
   1291   // Add the computations in an order that doesn't match the dependency
   1292   // post-order, to shake out more possible bugs.
   1293   module->AddEmbeddedComputation(std::move(d_computation));
   1294   module->AddEmbeddedComputation(std::move(c_computation));
   1295   module->AddEntryComputation(std::move(a_computation));
   1296   module->AddEmbeddedComputation(std::move(b_computation));
   1297 
   1298   auto assignment = RunBufferAssignment(module.get());
   1299 
   1300   // Buffers for call are colocated with the sub-computations.
   1301   EXPECT_EQ(GetAllocation(*assignment, a_call, /*index=*/{}),
   1302             GetAllocation(*assignment, b_call, /*index=*/{}));
   1303   EXPECT_EQ(GetAllocation(*assignment, b_call, /*index=*/{}),
   1304             GetAllocation(*assignment, c_call, /*index=*/{}));
   1305   EXPECT_EQ(GetAllocation(*assignment, c_call, /*index=*/{}),
   1306             GetAllocation(*assignment, d_param, /*index=*/{}));
   1307   EXPECT_EQ(GetAllocation(*assignment, a_call, /*index=*/{0}),
   1308             GetAllocation(*assignment, b_call, /*index=*/{0}));
   1309   EXPECT_EQ(GetAllocation(*assignment, b_call, /*index=*/{0}),
   1310             GetAllocation(*assignment, c_call, /*index=*/{0}));
   1311   EXPECT_EQ(GetAllocation(*assignment, c_call, /*index=*/{0}),
   1312             GetAllocation(*assignment, d_param, /*index=*/{0}));
   1313   // The parameters aren't aliased with anything.
   1314   EXPECT_TRUE(BuffersDistinct({a_param}, {b_param}, *assignment));
   1315   EXPECT_TRUE(BuffersDistinct({a_param}, {c_param}, *assignment));
   1316   EXPECT_TRUE(BuffersDistinct({a_param}, {d_param}, *assignment));
   1317   EXPECT_TRUE(BuffersDistinct({b_param}, {c_param}, *assignment));
   1318   EXPECT_TRUE(BuffersDistinct({b_param}, {d_param}, *assignment));
   1319   EXPECT_TRUE(BuffersDistinct({c_param}, {d_param}, *assignment));
   1320 }
   1321 
   1322 TEST_F(BufferAssignmentTest, BitcastAsOutput) {
   1323   // Test a computation which returns a bitcast value.
   1324   auto builder = HloComputation::Builder(TestName());
   1325   auto param = builder.AddInstruction(HloInstruction::CreateParameter(
   1326       0, ShapeUtil::MakeShape(F32, {42}), "param"));
   1327   auto bitcast = builder.AddInstruction(
   1328       HloInstruction::CreateUnary(param->shape(), HloOpcode::kBitcast, param));
   1329 
   1330   auto module = CreateNewModule();
   1331   module->AddEntryComputation(builder.Build());
   1332   auto assignment = RunBufferAssignment(module.get());
   1333 
   1334   // Bitcast should get the same allocation as the param.
   1335   EXPECT_EQ(1, assignment->Allocations().size());
   1336   EXPECT_EQ(GetTopLevelAllocation(*assignment, param),
   1337             GetTopLevelAllocation(*assignment, bitcast));
   1338 }
   1339 
   1340 TEST_F(BufferAssignmentTest, AmbiguousBufferAsOutput) {
   1341   // Test a computation with an output that has an ambiguous points-to set.
   1342   // This is constructed using a select among tuple shapes.
   1343   auto builder = HloComputation::Builder(TestName());
   1344   auto tuple_shape =
   1345       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4})});
   1346 
   1347   auto tuple_param0 = builder.AddInstruction(
   1348       HloInstruction::CreateParameter(0, tuple_shape, "param0"));
   1349   auto tuple_param1 = builder.AddInstruction(
   1350       HloInstruction::CreateParameter(1, tuple_shape, "param1"));
   1351   auto pred_param = builder.AddInstruction(HloInstruction::CreateParameter(
   1352       2, ShapeUtil::MakeShape(PRED, {}), "param1"));
   1353   auto select = builder.AddInstruction(HloInstruction::CreateTernary(
   1354       tuple_shape, HloOpcode::kSelect, pred_param, tuple_param0, tuple_param1));
   1355 
   1356   auto module = CreateNewModule();
   1357   module->AddEntryComputation(builder.Build());
   1358   auto assignment = RunBufferAssignment(module.get());
   1359 
   1360   // Select shallow copies one of its operands so it defines its own top-level
   1361   // buffer and receives its own allocation.
   1362   auto select_alloc = GetTopLevelAllocation(*assignment, select);
   1363   EXPECT_EQ(1, select_alloc.assigned_buffers().size());
   1364   EXPECT_EQ(select,
   1365             select_alloc.assigned_buffers().begin()->first->instruction());
   1366 
   1367   // The buffer for the tuple element of the select is forwarded from one its
   1368   // operands which cannot be determined statically. Therefore its slices
   1369   // should include the slices of both of the elements in the parameters.
   1370   auto element_slices = assignment->GetAllSlices(select, /*index=*/{0});
   1371   EXPECT_EQ(2, element_slices.size());
   1372   EXPECT_THAT(element_slices,
   1373               ::testing::UnorderedElementsAre(
   1374                   assignment->GetUniqueSlice(tuple_param0, /*index=*/{0})
   1375                       .ConsumeValueOrDie(),
   1376                   assignment->GetUniqueSlice(tuple_param1, /*index=*/{0})
   1377                       .ConsumeValueOrDie()));
   1378 }
   1379 
   1380 // TODO(b/34669761): Remove this test when buffers are allowed to share
   1381 // allocations.
   1382 TEST_F(BufferAssignmentTest, TupleBufferNotReused) {
   1383   // Test a computation that returns a tuple parameter.
   1384   auto builder = HloComputation::Builder(TestName());
   1385   auto scalar_shape = ShapeUtil::MakeShape(F32, {});
   1386   auto param = builder.AddInstruction(
   1387       HloInstruction::CreateParameter(0, scalar_shape, "param0"));
   1388   auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({param}));
   1389   auto tuple_element = builder.AddInstruction(
   1390       HloInstruction::CreateGetTupleElement(scalar_shape, tuple, 0));
   1391   auto copy = builder.AddInstruction(HloInstruction::CreateUnary(
   1392       scalar_shape, HloOpcode::kCopy, tuple_element));
   1393 
   1394   auto module = CreateNewModule();
   1395   module->AddEntryComputation(builder.Build());
   1396   auto assignment = RunBufferAssignment(module.get());
   1397 
   1398   // There should be no buffer reuse. The copy should not reuse the tuple
   1399   // buffer.
   1400   EXPECT_EQ(3, assignment->Allocations().size());
   1401   EXPECT_NE(GetTopLevelAllocation(*assignment, tuple),
   1402             GetTopLevelAllocation(*assignment, copy));
   1403 }
   1404 
   1405 TEST_F(BufferAssignmentTest, OneTempAllocation) {
   1406   // Test a computation that requires multiple temp buffers, and ensure they
   1407   // are combined into a single allocation.
   1408   auto builder = HloComputation::Builder(TestName());
   1409   Shape shape_2x3 = ShapeUtil::MakeShape(F32, {2, 3});
   1410   Shape shape_2x4 = ShapeUtil::MakeShape(F32, {2, 4});
   1411   Shape shape_3x4 = ShapeUtil::MakeShape(F32, {3, 4});
   1412   Shape shape_4x4 = ShapeUtil::MakeShape(F32, {4, 4});
   1413   Shape shape_5x4 = ShapeUtil::MakeShape(F32, {5, 4});
   1414 
   1415   // There should be separate temp buffers for dot_ab and dot_bc.
   1416   auto param_a = builder.AddInstruction(
   1417       HloInstruction::CreateParameter(0, shape_2x3, "param_a"));
   1418   auto param_b = builder.AddInstruction(
   1419       HloInstruction::CreateParameter(1, shape_3x4, "param_b"));
   1420   auto param_c = builder.AddInstruction(
   1421       HloInstruction::CreateParameter(2, shape_4x4, "param_c"));
   1422   DotDimensionNumbers dot_dnums;
   1423   dot_dnums.add_lhs_contracting_dimensions(1);
   1424   dot_dnums.add_rhs_contracting_dimensions(0);
   1425   auto dot_ab = builder.AddInstruction(
   1426       HloInstruction::CreateDot(shape_2x4, param_a, param_b, dot_dnums));
   1427   auto dot_bc = builder.AddInstruction(
   1428       HloInstruction::CreateDot(shape_3x4, param_b, param_c, dot_dnums));
   1429   builder.AddInstruction(
   1430       HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 1));
   1431 
   1432   // Run buffer assignment with alignment=1.
   1433   auto module = CreateNewModule();
   1434   module->AddEntryComputation(builder.Build());
   1435   auto assignment = RunBufferAssignment(module.get(), /*alignment=*/1);
   1436 
   1437   // There are 5 allocations: 3 parameters, 1 output, and 1 temp.
   1438   EXPECT_EQ(5, assignment->Allocations().size());
   1439 
   1440   // Ensure the temp buffers for dot_ab and dot_bc share a single allocation,
   1441   // and each occupies different slices of that allocation.
   1442   BufferAllocation::Slice slice_ab =
   1443       assignment->GetUniqueTopLevelSlice(dot_ab).ConsumeValueOrDie();
   1444   BufferAllocation::Slice slice_bc =
   1445       assignment->GetUniqueTopLevelSlice(dot_bc).ConsumeValueOrDie();
   1446   EXPECT_EQ(slice_ab.allocation(), slice_bc.allocation());
   1447   EXPECT_NE(slice_ab, slice_bc);
   1448   EXPECT_EQ(32, slice_ab.size());
   1449   EXPECT_EQ(48, slice_bc.size());
   1450   EXPECT_EQ(80, slice_ab.allocation()->size());
   1451   EXPECT_EQ(80, slice_bc.allocation()->size());
   1452 
   1453   // Re-run buffer assignment with alignment=64.
   1454   assignment = RunBufferAssignment(module.get(), /*alignment=*/64);
   1455   EXPECT_EQ(5, assignment->Allocations().size());
   1456   slice_ab = assignment->GetUniqueTopLevelSlice(dot_ab).ConsumeValueOrDie();
   1457   slice_bc = assignment->GetUniqueTopLevelSlice(dot_bc).ConsumeValueOrDie();
   1458   EXPECT_EQ(slice_ab.allocation(), slice_bc.allocation());
   1459   EXPECT_NE(slice_ab, slice_bc);
   1460   EXPECT_EQ(32, slice_ab.size());
   1461   EXPECT_EQ(48, slice_bc.size());
   1462   // Ensure the offsets and allocation size account for the alignment, without
   1463   // assuming which buffer gets assigned first.
   1464   if (slice_ab.offset() == 0) {
   1465     EXPECT_EQ(64, slice_bc.offset());
   1466     EXPECT_EQ(64 + 48, slice_ab.allocation()->size());
   1467     EXPECT_EQ(64 + 48, slice_bc.allocation()->size());
   1468   } else {
   1469     EXPECT_EQ(64, slice_ab.offset());
   1470     EXPECT_EQ(0, slice_bc.offset());
   1471     EXPECT_EQ(64 + 32, slice_ab.allocation()->size());
   1472     EXPECT_EQ(64 + 32, slice_bc.allocation()->size());
   1473   }
   1474 }
   1475 
   1476 class WhileBufferAssignmentTest : public HloTestBase {
   1477  protected:
   1478   std::unique_ptr<HloComputation> BuildWhileConditionComputation(
   1479       const string& name) {
   1480     auto builder = HloComputation::Builder(name);
   1481     builder.AddInstruction(
   1482         HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
   1483     auto zero = builder.AddInstruction(
   1484         HloInstruction::CreateConstant(Literal::CreateR0<int>(0)));
   1485     auto ten = builder.AddInstruction(
   1486         HloInstruction::CreateConstant(Literal::CreateR0<int>(10)));
   1487     builder.AddInstruction(HloInstruction::CreateBinary(
   1488         ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, zero, ten));
   1489     return builder.Build();
   1490   }
   1491 
   1492   std::unique_ptr<HloComputation> BuildWhileBodyComputation(
   1493       const string& name) {
   1494     auto builder = HloComputation::Builder(name);
   1495     auto loop_state = builder.AddInstruction(
   1496         HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
   1497     auto input = builder.AddInstruction(
   1498         HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 0));
   1499     auto weights = builder.AddInstruction(
   1500         HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
   1501     auto output = builder.AddInstruction(HloInstruction::CreateBinary(
   1502         data_shape_, HloOpcode::kMultiply, input, weights));
   1503     builder.AddInstruction(
   1504         HloInstruction::CreateTuple({input, weights, output}));
   1505     return builder.Build();
   1506   }
   1507 
   1508   std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
   1509                                                         int64 alignment = 1) {
   1510     auto sequence =
   1511         CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie();
   1512     return BufferAssigner::Run(
   1513                module, xla::MakeUnique<SequentialHloOrdering>(module, sequence),
   1514                ByteSizeOf,
   1515                [alignment](LogicalBuffer::Color) { return alignment; })
   1516         .ConsumeValueOrDie();
   1517   }
   1518 
   1519   static int64 ByteSizeOf(const LogicalBuffer& buffer) {
   1520     return ShapeUtil::ByteSizeOf(buffer.shape(), sizeof(void*));
   1521   }
   1522 
   1523   Shape data_shape_ = ShapeUtil::MakeShape(F32, {4});
   1524   Shape loop_state_shape_ =
   1525       ShapeUtil::MakeTupleShape({data_shape_, data_shape_, data_shape_});
   1526 };
   1527 
   1528 static void RunCopyInsertion(HloModule* module) {
   1529   CopyInsertion copy_insertion;
   1530   EXPECT_IS_OK(copy_insertion.Run(module).status());
   1531 }
   1532 
   1533 TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) {
   1534   auto module = xla::MakeUnique<HloModule>(TestName());
   1535   auto builder = HloComputation::Builder("entry");
   1536 
   1537   auto input0 = builder.AddInstruction(
   1538       HloInstruction::CreateParameter(0, data_shape_, "input0"));
   1539   auto weights0 = builder.AddInstruction(
   1540       HloInstruction::CreateParameter(1, data_shape_, "weights0"));
   1541   auto weights1 = builder.AddInstruction(
   1542       HloInstruction::CreateParameter(2, data_shape_, "weights1"));
   1543 
   1544   auto zero = builder.AddInstruction(
   1545       HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0)));
   1546   auto output0 = builder.AddInstruction(
   1547       HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
   1548   auto output1 = builder.AddInstruction(
   1549       HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
   1550 
   1551   auto cond0 =
   1552       module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
   1553   auto body0 =
   1554       module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
   1555 
   1556   auto tuple0 = builder.AddInstruction(
   1557       HloInstruction::CreateTuple({input0, weights0, output0}));
   1558   auto while0 = builder.AddInstruction(
   1559       HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0));
   1560 
   1561   auto cond1 =
   1562       module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
   1563   auto body1 =
   1564       module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
   1565   auto input1 = builder.AddInstruction(
   1566       HloInstruction::CreateGetTupleElement(data_shape_, while0, 2));
   1567   auto tuple1 = builder.AddInstruction(
   1568       HloInstruction::CreateTuple({input1, weights1, output1}));
   1569   auto while1 = builder.AddInstruction(
   1570       HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1));
   1571 
   1572   module->AddEntryComputation(builder.Build());
   1573   RunCopyInsertion(module.get());
   1574   auto assignment = RunBufferAssignment(module.get());
   1575 
   1576   // Verify 'input0' and read-only use while0{0} alias.
   1577   EXPECT_EQ(assignment->GetUniqueSlice(input0, {}).ConsumeValueOrDie(),
   1578             assignment->GetUniqueSlice(while0, {0}).ConsumeValueOrDie());
   1579   // Verify 'weights0' and read-only use while0{1} alias.
   1580   EXPECT_EQ(assignment->GetUniqueSlice(weights0, {}).ConsumeValueOrDie(),
   1581             assignment->GetUniqueSlice(while0, {1}).ConsumeValueOrDie());
   1582   // Verify 'while0{2}' and read-only use while1{0} alias.
   1583   EXPECT_EQ(assignment->GetUniqueSlice(while0, {2}).ConsumeValueOrDie(),
   1584             assignment->GetUniqueSlice(while1, {0}).ConsumeValueOrDie());
   1585   // Verify 'weights1' and read-only use while1{1} alias.
   1586   EXPECT_EQ(assignment->GetUniqueSlice(weights1, {}).ConsumeValueOrDie(),
   1587             assignment->GetUniqueSlice(while1, {1}).ConsumeValueOrDie());
   1588 }
   1589 
   1590 // Tests that the colocated buffers for while instructions are properly assigned
   1591 // during buffer assignment such that the result tuple elements are not assigned
   1592 // to the same buffer.
   1593 //
   1594 // %infeed --> %while.0 --> %while.1 --+
   1595 //                                     +-- %tuple
   1596 //   %zero -->   %add   --> %while.2 --+
   1597 //
   1598 // Execution Order:
   1599 // %infeed -> %while.0 -> %while.1 -> %zero -> %add -> %while.2 -> %tuple
   1600 //
   1601 // The HLO computation used in this test requires specific ordering to expose
   1602 // the bug (b/72496031). During buffer assignment, the visitation order of
   1603 // colocated buffers is %while.2 -> while.0 -> while.1, and the buffer
   1604 // assignment was coalescing the colocated buffers for all 3 while instructions,
   1605 // therefore assigning the same buffer to the two result tuple elements.
   1606 TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
   1607   const Shape r0s32 = ShapeUtil::MakeShape(S32, {});
   1608 
   1609   // Builds a condition computation: x -> x < 4
   1610   auto build_cond = [&]() {
   1611     auto builder = HloComputation::Builder("cond");
   1612     auto const4 = builder.AddInstruction(
   1613         HloInstruction::CreateConstant(Literal::CreateR0<int>(4)));
   1614     auto param =
   1615         builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x"));
   1616     builder.AddInstruction(HloInstruction::CreateBinary(
   1617         ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param, const4));
   1618     return builder.Build();
   1619   };
   1620 
   1621   // Builds a body computation: x -> x + 9
   1622   auto build_body = [&]() {
   1623     auto builder = HloComputation::Builder("body");
   1624     auto const9 = builder.AddInstruction(
   1625         HloInstruction::CreateConstant(Literal::CreateR0<int>(9)));
   1626     auto param =
   1627         builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x"));
   1628     builder.AddInstruction(
   1629         HloInstruction::CreateBinary(r0s32, HloOpcode::kAdd, param, const9));
   1630     return builder.Build();
   1631   };
   1632 
   1633   // Build the entry computation as described in the comment above.
   1634   auto module = xla::MakeUnique<HloModule>(TestName());
   1635   auto builder = HloComputation::Builder("entry");
   1636 
   1637   auto infeed = builder.AddInstruction(HloInstruction::CreateInfeed(r0s32, ""));
   1638   auto cond0 = module->AddEmbeddedComputation(build_cond());
   1639   auto body0 = module->AddEmbeddedComputation(build_body());
   1640   auto while0 = builder.AddInstruction(
   1641       HloInstruction::CreateWhile(r0s32, cond0, body0, infeed));
   1642 
   1643   auto cond1 = module->AddEmbeddedComputation(build_cond());
   1644   auto body1 = module->AddEmbeddedComputation(build_body());
   1645   auto while1 = builder.AddInstruction(
   1646       HloInstruction::CreateWhile(r0s32, cond1, body1, while0));
   1647 
   1648   auto zero = builder.AddInstruction(
   1649       HloInstruction::CreateConstant(Literal::CreateR0<int32>(0)));
   1650   auto add = builder.AddInstruction(
   1651       HloInstruction::CreateBinary(r0s32, HloOpcode::kAdd, zero, zero));
   1652   auto cond2 = module->AddEmbeddedComputation(build_cond());
   1653   auto body2 = module->AddEmbeddedComputation(build_body());
   1654   auto while2 = builder.AddInstruction(
   1655       HloInstruction::CreateWhile(r0s32, cond2, body2, add));
   1656 
   1657   auto tuple =
   1658       builder.AddInstruction(HloInstruction::CreateTuple({while2, while1}));
   1659   module->AddEntryComputation(builder.Build());
   1660 
   1661   // Run CopyInsertion and check if the graph constructed above doesn't need
   1662   // any copies inserted for BufferAssignment to run.
   1663   int64 instruction_count = module->instruction_count();
   1664   CopyInsertion copy_insertion;
   1665   ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
   1666   ASSERT_EQ(instruction_count, module->instruction_count());
   1667 
   1668   // Create a sequential order among all the instructions in the entry
   1669   // computation, since the issue this test stresses depends on the order the
   1670   // nodes are traversed during BufferAssignment.
   1671   SequentialHloOrdering::HloModuleSequence sequence;
   1672   sequence[module->entry_computation()] = {infeed, while0, while1, zero,
   1673                                            add,    while2, tuple};
   1674   TF_ASSERT_OK_AND_ASSIGN(
   1675       auto assignment,
   1676       BufferAssigner::Run(
   1677           module.get(),
   1678           xla::MakeUnique<SequentialHloOrdering>(module.get(), sequence),
   1679           backend().compiler()->BufferSizeBytesFunction(),
   1680           [](LogicalBuffer::Color) { return 1; }));
   1681 
   1682   // The result tuple elements must be assigned with different buffers.
   1683   TF_ASSERT_OK_AND_ASSIGN(auto slice0, assignment->GetUniqueSlice(tuple, {0}));
   1684   TF_ASSERT_OK_AND_ASSIGN(auto slice1, assignment->GetUniqueSlice(tuple, {1}));
   1685   EXPECT_NE(slice0, slice1);
   1686 
   1687   // while0 and while1 result buffers must be equal to slice1.
   1688   TF_ASSERT_OK_AND_ASSIGN(auto slice_while0,
   1689                           assignment->GetUniqueSlice(while0, {}));
   1690   TF_ASSERT_OK_AND_ASSIGN(auto slice_while1,
   1691                           assignment->GetUniqueSlice(while1, {}));
   1692   EXPECT_EQ(slice1, slice_while0);
   1693   EXPECT_EQ(slice1, slice_while1);
   1694 
   1695   // while2 result buffer must be equal to slice0.
   1696   TF_ASSERT_OK_AND_ASSIGN(auto slice_while2,
   1697                           assignment->GetUniqueSlice(while2, {}));
   1698   EXPECT_EQ(slice0, slice_while2);
   1699 }
   1700 
   1701 TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) {
   1702   auto module = xla::MakeUnique<HloModule>(TestName());
   1703   auto builder = HloComputation::Builder("entry");
   1704 
   1705   auto input0 = builder.AddInstruction(
   1706       HloInstruction::CreateParameter(0, data_shape_, "input0"));
   1707   auto weights0 = builder.AddInstruction(
   1708       HloInstruction::CreateParameter(1, data_shape_, "weights0"));
   1709 
   1710   auto zero = builder.AddInstruction(
   1711       HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0)));
   1712   auto output0 = builder.AddInstruction(
   1713       HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
   1714 
   1715   auto cond0 =
   1716       module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
   1717   auto body0 =
   1718       module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
   1719 
   1720   auto tuple0 = builder.AddInstruction(
   1721       HloInstruction::CreateTuple({input0, weights0, output0}));
   1722   auto while0 = builder.AddInstruction(
   1723       HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0));
   1724 
   1725   auto cond1 =
   1726       module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
   1727   auto body1 =
   1728       module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
   1729 
   1730   auto while1 = builder.AddInstruction(
   1731       HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, while0));
   1732 
   1733   module->AddEntryComputation(builder.Build());
   1734   RunCopyInsertion(module.get());
   1735   auto assignment = RunBufferAssignment(module.get());
   1736 
   1737   // while0 and while1 buffers should be completely aligned.
   1738   EXPECT_EQ(assignment->GetUniqueSlice(while0, {0}).ConsumeValueOrDie(),
   1739             assignment->GetUniqueSlice(while1, {0}).ConsumeValueOrDie());
   1740   EXPECT_EQ(assignment->GetUniqueSlice(while0, {1}).ConsumeValueOrDie(),
   1741             assignment->GetUniqueSlice(while1, {1}).ConsumeValueOrDie());
   1742   EXPECT_EQ(assignment->GetUniqueSlice(while0, {2}).ConsumeValueOrDie(),
   1743             assignment->GetUniqueSlice(while1, {2}).ConsumeValueOrDie());
   1744 }
   1745 
   1746 TEST_F(BufferAssignmentTest, TwoCalls) {
   1747   auto module = xla::MakeUnique<HloModule>(TestName());
   1748   Shape r0f32 = ShapeUtil::MakeShape(xla::F32, {});
   1749   HloComputation* sub_computation;
   1750   {
   1751     auto builder = HloComputation::Builder(TestName() + "_sub_comp");
   1752     auto param = builder.AddInstruction(
   1753         HloInstruction::CreateParameter(0, r0f32, "param"));
   1754     auto constant1 = builder.AddInstruction(
   1755         HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
   1756     auto add = builder.AddInstruction(
   1757         HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, constant1));
   1758     sub_computation = module->AddEmbeddedComputation(builder.Build(add));
   1759   }
   1760   auto builder = HloComputation::Builder(TestName());
   1761   auto constant2 = builder.AddInstruction(
   1762       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
   1763   auto constant3 = builder.AddInstruction(
   1764       HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
   1765   auto call1 = builder.AddInstruction(
   1766       HloInstruction::CreateCall(r0f32, {constant2}, sub_computation));
   1767   auto call2 = builder.AddInstruction(
   1768       HloInstruction::CreateCall(r0f32, {constant3}, sub_computation));
   1769   auto add1 = builder.AddInstruction(
   1770       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, call1, constant2));
   1771   auto add2 = builder.AddInstruction(
   1772       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, call2, add1));
   1773   module->AddEntryComputation(builder.Build(add2));
   1774 
   1775   {
   1776     FlattenCallGraph flatten;
   1777     TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get()));
   1778     EXPECT_TRUE(result);
   1779     std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
   1780   }
   1781 
   1782   RunCopyInsertion(module.get());
   1783   auto assignment = RunBufferAssignment(module.get());
   1784 
   1785   EXPECT_TRUE(BuffersDistinct({call1}, {call2}, *assignment));
   1786 }
   1787 
   1788 static bool IsPostOrderTraversal(
   1789     const std::vector<const HloInstruction*>& sequence) {
   1790   tensorflow::gtl::FlatSet<const HloInstruction*> seen_so_far;
   1791   auto has_not_been_seen_yet = [&](const HloInstruction* instruction) {
   1792     return seen_so_far.count(instruction) == 0;
   1793   };
   1794 
   1795   for (auto instruction : sequence) {
   1796     if (std::any_of(instruction->operands().begin(),
   1797                     instruction->operands().end(), has_not_been_seen_yet) ||
   1798         std::any_of(instruction->control_predecessors().begin(),
   1799                     instruction->control_predecessors().end(),
   1800                     has_not_been_seen_yet)) {
   1801       return false;  // Not a post order.
   1802     }
   1803     if (!seen_so_far.insert(instruction).second) {
   1804       return false;  // Not a "traversal".
   1805     }
   1806   }
   1807 
   1808   return true;
   1809 }
   1810 
   1811 TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
   1812   auto module = xla::MakeUnique<HloModule>(TestName());
   1813   auto builder = HloComputation::Builder(TestName());
   1814 
   1815   auto zero = builder.AddInstruction(
   1816       HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0)));
   1817   auto one = builder.AddInstruction(
   1818       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
   1819 
   1820   auto input0 = builder.AddInstruction(
   1821       HloInstruction::CreateParameter(0, data_shape_, "input0"));
   1822   auto weights0 = builder.AddInstruction(
   1823       HloInstruction::CreateParameter(1, data_shape_, "weights0"));
   1824   auto output0 = builder.AddInstruction(
   1825       HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
   1826 
   1827   auto input1 = builder.AddInstruction(
   1828       HloInstruction::CreateParameter(2, data_shape_, "input1"));
   1829   auto weights1 = builder.AddInstruction(
   1830       HloInstruction::CreateParameter(3, data_shape_, "weights1"));
   1831   auto output1 = builder.AddInstruction(
   1832       HloInstruction::CreateBroadcast(data_shape_, one, {1}));
   1833 
   1834   auto cond =
   1835       module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
   1836   auto body = module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
   1837 
   1838   auto tuple0 = builder.AddInstruction(
   1839       HloInstruction::CreateTuple({input0, weights0, output0}));
   1840   auto tuple1 = builder.AddInstruction(
   1841       HloInstruction::CreateTuple({input1, weights1, output1}));
   1842 
   1843   auto while0 = builder.AddInstruction(
   1844       HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple0));
   1845   auto while1 = builder.AddInstruction(
   1846       HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple1));
   1847 
   1848   auto gte0 = builder.AddInstruction(
   1849       HloInstruction::CreateGetTupleElement(data_shape_, while0, 0));
   1850   auto gte1 = builder.AddInstruction(
   1851       HloInstruction::CreateGetTupleElement(data_shape_, while1, 1));
   1852   auto root_add = builder.AddInstruction(HloInstruction::CreateBinary(
   1853       while0->shape(), HloOpcode::kAdd, gte0, gte1));
   1854 
   1855   module->AddEntryComputation(builder.Build());
   1856 
   1857   {
   1858     FlattenCallGraph flatten;
   1859     TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get()));
   1860     EXPECT_TRUE(result);
   1861   }
   1862 
   1863   RunCopyInsertion(module.get());
   1864 
   1865   auto sequence =
   1866       CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie();
   1867 
   1868   // To trigger b/38494731, we want a specific Hlo sequence for the
   1869   // root computation, so we overwrite that entry with a manually
   1870   // crafted sequence.
   1871   sequence[module->entry_computation()] = {
   1872       input1, weights1, one,     output1, while1->operand(0), while1,
   1873       input0, weights0, zero,    output0, while0->operand(0), while0,
   1874       gte0,   gte1,     root_add};
   1875 
   1876   // If this ASSERT_TRUE fails, we constructed a bogus sequence above
   1877   // and this test itself is buggy.
   1878   ASSERT_TRUE(IsPostOrderTraversal(sequence[module->entry_computation()]));
   1879 
   1880   auto assignment =
   1881       BufferAssigner::Run(
   1882           module.get(),
   1883           xla::MakeUnique<SequentialHloOrdering>(module.get(), sequence),
   1884           ByteSizeOf, [](LogicalBuffer::Color) { return 1; })
   1885           .ConsumeValueOrDie();
   1886 
   1887   EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment));
   1888 }
   1889 
   1890 TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) {
   1891   auto module = xla::MakeUnique<HloModule>(TestName());
   1892   auto builder = HloComputation::Builder("entry");
   1893 
   1894   auto input0 = builder.AddInstruction(
   1895       HloInstruction::CreateParameter(0, data_shape_, "input0"));
   1896   auto weights0 = builder.AddInstruction(
   1897       HloInstruction::CreateParameter(1, data_shape_, "weights0"));
   1898 
   1899   auto zero = builder.AddInstruction(
   1900       HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0)));
   1901   auto output0 = builder.AddInstruction(
   1902       HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
   1903   auto output1 = builder.AddInstruction(
   1904       HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
   1905 
   1906   auto cond0 =
   1907       module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
   1908   auto body0 =
   1909       module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
   1910 
   1911   auto tuple0 = builder.AddInstruction(
   1912       HloInstruction::CreateTuple({input0, weights0, output0}));
   1913   auto while0 = builder.AddInstruction(
   1914       HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0));
   1915 
   1916   // Get output of 'while0' and feed as input to 'while1'.
   1917   auto while0_out = builder.AddInstruction(
   1918       HloInstruction::CreateGetTupleElement(data_shape_, while0, 2));
   1919 
   1920   auto cond1 =
   1921       module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
   1922   auto body1 =
   1923       module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
   1924 
   1925   auto tuple1 = builder.AddInstruction(
   1926       HloInstruction::CreateTuple({while0_out, weights0, output1}));
   1927   auto while1 = builder.AddInstruction(
   1928       HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1));
   1929 
   1930   // Get output of 'while1' so that it is live out of computation.
   1931   auto while1_out = builder.AddInstruction(
   1932       HloInstruction::CreateGetTupleElement(data_shape_, while1, 2));
   1933 
   1934   module->AddEntryComputation(builder.Build());
   1935   RunCopyInsertion(module.get());
   1936   auto assignment = RunBufferAssignment(module.get());
   1937   // Get BufferAllocation for root instruction.
   1938   auto* root_alloc = assignment->GetUniqueTopLevelSlice(while1_out)
   1939                          .ConsumeValueOrDie()
   1940                          .allocation();
   1941   // Test that root instruction allocation is live out.
   1942   EXPECT_TRUE(root_alloc->maybe_live_out());
   1943   // Test that root instruction allocation is not an entry parameter.
   1944   EXPECT_FALSE(root_alloc->is_entry_computation_parameter());
   1945 }
   1946 
   1947 }  // namespace
   1948 }  // namespace xla
   1949