Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/heap_simulator.h"
     17 
     18 #include <memory>
     19 #include <utility>
     20 #include <vector>
     21 
     22 #include "absl/container/flat_hash_map.h"
     23 #include "absl/memory/memory.h"
     24 #include "tensorflow/compiler/xla/literal.h"
     25 #include "tensorflow/compiler/xla/service/buffer_value.h"
     26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     28 #include "tensorflow/compiler/xla/service/hlo_module.h"
     29 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
     30 #include "tensorflow/compiler/xla/service/hlo_value.h"
     31 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
     32 #include "tensorflow/compiler/xla/status_macros.h"
     33 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     34 #include "tensorflow/core/lib/core/status_test_util.h"
     35 
     36 namespace xla {
     37 namespace {
     38 
     39 class MinimumMemoryForSequenceTest : public HloTestBase {};
     40 
     41 TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
     42   auto module = CreateNewVerifiedModule();
     43   const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
     44   const Shape tuple_shape =
     45       ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape});
     46 
     47   auto cond_builder = HloComputation::Builder("WhileCond");
     48   // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
     49   HloInstruction* cond_param = cond_builder.AddInstruction(
     50       HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
     51   HloInstruction* cond_iter = cond_builder.AddInstruction(
     52       HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0));
     53   HloInstruction* cond_data = cond_builder.AddInstruction(
     54       HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
     55   // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte)
     56   HloInstruction* cond_lt = cond_builder.AddInstruction(
     57       HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
     58                                     cond_data, ComparisonDirection::kLt));
     59   HloComputation* cond_computation =
     60       module->AddEmbeddedComputation(cond_builder.Build());
     61 
     62   auto body_builder = HloComputation::Builder("WhileBody");
     63   // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
     64   HloInstruction* body_param = body_builder.AddInstruction(
     65       HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
     66   HloComputation* body_computation =
     67       module->AddEmbeddedComputation(body_builder.Build());
     68 
     69   auto builder = HloComputation::Builder(TestName());
     70   // Entry params: 8 bytes (4 bytes per param), TOTAL=8
     71   HloInstruction* iter = builder.AddInstruction(
     72       HloInstruction::CreateParameter(0, scalar_shape, "param_iter"));
     73   HloInstruction* data = builder.AddInstruction(
     74       HloInstruction::CreateParameter(1, scalar_shape, "param_data"));
     75   // Tuple: 16 bytes (8 bytes per pointer), TOTAL=24
     76   HloInstruction* tuple =
     77       builder.AddInstruction(HloInstruction::CreateTuple({iter, data}));
     78   // While: 8 bytes (4 bytes per element), TOTAL=32
     79   // Both cond and body use a max of 24 bytes, TOTAL=56
     80   HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
     81       tuple_shape, cond_computation, body_computation, tuple));
     82   HloComputation* entry_computation =
     83       module->AddEntryComputation(builder.Build());
     84 
     85   auto size_fn = [](const BufferValue& buffer) {
     86     return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
     87   };
     88 
     89   HloSchedule schedule(module.get());
     90   schedule.set_sequence(cond_computation,
     91                         {cond_param, cond_iter, cond_data, cond_lt});
     92   schedule.set_sequence(body_computation, {body_param});
     93   schedule.set_sequence(entry_computation, {iter, data, tuple, while_op});
     94   TF_ASSERT_OK(schedule.Verify());
     95 
     96   EXPECT_EQ(
     97       56,
     98       HeapSimulator::MinimumMemoryForModule(schedule, size_fn).ValueOrDie());
     99 }
    100 
    101 TEST_F(MinimumMemoryForSequenceTest, SubcomputationAccounting) {
    102   // HloModule SubcomputationAccounting
    103 
    104   // %WhileBody (body_param: f32[4]) -> f32[4] {
    105   //   %body_param = f32[4]{0} parameter(0)
    106   //   %constant.1 = f32[4]{0} constant({1, 1, 1, 1})
    107   //   ROOT %subtract = f32[4]{0} subtract(f32[4]{0} %body_param, f32[4]{0}
    108   //   %constant.1)
    109   // }
    110 
    111   // %WhileCond (cond_param: f32[4]) -> pred[] {
    112   //   %cond_param = f32[4]{0} parameter(0)
    113   //   %slice = f32[1]{0} slice(f32[4]{0} %cond_param), slice={[0:1]}
    114   //   %reshape = f32[] reshape(f32[1]{0} %slice)
    115   //   %constant = f32[] constant(0)
    116   //   ROOT %not-equal-to = pred[] compare(f32[] %reshape, f32[] %constant),
    117   //   direction=NE
    118   // }
    119 
    120   // ENTRY %SubcomputationAccounting () -> f32[2,4] {
    121   //   %constant.3 = f32[2,4]{1,0} constant(f32[2,4] { { 1, 2, 3, 4 }, { 1, 2,
    122   //   3, 4 } }) %transpose = f32[2,4]{1,0} transpose(f32[2,4]{1,0}
    123   //   %constant.3), dimensions={0,1} %constant.2 = f32[4]{0} constant({1, 1, 1,
    124   //   1}) %while = f32[4]{0} while(f32[4]{0} %constant.2),
    125   //   condition=%WhileCond, body=%WhileBody %broadcast = f32[2,4]{1,0}
    126   //   broadcast(f32[4]{0} %while), dimensions={1} ROOT %add = f32[2,4]{1,0}
    127   //   add(f32[2,4]{1,0} %transpose, f32[2,4]{1,0} %broadcast)
    128   // }
    129 
    130   auto module = CreateNewVerifiedModule();
    131   const Shape r0f32 = ShapeUtil::MakeShape(F32, {});
    132   const Shape r1f32 = ShapeUtil::MakeShape(F32, {4});
    133   const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4});
    134 
    135   // reshape(slice(param)) != 0
    136   // Needs 5 bytes
    137   auto cond_builder = HloComputation::Builder("WhileCond");
    138   HloInstruction* cond_param = cond_builder.AddInstruction(
    139       HloInstruction::CreateParameter(0, r1f32, "cond_param"));
    140   HloInstruction* slice =
    141       cond_builder.AddInstruction(HloInstruction::CreateSlice(
    142           ShapeUtil::MakeShape(F32, {1}), cond_param, {0}, {1}, {1}));
    143   HloInstruction* reshape =
    144       cond_builder.AddInstruction(HloInstruction::CreateReshape(r0f32, slice));
    145   HloInstruction* zero = cond_builder.AddInstruction(
    146       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
    147   HloInstruction* cond_comparison = cond_builder.AddInstruction(
    148       HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), reshape,
    149                                     zero, ComparisonDirection::kNe));
    150   auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
    151 
    152   // param - 1
    153   // Needs 16 bytes
    154   auto body_builder = HloComputation::Builder("WhileBody");
    155   HloInstruction* body_param = body_builder.AddInstruction(
    156       HloInstruction::CreateParameter(0, r1f32, "body_param"));
    157   HloInstruction* one_vector =
    158       body_builder.AddInstruction(HloInstruction::CreateConstant(
    159           LiteralUtil::CreateR1<float>({1, 1, 1, 1})));
    160   HloInstruction* subtract =
    161       body_builder.AddInstruction(HloInstruction::CreateBinary(
    162           r1f32, HloOpcode::kSubtract, body_param, one_vector));
    163   auto body_computation = module->AddEmbeddedComputation(body_builder.Build());
    164 
    165   // transpose(matrix) + bcast(while)
    166   auto builder = HloComputation::Builder(TestName());
    167   HloInstruction* while_init =
    168       builder.AddInstruction(HloInstruction::CreateConstant(
    169           LiteralUtil::CreateR1<float>({1, 1, 1, 1})));
    170   // Creates 16 bytes, ignoring subcomputations
    171   HloInstruction* while_loop =
    172       builder.AddInstruction(HloInstruction::CreateWhile(
    173           r1f32, cond_computation, body_computation, while_init));
    174 
    175   // Creates 32 bytes and frees 16
    176   HloInstruction* bcast = builder.AddInstruction(
    177       HloInstruction::CreateBroadcast(r2f32, while_loop, {1}));
    178 
    179   HloInstruction* matrix = builder.AddInstruction(
    180       HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>(
    181           {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}})));
    182   // Creates 32 bytes
    183   HloInstruction* transpose = builder.AddInstruction(
    184       HloInstruction::CreateTranspose(r2f32, matrix, {0, 1}));
    185 
    186   // Creates 32 bytes and frees 64
    187   HloInstruction* add = builder.AddInstruction(
    188       HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, transpose, bcast));
    189 
    190   auto entry_computation = module->AddEntryComputation(builder.Build());
    191 
    192   HloSchedule schedule(module.get());
    193   std::vector<HloInstruction*> cond_vec = {cond_param, slice, reshape, zero,
    194                                            cond_comparison};
    195   std::vector<HloInstruction*> while_body_vec = {body_param, one_vector,
    196                                                  subtract};
    197   std::vector<HloInstruction*> entry_comp_vec = {while_init, while_loop, bcast,
    198                                                  matrix,     transpose,  add};
    199   schedule.set_sequence(cond_computation, cond_vec);
    200   schedule.set_sequence(body_computation, while_body_vec);
    201   schedule.set_sequence(entry_computation, entry_comp_vec);
    202 
    203   auto size_fn = [](const BufferValue& buffer) {
    204     return ShapeUtil::ByteSizeOf(buffer.shape());
    205   };
    206   absl::flat_hash_map<const HloComputation*, int64> memory_by_computation;
    207   memory_by_computation[cond_computation] = 5;
    208   memory_by_computation[body_computation] = 16;
    209   std::unique_ptr<TuplePointsToAnalysis> points_to_analysis =
    210       TuplePointsToAnalysis::Run(module.get()).ValueOrDie();
    211 
    212   // HeapSimulator accounts for subcomputations. The output buffer is aliased,
    213   // so we don't double count.
    214   EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation(
    215                     *entry_computation, schedule.sequence(entry_computation),
    216                     *points_to_analysis, size_fn, &memory_by_computation)
    217                     .ValueOrDie());
    218 }
    219 
    220 const char kAlloc[] = "Alloc";
    221 const char kFree[] = "Free";
    222 const char kFinish[] = "Finish";
    223 
    224 // CallSequence records a sequence of Alloc/Free/Finish calls.
    225 using CallSequence = std::vector<std::pair<string, const BufferValue*>>;
    226 
    227 // HeapCallRecorder is a dummy heap algorithm that simply records its calls.
    228 class HeapCallRecorder : public HeapAlgorithm {
    229  public:
    230   explicit HeapCallRecorder(CallSequence* calls) : calls_(calls) {}
    231   ~HeapCallRecorder() override {}
    232 
    233   void Alloc(const BufferValue* buffer, int64 size) override {
    234     calls_->emplace_back(kAlloc, buffer);
    235     // Instead of assigning a real offset, we set the cardinality of the Alloc
    236     // call.  This isn't a valid assignment, but allows us to easily test for
    237     // buffer sharing.
    238     const int64 offset = result_.chunk_map.size();
    239     result_.chunk_map.emplace(buffer, Chunk{offset, size});
    240   }
    241   void Free(const BufferValue* buffer, int64 size) override {
    242     calls_->emplace_back(kFree, buffer);
    243   }
    244   Result Finish() override {
    245     calls_->emplace_back(kFinish, nullptr);
    246     return result_;
    247   }
    248 
    249  private:
    250   CallSequence* calls_;
    251   Result result_;
    252 };
    253 
    254 // HeapSimulatorTracker runs the heap simulator, recording the sequence of calls
    255 // made to the underlying heap algorithm.  Tests compare the actual call
    256 // sequence against an expected sequence.
    257 class HeapSimulatorTracker {
    258  public:
    259   // Constructor for testing a single entry computation.
    260   HeapSimulatorTracker(
    261       const string& name, std::unique_ptr<HloComputation> computation,
    262       const std::vector<HloInstruction*>& instruction_sequence) {
    263     HloModuleConfig config;
    264     module_ = absl::make_unique<HloModule>(name, config);
    265     module_->AddEntryComputation(std::move(computation));
    266     points_to_analysis_ =
    267         TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
    268     // Since we're only tracking the sequence of Alloc/Free calls, the actual
    269     // size of the buffers doesn't matter, so we always return 0.  We rely on
    270     // the secondary sorting criteria of DecreasingSizeRunsHeap to sort calls by
    271     // buffer id, for determinism in the tests.
    272     auto zero_size = [](const BufferValue& buffer) { return 0; };
    273     auto algorithm = absl::make_unique<DecreasingSizeRunsHeap>(
    274         absl::make_unique<HeapCallRecorder>(&actual_calls_));
    275     result_ =
    276         HeapSimulator::Run(std::move(algorithm), *module_->entry_computation(),
    277                            HloInstructionSequence(instruction_sequence),
    278                            *points_to_analysis_, zero_size)
    279             .ConsumeValueOrDie();
    280   }
    281 
    282   explicit HeapSimulatorTracker(const string& name) {
    283     HloModuleConfig config;
    284     module_ = absl::make_unique<HloModule>(name, config);
    285   }
    286 
    287   // Similar to the single entry computation constructor above, but runs the
    288   // simulation over the entire module.
    289   void RunWholeModule(
    290       const std::vector<HloInstruction*>& full_module_sequence) {
    291     points_to_analysis_ =
    292         TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
    293 
    294     // Construct the module sequence grouped by computation.
    295     HloSchedule schedule(module_.get());
    296     absl::flat_hash_map<const HloInstruction*, int> reverse_position;
    297     for (int i = 0; i < full_module_sequence.size(); ++i) {
    298       HloInstruction* instruction = full_module_sequence[i];
    299       schedule.GetOrCreateSequence(instruction->parent())
    300           .push_back(instruction);
    301       reverse_position[instruction] = full_module_sequence.size() - i;
    302     }
    303 
    304     // Hack the size_fn so that it returns a decreasing value as we step through
    305     // the sequence. This lets us ensure the Alloc calls are in the sequence
    306     // order. The Free calls are sorted by BufferValue.id, which is at least
    307     // deterministic.
    308     auto size_fn = [&reverse_position](const BufferValue& buffer) {
    309       return reverse_position[buffer.instruction()];
    310     };
    311     auto algorithm = absl::make_unique<DecreasingSizeRunsHeap>(
    312         absl::make_unique<HeapCallRecorder>(&actual_calls_));
    313     result_ = HeapSimulator::Run(std::move(algorithm), *module_, schedule,
    314                                  *points_to_analysis_, size_fn)
    315                   .ConsumeValueOrDie();
    316   }
    317 
    318   HloModule* module() { return module_.get(); }
    319 
    320   // Returns the buffer defined at the given instruction and index.
    321   const BufferValue* BufferAt(const HloInstruction* instruction,
    322                               const ShapeIndex& index) const {
    323     return points_to_analysis_->GetBufferDefinedAt(instruction, index)
    324         .ConsumeValueOrDie();
    325   }
    326 
    327   int64 OffsetAt(const HloInstruction* instruction, const ShapeIndex& index) {
    328     const BufferValue* buffer = BufferAt(instruction, index);
    329     return result_.chunk_map.at(buffer).offset;
    330   }
    331 
    332   // Ensures the expected sequence of Alloc/Free/Finish calls was performed.
    333   void ExpectCallSequence(const CallSequence& expected) const {
    334     EXPECT_EQ(expected, actual_calls_);
    335   }
    336 
    337   // Ensures the buffers defined by the respective (instruction,index) pairs are
    338   // shared, relying on the unique offsets assigned in HeapCallRecorder::Alloc.
    339   void ExpectSharedBuffers(const HloInstruction* instruction_a,
    340                            const ShapeIndex& index_a,
    341                            const HloInstruction* instruction_b,
    342                            const ShapeIndex& index_b) {
    343     int64 offset_a = OffsetAt(instruction_a, index_a);
    344     int64 offset_b = OffsetAt(instruction_b, index_b);
    345     EXPECT_EQ(offset_a, offset_b);
    346   }
    347 
    348  private:
    349   std::unique_ptr<HloModule> module_;
    350   std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
    351   CallSequence actual_calls_;
    352   HeapSimulator::Result result_;
    353 };
    354 
    355 class HeapSimulatorTest : public HloTestBase {
    356  protected:
    357   HeapSimulatorTest() {}
    358   ~HeapSimulatorTest() override {}
    359 
    360   // Shapes for use in the examples.
    361   Shape f32scalar_ = ShapeUtil::MakeShape(xla::F32, {});
    362   Shape f32vec4_ = ShapeUtil::MakeShape(F32, {4});
    363 };
    364 
    365 TEST_F(HeapSimulatorTest, ScalarConstant) {
    366   auto builder = HloComputation::Builder(TestName());
    367   auto const0 = builder.AddInstruction(
    368       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
    369 
    370   // Constants aren't assigned.  See b/32248867
    371   HeapSimulatorTracker tracker(TestName(), builder.Build(), {const0});
    372   tracker.ExpectCallSequence({{kFinish, nullptr}});
    373 }
    374 
    375 TEST_F(HeapSimulatorTest, OneParam) {
    376   auto builder = HloComputation::Builder(TestName());
    377   auto param0 = builder.AddInstruction(
    378       HloInstruction::CreateParameter(0, f32scalar_, "param0"));
    379 
    380   // A single parameter which is also the output.
    381   HeapSimulatorTracker tracker(TestName(), builder.Build(), {param0});
    382   tracker.ExpectCallSequence({
    383       {kAlloc, tracker.BufferAt(param0, {})},
    384       {kFree, tracker.BufferAt(param0, {})},
    385       {kFinish, nullptr},
    386   });
    387 }
    388 
    389 TEST_F(HeapSimulatorTest, Multiply) {
    390   auto builder = HloComputation::Builder(TestName());
    391   auto paramA = builder.AddInstruction(
    392       HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
    393   auto paramX = builder.AddInstruction(
    394       HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
    395   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
    396       f32vec4_, HloOpcode::kMultiply, paramA, paramX));
    397 
    398   // We must keep all parameters and outputs.
    399   HeapSimulatorTracker tracker(TestName(), builder.Build(),
    400                                {paramA, paramX, mul});
    401   tracker.ExpectCallSequence({
    402       {kAlloc, tracker.BufferAt(paramA, {})},
    403       {kAlloc, tracker.BufferAt(paramX, {})},
    404       {kAlloc, tracker.BufferAt(mul, {})},
    405       // All params and outputs are freed at the end.
    406       {kFree, tracker.BufferAt(paramA, {})},
    407       {kFree, tracker.BufferAt(paramX, {})},
    408       {kFree, tracker.BufferAt(mul, {})},
    409       {kFinish, nullptr},
    410   });
    411 }
    412 
    413 TEST_F(HeapSimulatorTest, MultiplyAdd) {
    414   auto builder = HloComputation::Builder(TestName());
    415   auto paramA = builder.AddInstruction(
    416       HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
    417   auto paramX = builder.AddInstruction(
    418       HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
    419   auto paramY = builder.AddInstruction(
    420       HloInstruction::CreateParameter(2, f32vec4_, "paramY"));
    421   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
    422       f32vec4_, HloOpcode::kMultiply, paramA, paramX));
    423   auto add = builder.AddInstruction(
    424       HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, mul, paramY));
    425 
    426   // The buffer for add is the output, and it's shared with the buffer for mul.
    427   HeapSimulatorTracker tracker(TestName(), builder.Build(),
    428                                {paramA, paramX, mul, paramY, add});
    429   tracker.ExpectCallSequence({
    430       {kAlloc, tracker.BufferAt(paramA, {})},
    431       {kAlloc, tracker.BufferAt(paramX, {})},
    432       {kAlloc, tracker.BufferAt(mul, {})},
    433       {kAlloc, tracker.BufferAt(paramY, {})},
    434       // All params and outputs are freed at the end.
    435       {kFree, tracker.BufferAt(paramA, {})},
    436       {kFree, tracker.BufferAt(paramX, {})},
    437       {kFree, tracker.BufferAt(mul, {})},
    438       {kFree, tracker.BufferAt(paramY, {})},
    439       {kFinish, nullptr},
    440   });
    441   tracker.ExpectSharedBuffers(add, {}, mul, {});
    442 }
    443 
    444 TEST_F(HeapSimulatorTest, BufferReusedOnce) {
    445   HeapSimulatorTracker tracker(TestName());
    446   auto builder = HloComputation::Builder(TestName());
    447 
    448   HloComputation::Builder fusion_builder("fusion");
    449   {
    450     HloComputation::Builder& builder = fusion_builder;
    451     auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
    452         /*parameter_number=*/0, f32vec4_, "A"));
    453     auto exp = builder.AddInstruction(
    454         HloInstruction::CreateUnary(f32vec4_, HloOpcode::kExp, a_param));
    455     auto neg = builder.AddInstruction(
    456         HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, a_param));
    457 
    458     builder.AddInstruction(HloInstruction::CreateTuple({exp, neg}));
    459   }
    460   auto fusion_computation =
    461       tracker.module()->AddEmbeddedComputation(fusion_builder.Build());
    462   auto a_param = builder.AddInstruction(
    463       HloInstruction::CreateParameter(0, f32vec4_, "paramA"));
    464   auto neg = builder.AddInstruction(
    465       HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, a_param));
    466   auto fusion = builder.AddInstruction(HloInstruction::CreateFusion(
    467       ShapeUtil::MakeTupleShape({f32vec4_, f32vec4_}),
    468       HloInstruction::FusionKind::kLoop, {neg}, fusion_computation));
    469   tracker.module()->AddEntryComputation(builder.Build());
    470 
    471   tracker.RunWholeModule({a_param, neg, fusion});
    472 
    473   auto neg_buffer = tracker.OffsetAt(neg, {});
    474   int64 output_buffer_0 = tracker.OffsetAt(fusion, {0});
    475   int64 output_buffer_1 = tracker.OffsetAt(fusion, {1});
    476   // Only one buffer should be shared.
    477   EXPECT_TRUE((neg_buffer == output_buffer_0) ^
    478               (neg_buffer == output_buffer_1));
    479 }
    480 
    481 TEST_F(HeapSimulatorTest, MultiplyDot) {
    482   auto builder = HloComputation::Builder(TestName());
    483   auto paramA = builder.AddInstruction(
    484       HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
    485   auto paramX = builder.AddInstruction(
    486       HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
    487   auto paramY = builder.AddInstruction(
    488       HloInstruction::CreateParameter(2, f32scalar_, "paramY"));
    489   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
    490       f32vec4_, HloOpcode::kMultiply, paramA, paramX));
    491   DotDimensionNumbers dot_dnums;
    492   dot_dnums.add_lhs_contracting_dimensions(1);
    493   dot_dnums.add_rhs_contracting_dimensions(0);
    494   auto dot = builder.AddInstruction(HloInstruction::CreateDot(
    495       f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
    496 
    497   // The buffer for dot is the output, and it cannot be shared with the buffer
    498   // for mul, since dot isn't elementwise.
    499   HeapSimulatorTracker tracker(TestName(), builder.Build(),
    500                                {paramA, paramX, mul, paramY, dot});
    501   tracker.ExpectCallSequence({
    502       {kAlloc, tracker.BufferAt(paramA, {})},
    503       {kAlloc, tracker.BufferAt(paramX, {})},
    504       {kAlloc, tracker.BufferAt(mul, {})},
    505       {kAlloc, tracker.BufferAt(paramY, {})},
    506       {kAlloc, tracker.BufferAt(dot, {})},
    507       // All params and outputs are freed at the end.
    508       {kFree, tracker.BufferAt(paramA, {})},
    509       {kFree, tracker.BufferAt(paramX, {})},
    510       {kFree, tracker.BufferAt(mul, {})},
    511       {kFree, tracker.BufferAt(paramY, {})},
    512       {kFree, tracker.BufferAt(dot, {})},
    513       {kFinish, nullptr},
    514   });
    515 }
    516 
    517 TEST_F(HeapSimulatorTest, MultiplyDotAdd) {
    518   auto builder = HloComputation::Builder(TestName());
    519   auto paramA = builder.AddInstruction(
    520       HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
    521   auto paramX = builder.AddInstruction(
    522       HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
    523   auto paramY = builder.AddInstruction(
    524       HloInstruction::CreateParameter(2, f32scalar_, "paramY"));
    525   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
    526       f32vec4_, HloOpcode::kMultiply, paramA, paramX));
    527   DotDimensionNumbers dot_dnums;
    528   dot_dnums.add_lhs_contracting_dimensions(1);
    529   dot_dnums.add_rhs_contracting_dimensions(0);
    530   auto dot = builder.AddInstruction(HloInstruction::CreateDot(
    531       f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
    532   auto add = builder.AddInstruction(
    533       HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, dot, paramA));
    534 
    535   // The buffer for add is the output, and it's shared with the buffer for dot.
    536   HeapSimulatorTracker tracker(TestName(), builder.Build(),
    537                                {paramA, paramX, mul, paramY, dot, add});
    538   tracker.ExpectCallSequence({
    539       {kAlloc, tracker.BufferAt(paramA, {})},
    540       {kAlloc, tracker.BufferAt(paramX, {})},
    541       {kAlloc, tracker.BufferAt(mul, {})},
    542       {kAlloc, tracker.BufferAt(paramY, {})},
    543       {kAlloc, tracker.BufferAt(dot, {})},
    544       // All params and outputs are freed at the end.
    545       {kFree, tracker.BufferAt(paramA, {})},
    546       {kFree, tracker.BufferAt(paramX, {})},
    547       {kFree, tracker.BufferAt(mul, {})},
    548       {kFree, tracker.BufferAt(paramY, {})},
    549       {kFree, tracker.BufferAt(dot, {})},
    550       {kFinish, nullptr},
    551   });
    552   tracker.ExpectSharedBuffers(add, {}, dot, {});
    553 }
    554 
    555 TEST_F(HeapSimulatorTest, MultiplyDotDot) {
    556   auto builder = HloComputation::Builder(TestName());
    557   auto paramA = builder.AddInstruction(
    558       HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
    559   auto paramX = builder.AddInstruction(
    560       HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
    561   auto paramY = builder.AddInstruction(
    562       HloInstruction::CreateParameter(2, f32scalar_, "paramY"));
    563   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
    564       f32vec4_, HloOpcode::kMultiply, paramA, paramX));
    565   DotDimensionNumbers dot_dnums;
    566   dot_dnums.add_lhs_contracting_dimensions(1);
    567   dot_dnums.add_rhs_contracting_dimensions(0);
    568   auto dot0 = builder.AddInstruction(HloInstruction::CreateDot(
    569       f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
    570   auto dot1 = builder.AddInstruction(HloInstruction::CreateDot(
    571       f32vec4_, dot0, paramY, dot_dnums, DefaultPrecisionConfig(2)));
    572 
    573   // The buffer for dot1 is the output.  No buffers can be shared.  The buffer
    574   // for mul is freed before the end, since it's no longer used after dot0
    575   // finishes.
    576   HeapSimulatorTracker tracker(TestName(), builder.Build(),
    577                                {paramA, paramX, mul, paramY, dot0, dot1});
    578   tracker.ExpectCallSequence({
    579       {kAlloc, tracker.BufferAt(paramA, {})},
    580       {kAlloc, tracker.BufferAt(paramX, {})},
    581       {kAlloc, tracker.BufferAt(mul, {})},
    582       {kAlloc, tracker.BufferAt(paramY, {})},
    583       {kAlloc, tracker.BufferAt(dot0, {})},
    584       {kFree, tracker.BufferAt(mul, {})},  // mul no longer used
    585       {kAlloc, tracker.BufferAt(dot1, {})},
    586       // All params and outputs are freed at the end.
    587       {kFree, tracker.BufferAt(paramA, {})},
    588       {kFree, tracker.BufferAt(paramX, {})},
    589       {kFree, tracker.BufferAt(paramY, {})},
    590       {kFree, tracker.BufferAt(dot0, {})},
    591       {kFree, tracker.BufferAt(dot1, {})},
    592       {kFinish, nullptr},
    593   });
    594 }
    595 
    596 TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) {
    597   auto builder = HloComputation::Builder(TestName());
    598   auto paramA = builder.AddInstruction(
    599       HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
    600   auto paramX = builder.AddInstruction(
    601       HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
    602   auto paramY = builder.AddInstruction(
    603       HloInstruction::CreateParameter(2, f32scalar_, "paramY"));
    604   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
    605       f32vec4_, HloOpcode::kMultiply, paramA, paramX));
    606   DotDimensionNumbers dot_dnums;
    607   dot_dnums.add_lhs_contracting_dimensions(1);
    608   dot_dnums.add_rhs_contracting_dimensions(0);
    609   auto dot0 = builder.AddInstruction(HloInstruction::CreateDot(
    610       f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
    611   auto dot1 = builder.AddInstruction(HloInstruction::CreateDot(
    612       f32vec4_, dot0, paramY, dot_dnums, DefaultPrecisionConfig(2)));
    613   auto tuple =
    614       builder.AddInstruction(HloInstruction::CreateTuple({dot0, dot1}));
    615 
    616   // The buffers for dot0, dot1 and tuple are the output.  No buffers can be
    617   // shared.  The buffer for mul is freed before the end, since it's no longer
    618   // used after dot0 finishes.
    619   HeapSimulatorTracker tracker(
    620       TestName(), builder.Build(),
    621       {paramA, paramX, mul, paramY, dot0, dot1, tuple});
    622   tracker.ExpectCallSequence({
    623       {kAlloc, tracker.BufferAt(paramA, {})},
    624       {kAlloc, tracker.BufferAt(paramX, {})},
    625       {kAlloc, tracker.BufferAt(mul, {})},
    626       {kAlloc, tracker.BufferAt(paramY, {})},
    627       {kAlloc, tracker.BufferAt(dot0, {})},
    628       {kFree, tracker.BufferAt(mul, {})},  // mul no longer used
    629       {kAlloc, tracker.BufferAt(dot1, {})},
    630       {kAlloc, tracker.BufferAt(tuple, {})},
    631       // All params and outputs are freed at the end.
    632       {kFree, tracker.BufferAt(paramA, {})},
    633       {kFree, tracker.BufferAt(paramX, {})},
    634       {kFree, tracker.BufferAt(paramY, {})},
    635       {kFree, tracker.BufferAt(dot0, {})},
    636       {kFree, tracker.BufferAt(dot1, {})},
    637       {kFree, tracker.BufferAt(tuple, {})},
    638       {kFinish, nullptr},
    639   });
    640 }
    641 
    642 TEST_F(HeapSimulatorTest, IndependentTupleElements) {
    643   auto builder = HloComputation::Builder(TestName());
    644   auto paramA = builder.AddInstruction(
    645       HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
    646   auto paramB = builder.AddInstruction(
    647       HloInstruction::CreateParameter(1, f32scalar_, "paramB"));
    648   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
    649       f32scalar_, HloOpcode::kMultiply, paramA, paramB));
    650   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
    651       f32scalar_, HloOpcode::kAdd, paramA, paramB));
    652   auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({mul, add}));
    653   auto element0 = builder.AddInstruction(
    654       HloInstruction::CreateGetTupleElement(f32scalar_, tuple, 0));
    655   auto broadcast = builder.AddInstruction(
    656       HloInstruction::CreateBroadcast(f32vec4_, element0, {0}));
    657   auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
    658       f32scalar_, HloOpcode::kSubtract, paramA, paramB));
    659   auto element1 = builder.AddInstruction(
    660       HloInstruction::CreateGetTupleElement(f32scalar_, tuple, 1));
    661   auto output = builder.AddInstruction(
    662       HloInstruction::CreateTuple({broadcast, sub, element1}));
    663 
    664   HeapSimulatorTracker tracker(TestName(), builder.Build(),
    665                                {paramA, paramB, mul, add, tuple, element0,
    666                                 broadcast, sub, element1, output});
    667   tracker.ExpectCallSequence({
    668       {kAlloc, tracker.BufferAt(paramA, {})},
    669       {kAlloc, tracker.BufferAt(paramB, {})},
    670       {kAlloc, tracker.BufferAt(mul, {})},
    671       {kAlloc, tracker.BufferAt(add, {})},
    672       {kAlloc, tracker.BufferAt(tuple, {})},
    673       {kAlloc, tracker.BufferAt(broadcast, {})},
    674       // The mul can be freed right after the broadcast happens, even though
    675       // The other GetTupleElement is still alive.
    676       {kFree, tracker.BufferAt(mul, {})},
    677       {kAlloc, tracker.BufferAt(sub, {})},
    678       // The temporary tuple is now dead.
    679       {kFree, tracker.BufferAt(tuple, {})},
    680       {kAlloc, tracker.BufferAt(output, {})},
    681       // All params and outputs are freed at the end.
    682       {kFree, tracker.BufferAt(paramA, {})},
    683       {kFree, tracker.BufferAt(paramB, {})},
    684       {kFree, tracker.BufferAt(add, {})},
    685       {kFree, tracker.BufferAt(broadcast, {})},
    686       {kFree, tracker.BufferAt(sub, {})},
    687       {kFree, tracker.BufferAt(output, {})},
    688       {kFinish, nullptr},
    689   });
    690 }
    691 
    692 TEST_F(HeapSimulatorTest, WholeModule) {
    693   HeapSimulatorTracker tracker(TestName());
    694 
    695   const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
    696   const Shape tuple_shape =
    697       ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape});
    698 
    699   auto cond_builder = HloComputation::Builder("WhileCond");
    700   HloInstruction* cond_param = cond_builder.AddInstruction(
    701       HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
    702   HloInstruction* cond_iter = cond_builder.AddInstruction(
    703       HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0));
    704   HloInstruction* cond_data = cond_builder.AddInstruction(
    705       HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
    706   HloInstruction* cond_lt = cond_builder.AddInstruction(
    707       HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
    708                                     cond_data, ComparisonDirection::kLt));
    709   HloComputation* cond_computation =
    710       tracker.module()->AddEmbeddedComputation(cond_builder.Build());
    711 
    712   auto body_builder = HloComputation::Builder("WhileBody");
    713   HloInstruction* body_param = body_builder.AddInstruction(
    714       HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
    715   HloComputation* body_computation =
    716       tracker.module()->AddEmbeddedComputation(body_builder.Build());
    717 
    718   auto builder = HloComputation::Builder(TestName());
    719   HloInstruction* param = builder.AddInstruction(
    720       HloInstruction::CreateParameter(0, tuple_shape, "param"));
    721   HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
    722       tuple_shape, cond_computation, body_computation, param));
    723   tracker.module()->AddEntryComputation(builder.Build());
    724 
    725   tracker.RunWholeModule(
    726       {param, while_op, body_param, cond_param, cond_iter, cond_data, cond_lt});
    727   tracker.ExpectCallSequence({
    728       // The entry computation param and while_op are allocated first.
    729       {kAlloc, tracker.BufferAt(param, {})},
    730       {kAlloc, tracker.BufferAt(param, {0})},
    731       {kAlloc, tracker.BufferAt(param, {1})},
    732       {kAlloc, tracker.BufferAt(while_op, {})},
    733       {kAlloc, tracker.BufferAt(while_op, {0})},
    734       {kAlloc, tracker.BufferAt(while_op, {1})},
    735 
    736       // Now the while body param is allocated and freed.
    737       {kAlloc, tracker.BufferAt(body_param, {})},
    738       {kAlloc, tracker.BufferAt(body_param, {0})},
    739       {kAlloc, tracker.BufferAt(body_param, {1})},
    740       {kFree, tracker.BufferAt(body_param, {})},
    741       {kFree, tracker.BufferAt(body_param, {0})},
    742       {kFree, tracker.BufferAt(body_param, {1})},
    743 
    744       // Now the while cond param is allocated. The GTE instructions just alias
    745       // the param elements, so the param tuple can immediately be freed.
    746       {kAlloc, tracker.BufferAt(cond_param, {})},
    747       {kAlloc, tracker.BufferAt(cond_param, {0})},
    748       {kAlloc, tracker.BufferAt(cond_param, {1})},
    749       {kFree, tracker.BufferAt(cond_param, {})},
    750 
    751       // Now the final cond less-than buffer is allocated.
    752       {kAlloc, tracker.BufferAt(cond_lt, {})},
    753 
    754       // The order of the remaining Free calls is based on the BufferValue.id,
    755       // which is deterministic, but not obvious.
    756       {kFree, tracker.BufferAt(param, {})},
    757       {kFree, tracker.BufferAt(param, {0})},
    758       {kFree, tracker.BufferAt(param, {1})},
    759 
    760       {kFree, tracker.BufferAt(while_op, {})},
    761       {kFree, tracker.BufferAt(while_op, {0})},
    762       {kFree, tracker.BufferAt(while_op, {1})},
    763 
    764       {kFree, tracker.BufferAt(cond_param, {0})},
    765       {kFree, tracker.BufferAt(cond_param, {1})},
    766       {kFree, tracker.BufferAt(cond_lt, {})},
    767 
    768       {kFinish, nullptr},
    769   });
    770 }
    771 
    772 // Base class for heap algorithm tests.
    773 class HeapAlgorithmTestBase : public ::testing::Test {
    774  protected:
    775   HeapAlgorithmTestBase() : builder_("heap_simulator_test") {
    776     buffer_a_ = DummyBufferValue();
    777     buffer_b_ = DummyBufferValue();
    778     buffer_c_ = DummyBufferValue();
    779     buffer_d_ = DummyBufferValue();
    780     buffer_e_ = DummyBufferValue();
    781     buffer_f_ = DummyBufferValue();
    782     buffer_g_ = DummyBufferValue();
    783     buffer_h_ = DummyBufferValue();
    784     buffer_i_ = DummyBufferValue();
    785   }
    786   ~HeapAlgorithmTestBase() override {}
    787 
    788   const BufferValue* buffer_a_;
    789   const BufferValue* buffer_b_;
    790   const BufferValue* buffer_c_;
    791   const BufferValue* buffer_d_;
    792   const BufferValue* buffer_e_;
    793   const BufferValue* buffer_f_;
    794   const BufferValue* buffer_g_;
    795   const BufferValue* buffer_h_;
    796   const BufferValue* buffer_i_;
    797 
    798  private:
    799   // Create a dummy BufferValue to pass to the heap algorithm.
    800   const BufferValue* DummyBufferValue() {
    801     const BufferValue::Id id = buffers_.size();
    802     auto const0 = builder_.AddInstruction(
    803         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
    804     buffers_.emplace_back(
    805         absl::make_unique<HloValue>(id, const0, ShapeIndex{}));
    806     return buffers_.back().get();
    807   }
    808 
    809   HloComputation::Builder builder_;
    810   std::vector<std::unique_ptr<BufferValue>> buffers_;
    811 };
    812 
    813 class NoFragmentationStatsHeapTest : public HeapAlgorithmTestBase {};
    814 
    815 TEST_F(NoFragmentationStatsHeapTest, Empty) {
    816   NoFragmentationStatsHeap heap;
    817   EXPECT_EQ(0, heap.Finish().heap_size);
    818 }
    819 
    820 TEST_F(NoFragmentationStatsHeapTest, Simple) {
    821   NoFragmentationStatsHeap heap;
    822   heap.Alloc(buffer_a_, 10);
    823   heap.Alloc(buffer_b_, 20);
    824   heap.Alloc(buffer_c_, 30);
    825   heap.Alloc(buffer_d_, 30);
    826   heap.Free(buffer_a_, 10);
    827   heap.Free(buffer_b_, 20);
    828   heap.Free(buffer_c_, 30);
    829   heap.Free(buffer_d_, 30);
    830   EXPECT_EQ(90, heap.Finish().heap_size);
    831 }
    832 
    833 TEST_F(NoFragmentationStatsHeapTest, Mixed) {
    834   NoFragmentationStatsHeap heap;
    835   heap.Alloc(buffer_a_, 10);  // max: A
    836 
    837   heap.Alloc(buffer_b_, 20);  // max: A+B
    838   heap.Free(buffer_b_, 20);
    839 
    840   heap.Alloc(buffer_c_, 30);  // max: A+C
    841   heap.Free(buffer_c_, 30);
    842 
    843   heap.Alloc(buffer_d_, 5);  // max: A+C
    844   heap.Free(buffer_d_, 5);
    845 
    846   heap.Free(buffer_a_, 10);
    847   EXPECT_EQ(40, heap.Finish().heap_size);
    848 }
    849 
    850 class DecreasingSizeRunsHeapTest : public HeapAlgorithmTestBase {};
    851 
    852 TEST_F(DecreasingSizeRunsHeapTest, Empty) {
    853   CallSequence call_sequence;
    854   DecreasingSizeRunsHeap heap(
    855       absl::make_unique<HeapCallRecorder>(&call_sequence));
    856   heap.Finish();
    857   EXPECT_EQ(call_sequence, CallSequence({
    858                                {kFinish, nullptr},
    859                            }));
    860 }
    861 
    862 TEST_F(DecreasingSizeRunsHeapTest, Simple) {
    863   CallSequence call_sequence;
    864   DecreasingSizeRunsHeap heap(
    865       absl::make_unique<HeapCallRecorder>(&call_sequence));
    866   heap.Alloc(buffer_a_, 10);
    867   heap.Alloc(buffer_b_, 20);
    868   heap.Alloc(buffer_c_, 30);
    869   heap.Alloc(buffer_d_, 30);
    870   heap.Free(buffer_a_, 10);
    871   heap.Free(buffer_b_, 20);
    872   heap.Free(buffer_c_, 30);
    873   heap.Free(buffer_d_, 30);
    874   heap.Finish();
    875   // Runs of Allocs and Frees are sorted by decreasing size, with buffer id
    876   // tiebreaker.
    877   EXPECT_EQ(call_sequence, CallSequence({
    878                                {kAlloc, buffer_c_},
    879                                {kAlloc, buffer_d_},
    880                                {kAlloc, buffer_b_},
    881                                {kAlloc, buffer_a_},
    882                                {kFree, buffer_c_},
    883                                {kFree, buffer_d_},
    884                                {kFree, buffer_b_},
    885                                {kFree, buffer_a_},
    886                                {kFinish, nullptr},
    887                            }));
    888 }
    889 
    890 TEST_F(DecreasingSizeRunsHeapTest, Mixed) {
    891   CallSequence call_sequence;
    892   DecreasingSizeRunsHeap heap(
    893       absl::make_unique<HeapCallRecorder>(&call_sequence));
    894   heap.Alloc(buffer_a_, 10);
    895   heap.Alloc(buffer_b_, 20);
    896   heap.Free(buffer_b_, 20);
    897 
    898   heap.Alloc(buffer_c_, 30);
    899   heap.Free(buffer_c_, 30);
    900 
    901   heap.Alloc(buffer_d_, 5);
    902   heap.Free(buffer_d_, 5);
    903   heap.Free(buffer_a_, 10);
    904   heap.Finish();
    905   // Runs of Allocs and Frees are sorted by decreasing size.
    906   EXPECT_EQ(call_sequence, CallSequence({
    907                                {kAlloc, buffer_b_},
    908                                {kAlloc, buffer_a_},
    909                                {kFree, buffer_b_},
    910 
    911                                {kAlloc, buffer_c_},
    912                                {kFree, buffer_c_},
    913 
    914                                {kAlloc, buffer_d_},
    915                                {kFree, buffer_a_},
    916                                {kFree, buffer_d_},
    917                                {kFinish, nullptr},
    918                            }));
    919 }
    920 
    921 class LazyBestFitHeapTest : public HeapAlgorithmTestBase {};
    922 
    923 TEST_F(LazyBestFitHeapTest, Empty) {
    924   LazyBestFitHeap heap(/*alignment=*/1);
    925   const HeapSimulator::Result result = heap.Finish();
    926   EXPECT_EQ(0, result.heap_size);
    927   EXPECT_EQ(0, result.chunk_map.size());
    928 }
    929 
    930 TEST_F(LazyBestFitHeapTest, Simple) {
    931   LazyBestFitHeap heap(/*alignment=*/1);
    932   heap.Alloc(buffer_a_, 10);
    933   heap.Alloc(buffer_b_, 20);
    934   heap.Alloc(buffer_c_, 30);
    935   heap.Alloc(buffer_d_, 30);
    936   heap.Free(buffer_a_, 10);
    937   heap.Free(buffer_b_, 20);
    938   heap.Free(buffer_c_, 30);
    939   heap.Free(buffer_d_, 30);
    940 
    941   const HeapSimulator::Result result = heap.Finish();
    942   EXPECT_EQ(90, result.heap_size);
    943   EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
    944   EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
    945   EXPECT_EQ(30, result.chunk_map.at(buffer_c_).size);
    946   EXPECT_EQ(30, result.chunk_map.at(buffer_d_).size);
    947 
    948   EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset);
    949   EXPECT_EQ(10, result.chunk_map.at(buffer_b_).offset);
    950   EXPECT_EQ(30, result.chunk_map.at(buffer_c_).offset);
    951   EXPECT_EQ(60, result.chunk_map.at(buffer_d_).offset);
    952 }
    953 
    954 TEST_F(LazyBestFitHeapTest, Mixed) {
    955   LazyBestFitHeap heap(/*alignment=*/1);
    956   heap.Alloc(buffer_a_, 10);  // A lazy offset
    957 
    958   heap.Alloc(buffer_b_, 20);  // B lazy offset
    959   heap.Free(buffer_b_, 20);   // B range = [0, 20)  free = [0, 20)
    960 
    961   heap.Alloc(buffer_c_, 30);  // C range = [0, 30)
    962   heap.Free(buffer_c_, 30);   //                    free = [0, 30)
    963 
    964   heap.Alloc(buffer_d_, 5);  // D range = [0, 5)   free = [5, 30)
    965   heap.Free(buffer_d_, 5);   //                    free = [0, 30)
    966 
    967   heap.Free(buffer_a_, 10);  // A range = [30, 10) free = [0, 40)
    968 
    969   const HeapSimulator::Result result = heap.Finish();
    970   EXPECT_EQ(40, result.heap_size);
    971   EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
    972   EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
    973   EXPECT_EQ(30, result.chunk_map.at(buffer_c_).size);
    974   EXPECT_EQ(5, result.chunk_map.at(buffer_d_).size);
    975 
    976   EXPECT_EQ(30, result.chunk_map.at(buffer_a_).offset);
    977   EXPECT_EQ(0, result.chunk_map.at(buffer_b_).offset);
    978   EXPECT_EQ(0, result.chunk_map.at(buffer_c_).offset);
    979   EXPECT_EQ(0, result.chunk_map.at(buffer_d_).offset);
    980 }
    981 
    982 TEST_F(LazyBestFitHeapTest, BestFit) {
    983   LazyBestFitHeap heap(/*alignment=*/1);
    984 
    985   // First alloc/free buffer_a_, to force a big free chunk to appear.
    986   heap.Alloc(buffer_a_, 200);  // A lazy offset
    987   heap.Free(buffer_a_, 200);   // A range = [0, 200)   free = [0, 200)
    988 
    989   // Now alloc a bunch of buffers that are allocated out of the free chunk.
    990   heap.Alloc(buffer_b_, 30);  // B range = [0, 30)    free = [30, 200)
    991   heap.Alloc(buffer_c_, 30);  // C range = [30, 60)   free = [60, 200)
    992   heap.Alloc(buffer_d_, 20);  // D range = [60, 80)   free = [80, 200)
    993   heap.Alloc(buffer_e_, 20);  // E range = [80, 100)  free = [100, 200)
    994   heap.Alloc(buffer_f_, 10);  // F range = [100, 110) free = [110, 200)
    995   heap.Alloc(buffer_g_, 10);  // G range = [110, 120) free = [120, 200)
    996   heap.Alloc(buffer_h_, 80);  // H range = [120, 200)
    997 
    998   // Free buffers to create free chunks of different sizes.
    999   heap.Free(buffer_c_, 30);  // free = [30, 60)
   1000   heap.Free(buffer_e_, 20);  // free = [30, 60), [80, 100)
   1001   heap.Free(buffer_g_, 10);  // free = [30, 60), [80, 100), [110, 120)
   1002 
   1003   // The best fit is picked out of the existing free chunks.
   1004   heap.Alloc(buffer_i_, 15);  // I range = [80, 95)
   1005 
   1006   // The frees here ensure the buffer-coalescing logic is exercised.
   1007   heap.Free(buffer_b_, 30);
   1008   heap.Free(buffer_d_, 20);
   1009   heap.Free(buffer_f_, 10);
   1010   heap.Free(buffer_h_, 80);
   1011   heap.Free(buffer_i_, 15);
   1012 
   1013   const HeapSimulator::Result result = heap.Finish();
   1014   EXPECT_EQ(200, result.heap_size);
   1015   EXPECT_EQ(200, result.chunk_map.at(buffer_a_).size);
   1016   EXPECT_EQ(30, result.chunk_map.at(buffer_b_).size);
   1017   EXPECT_EQ(30, result.chunk_map.at(buffer_c_).size);
   1018   EXPECT_EQ(20, result.chunk_map.at(buffer_d_).size);
   1019   EXPECT_EQ(20, result.chunk_map.at(buffer_e_).size);
   1020   EXPECT_EQ(10, result.chunk_map.at(buffer_f_).size);
   1021   EXPECT_EQ(10, result.chunk_map.at(buffer_g_).size);
   1022   EXPECT_EQ(80, result.chunk_map.at(buffer_h_).size);
   1023   EXPECT_EQ(15, result.chunk_map.at(buffer_i_).size);
   1024 
   1025   EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset);
   1026   EXPECT_EQ(0, result.chunk_map.at(buffer_b_).offset);
   1027   EXPECT_EQ(30, result.chunk_map.at(buffer_c_).offset);
   1028   EXPECT_EQ(60, result.chunk_map.at(buffer_d_).offset);
   1029   EXPECT_EQ(80, result.chunk_map.at(buffer_e_).offset);
   1030   EXPECT_EQ(100, result.chunk_map.at(buffer_f_).offset);
   1031   EXPECT_EQ(110, result.chunk_map.at(buffer_g_).offset);
   1032   EXPECT_EQ(120, result.chunk_map.at(buffer_h_).offset);
   1033   EXPECT_EQ(80, result.chunk_map.at(buffer_i_).offset);
   1034 }
   1035 
   1036 TEST_F(LazyBestFitHeapTest, Lazy) {
   1037   LazyBestFitHeap heap(/*alignment=*/1);
   1038 
   1039   // First alloc some buffers, which are all lazily allocated offsets.
   1040   heap.Alloc(buffer_a_, 10);
   1041   heap.Alloc(buffer_b_, 5);
   1042   heap.Alloc(buffer_c_, 10);
   1043 
   1044   // Now free some buffers, which forces offset assignment.
   1045   heap.Free(buffer_a_, 10);  // A range = [0, 10)  free = [0, 10)
   1046   heap.Free(buffer_c_, 10);  // C range = [10, 20) free = [0, 20)
   1047 
   1048   // If we hadn't lazily assigned offsets, the free chunk wouldn't be large
   1049   // enough to hold the entire allocation.
   1050   heap.Alloc(buffer_d_, 20);  // D range = [0, 20)
   1051 
   1052   heap.Free(buffer_b_, 5);  // B range = [20, 25)
   1053   heap.Free(buffer_d_, 20);
   1054 
   1055   const HeapSimulator::Result result = heap.Finish();
   1056   EXPECT_EQ(25, result.heap_size);
   1057   EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
   1058   EXPECT_EQ(5, result.chunk_map.at(buffer_b_).size);
   1059   EXPECT_EQ(10, result.chunk_map.at(buffer_c_).size);
   1060   EXPECT_EQ(20, result.chunk_map.at(buffer_d_).size);
   1061 
   1062   EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset);
   1063   EXPECT_EQ(20, result.chunk_map.at(buffer_b_).offset);
   1064   EXPECT_EQ(10, result.chunk_map.at(buffer_c_).offset);
   1065   EXPECT_EQ(0, result.chunk_map.at(buffer_d_).offset);
   1066 }
   1067 
   1068 TEST_F(LazyBestFitHeapTest, ReuseLastFreeChunk) {
   1069   LazyBestFitHeap heap(/*alignment=*/1);
   1070 
   1071   // First alloc/free buffer_a_, to force a big free chunk to appear.
   1072   heap.Alloc(buffer_a_, 60);  // A lazy offset
   1073   heap.Free(buffer_a_, 60);   // A range = [0, 60)   free = [0, 60)
   1074 
   1075   // Now alloc a bunch of buffers that are allocated out of the free chunk.
   1076   heap.Alloc(buffer_b_, 10);  // B range = [0, 10)    free = [10, 60)
   1077   heap.Alloc(buffer_c_, 20);  // C range = [10, 30)   free = [30, 60)
   1078   heap.Alloc(buffer_d_, 30);  // D range = [30, 60)
   1079 
   1080   // Free buffers to create free chunks of different sizes.
   1081   heap.Free(buffer_b_, 10);  // free = [0, 10)
   1082   heap.Free(buffer_d_, 30);  // free = [0, 10), [30, 60)
   1083 
   1084   // No free chunks are large enough, but the last free chunk is adjacent to the
   1085   // end of the heap, so we re-use that chunk.
   1086   heap.Alloc(buffer_e_, 40);  // E range = [30, 70)
   1087 
   1088   heap.Free(buffer_c_, 20);
   1089   heap.Free(buffer_e_, 40);
   1090 
   1091   const HeapSimulator::Result result = heap.Finish();
   1092   EXPECT_EQ(70, result.heap_size);
   1093   EXPECT_EQ(60, result.chunk_map.at(buffer_a_).size);
   1094   EXPECT_EQ(10, result.chunk_map.at(buffer_b_).size);
   1095   EXPECT_EQ(20, result.chunk_map.at(buffer_c_).size);
   1096   EXPECT_EQ(30, result.chunk_map.at(buffer_d_).size);
   1097   EXPECT_EQ(40, result.chunk_map.at(buffer_e_).size);
   1098 
   1099   EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset);
   1100   EXPECT_EQ(0, result.chunk_map.at(buffer_b_).offset);
   1101   EXPECT_EQ(10, result.chunk_map.at(buffer_c_).offset);
   1102   EXPECT_EQ(30, result.chunk_map.at(buffer_d_).offset);
   1103   EXPECT_EQ(30, result.chunk_map.at(buffer_e_).offset);
   1104 }
   1105 
   1106 TEST_F(LazyBestFitHeapTest, Alignment) {
   1107   LazyBestFitHeap heap(/*alignment=*/64);
   1108 
   1109   // First alloc some buffers, which are all lazily allocated offsets.
   1110   heap.Alloc(buffer_a_, 10);
   1111   heap.Alloc(buffer_b_, 5);
   1112   heap.Alloc(buffer_c_, 10);
   1113 
   1114   // Now free some buffers, which forces offset assignment with alignment.
   1115   heap.Free(buffer_a_, 10);  //  A range = [0, 10)    free = [0, 10)
   1116   heap.Free(buffer_c_, 10);  //  C range = [64, 74)   free = [0, 74)
   1117 
   1118   // If we hadn't lazily assigned offsets, and accounted for alignment, the free
   1119   // chunk wouldn't be large enough to hold the entire allocation.
   1120   heap.Alloc(buffer_d_, 74);  // D range = [0, 74)    free = [)
   1121 
   1122   heap.Free(buffer_b_, 5);    // B range = [128, 133) free = [74, 133)
   1123   heap.Alloc(buffer_e_, 23);  // E range = [128, 151) free = [74, 128)
   1124 
   1125   heap.Free(buffer_d_, 74);  //                       free = [0, 128)
   1126   heap.Free(buffer_e_, 23);  //                       free = [0, 151)
   1127 
   1128   const HeapSimulator::Result result = heap.Finish();
   1129   EXPECT_EQ(151, result.heap_size);
   1130   EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
   1131   EXPECT_EQ(5, result.chunk_map.at(buffer_b_).size);
   1132   EXPECT_EQ(10, result.chunk_map.at(buffer_c_).size);
   1133   EXPECT_EQ(74, result.chunk_map.at(buffer_d_).size);
   1134   EXPECT_EQ(23, result.chunk_map.at(buffer_e_).size);
   1135 
   1136   EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset);
   1137   EXPECT_EQ(128, result.chunk_map.at(buffer_b_).offset);
   1138   EXPECT_EQ(64, result.chunk_map.at(buffer_c_).offset);
   1139   EXPECT_EQ(0, result.chunk_map.at(buffer_d_).offset);
   1140   EXPECT_EQ(128, result.chunk_map.at(buffer_e_).offset);
   1141 }
   1142 
   1143 class GlobalDecreasingSizeBestFitHeapTest : public HeapAlgorithmTestBase {};
   1144 
   1145 TEST_F(GlobalDecreasingSizeBestFitHeapTest, Empty) {
   1146   GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1);
   1147   const HeapSimulator::Result result = heap.Finish();
   1148   EXPECT_EQ(0, result.heap_size);
   1149   EXPECT_EQ(0, result.chunk_map.size());
   1150 }
   1151 
   1152 TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSize) {
   1153   // space
   1154   //   ^
   1155   //   |  +---a---+
   1156   //   |      +-------+
   1157   //   |      +---c---+
   1158   //   |    +-------+
   1159   //   |    |   b   |
   1160   //   |    +-------+
   1161   //   |         +-------+
   1162   //   |         |       |
   1163   //   |         |   d   |
   1164   //   |         +-------+
   1165   //   -----------------> time
   1166   GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1);
   1167   heap.Alloc(buffer_a_, 10);
   1168   heap.Alloc(buffer_b_, 30);
   1169   heap.Alloc(buffer_c_, 20);
   1170   heap.Alloc(buffer_d_, 40);
   1171   heap.Free(buffer_a_, 10);
   1172   heap.Free(buffer_b_, 30);
   1173   heap.Free(buffer_c_, 20);
   1174   heap.Free(buffer_d_, 40);
   1175 
   1176   const HeapSimulator::Result result = heap.Finish();
   1177   EXPECT_EQ(100, result.heap_size);
   1178   EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
   1179   EXPECT_EQ(30, result.chunk_map.at(buffer_b_).size);
   1180   EXPECT_EQ(20, result.chunk_map.at(buffer_c_).size);
   1181   EXPECT_EQ(40, result.chunk_map.at(buffer_d_).size);
   1182 
   1183   EXPECT_EQ(90, result.chunk_map.at(buffer_a_).offset);
   1184   EXPECT_EQ(40, result.chunk_map.at(buffer_b_).offset);
   1185   EXPECT_EQ(70, result.chunk_map.at(buffer_c_).offset);
   1186   EXPECT_EQ(0, result.chunk_map.at(buffer_d_).offset);
   1187 }
   1188 
   1189 TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSizeWithAlignment) {
   1190   // space
   1191   //   ^
   1192   //   |      +-------+
   1193   //   |      +---b---+
   1194   //   |            +-------+
   1195   //   |            |       |
   1196   //   |            |   d   |
   1197   //   |  +---a---+ +-------+
   1198   //   |
   1199   //   |         +-------+
   1200   //   |         |       |
   1201   //   |         |   c   |
   1202   //   |         |       |
   1203   //   |         +-------+
   1204   //   ---------------------> time
   1205   GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/20);
   1206   heap.Alloc(buffer_a_, 10);
   1207   heap.Alloc(buffer_b_, 20);
   1208   heap.Alloc(buffer_c_, 50);
   1209   heap.Free(buffer_a_, 10);
   1210   heap.Alloc(buffer_d_, 40);
   1211   heap.Free(buffer_b_, 20);
   1212   heap.Free(buffer_c_, 50);
   1213   heap.Free(buffer_d_, 40);
   1214 
   1215   const HeapSimulator::Result result = heap.Finish();
   1216   EXPECT_EQ(120, result.heap_size);
   1217   EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
   1218   EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
   1219   EXPECT_EQ(50, result.chunk_map.at(buffer_c_).size);
   1220   EXPECT_EQ(40, result.chunk_map.at(buffer_d_).size);
   1221 
   1222   EXPECT_EQ(60, result.chunk_map.at(buffer_a_).offset);
   1223   EXPECT_EQ(100, result.chunk_map.at(buffer_b_).offset);
   1224   EXPECT_EQ(0, result.chunk_map.at(buffer_c_).offset);
   1225   EXPECT_EQ(60, result.chunk_map.at(buffer_d_).offset);
   1226 }
   1227 
   1228 TEST_F(GlobalDecreasingSizeBestFitHeapTest, BestFit) {
   1229   // space
   1230   //   ^
   1231   //   |    +-------+
   1232   //   |    +---b---+
   1233   //   |         +-------+
   1234   //   |         |   d   |
   1235   //   | +--a--+ +-------+
   1236   //   |      +-------+
   1237   //   |      |       |
   1238   //   |      |   c   |
   1239   //   |      +-------+
   1240   //   |           +-------+
   1241   //   |           |       |
   1242   //   |           |   e   |
   1243   //   |           |       |
   1244   //   |           +-------+
   1245   //   ---------------------> time
   1246   GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1);
   1247   heap.Alloc(buffer_a_, 10);
   1248   heap.Alloc(buffer_b_, 20);
   1249   heap.Alloc(buffer_c_, 40);
   1250   heap.Free(buffer_a_, 10);
   1251   heap.Alloc(buffer_d_, 30);
   1252   heap.Alloc(buffer_e_, 50);
   1253   heap.Free(buffer_b_, 20);
   1254   heap.Free(buffer_c_, 40);
   1255   heap.Free(buffer_d_, 30);
   1256   heap.Free(buffer_e_, 50);
   1257 
   1258   const HeapSimulator::Result result = heap.Finish();
   1259   EXPECT_EQ(140, result.heap_size);
   1260   EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
   1261   EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
   1262   EXPECT_EQ(40, result.chunk_map.at(buffer_c_).size);
   1263   EXPECT_EQ(30, result.chunk_map.at(buffer_d_).size);
   1264   EXPECT_EQ(50, result.chunk_map.at(buffer_e_).size);
   1265 
   1266   EXPECT_EQ(90, result.chunk_map.at(buffer_a_).offset);
   1267   EXPECT_EQ(120, result.chunk_map.at(buffer_b_).offset);
   1268   EXPECT_EQ(50, result.chunk_map.at(buffer_c_).offset);
   1269   EXPECT_EQ(90, result.chunk_map.at(buffer_d_).offset);
   1270   EXPECT_EQ(0, result.chunk_map.at(buffer_e_).offset);
   1271 }
   1272 
   1273 }  // namespace
   1274 }  // namespace xla
   1275