Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/heap_simulator.h"
     17 
     18 #include <memory>
     19 #include <utility>
     20 #include <vector>
     21 
     22 #include "tensorflow/compiler/xla/literal_util.h"
     23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     25 #include "tensorflow/compiler/xla/service/hlo_module.h"
     26 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
     27 #include "tensorflow/compiler/xla/service/logical_buffer.h"
     28 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
     29 #include "tensorflow/compiler/xla/status_macros.h"
     30 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     31 #include "tensorflow/core/lib/gtl/flatmap.h"
     32 
     33 namespace xla {
     34 namespace {
     35 
     36 const char kAlloc[] = "Alloc";
     37 const char kFree[] = "Free";
     38 const char kFinish[] = "Finish";
     39 
     40 // CallSequence records a sequence of Alloc/Free/Finish calls.
     41 using CallSequence = std::vector<std::pair<string, const LogicalBuffer*>>;
     42 
     43 // HeapCallRecorder is a dummy heap algorithm that simply records its calls.
     44 class HeapCallRecorder : public HeapAlgorithm {
     45  public:
     46   explicit HeapCallRecorder(CallSequence* calls) : calls_(calls) {}
     47   ~HeapCallRecorder() override {}
     48 
     49   void Alloc(const LogicalBuffer* buffer, int64 size) override {
     50     calls_->emplace_back(kAlloc, buffer);
     51     // Instead of assigning a real offset, we set the cardinality of the Alloc
     52     // call.  This isn't a valid assignment, but allows us to easily test for
     53     // buffer sharing.
     54     const int64 offset = result_.chunk_map.size();
     55     result_.chunk_map.emplace(buffer, Chunk{offset, size});
     56   }
     57   void Free(const LogicalBuffer* buffer, int64 size) override {
     58     calls_->emplace_back(kFree, buffer);
     59   }
     60   Result Finish() override {
     61     calls_->emplace_back(kFinish, nullptr);
     62     return result_;
     63   }
     64 
     65  private:
     66   CallSequence* calls_;
     67   Result result_;
     68 };
     69 
     70 // HeapSimulatorTracker runs the heap simulator, recording the sequence of calls
     71 // made to the underlying heap algorithm.  Tests compare the actual call
     72 // sequence against an expected sequence.
     73 class HeapSimulatorTracker {
     74  public:
     75   // Constructor for testing a single entry computation.
     76   HeapSimulatorTracker(
     77       const string& name, std::unique_ptr<HloComputation> computation,
     78       const std::vector<const HloInstruction*>& instruction_sequence) {
     79     module_ = MakeUnique<HloModule>(name);
     80     module_->AddEntryComputation(std::move(computation));
     81     points_to_analysis_ =
     82         TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
     83     // Since we're only tracking the sequence of Alloc/Free calls, the actual
     84     // size of the buffers doesn't matter, so we always return 0.  We rely on
     85     // the secondary sorting criteria of DecreasingSizeRunsHeap to sort calls by
     86     // buffer id, for determinism in the tests.
     87     auto zero_size = [](const LogicalBuffer& buffer) { return 0; };
     88     auto algorithm = MakeUnique<DecreasingSizeRunsHeap>(
     89         MakeUnique<HeapCallRecorder>(&actual_calls_));
     90     result_ = HeapSimulator::Run(
     91                   std::move(algorithm), *module_->entry_computation(),
     92                   instruction_sequence, *points_to_analysis_, zero_size)
     93                   .ConsumeValueOrDie();
     94   }
     95 
     96   explicit HeapSimulatorTracker(const string& name) {
     97     module_ = MakeUnique<HloModule>(name);
     98   }
     99 
    100   // Similar to the single entry computation constructor above, but runs the
    101   // simulation over the entire module.
    102   void RunWholeModule(
    103       const std::vector<const HloInstruction*>& full_module_sequence) {
    104     points_to_analysis_ =
    105         TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
    106 
    107     // Construct the module sequence grouped by computation.
    108     SequentialHloOrdering::HloModuleSequence module_sequence;
    109     tensorflow::gtl::FlatMap<const HloInstruction*, int> reverse_position;
    110     for (int i = 0; i < full_module_sequence.size(); ++i) {
    111       const HloInstruction* instruction = full_module_sequence[i];
    112       module_sequence[instruction->parent()].push_back(instruction);
    113       reverse_position[instruction] = full_module_sequence.size() - i;
    114     }
    115 
    116     // Hack the size_fn so that it returns a decreasing value as we step through
    117     // the sequence. This lets us ensure the Alloc calls are in the sequence
    118     // order. The Free calls are sorted by LogicalBuffer.id, which is at least
    119     // deterministic.
    120     auto size_fn = [&reverse_position](const LogicalBuffer& buffer) {
    121       return reverse_position[buffer.instruction()];
    122     };
    123     auto algorithm = MakeUnique<DecreasingSizeRunsHeap>(
    124         MakeUnique<HeapCallRecorder>(&actual_calls_));
    125     result_ = HeapSimulator::Run(std::move(algorithm), *module_,
    126                                  module_sequence, *points_to_analysis_, size_fn)
    127                   .ConsumeValueOrDie();
    128   }
    129 
    130   HloModule* module() { return module_.get(); }
    131 
    132   // Returns the buffer defined at the given instruction and index.
    133   const LogicalBuffer* BufferAt(const HloInstruction* instruction,
    134                                 const ShapeIndex& index) const {
    135     return points_to_analysis_->GetBufferDefinedAt(instruction, index)
    136         .ConsumeValueOrDie();
    137   }
    138 
    139   // Ensures the expected sequence of Alloc/Free/Finish calls was performed.
    140   void ExpectCallSequence(const CallSequence& expected) const {
    141     EXPECT_EQ(expected, actual_calls_);
    142   }
    143 
    144   // Ensures the buffers defined by the respective (instruction,index) pairs are
    145   // shared, relying on the unique offsets assigned in HeapCallRecorder::Alloc.
    146   void ExpectSharedBuffers(const HloInstruction* instruction_a,
    147                            const ShapeIndex& index_a,
    148                            const HloInstruction* instruction_b,
    149                            const ShapeIndex& index_b) {
    150     const LogicalBuffer* a = BufferAt(instruction_a, index_a);
    151     const LogicalBuffer* b = BufferAt(instruction_b, index_b);
    152     EXPECT_EQ(result_.chunk_map[a].offset, result_.chunk_map[b].offset)
    153         << *a << ", " << *b;
    154   }
    155 
    156  private:
    157   std::unique_ptr<HloModule> module_;
    158   std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
    159   CallSequence actual_calls_;
    160   HeapSimulator::Result result_;
    161 };
    162 
    163 class HeapSimulatorTest : public HloTestBase {
    164  protected:
    165   HeapSimulatorTest() {}
    166   ~HeapSimulatorTest() override {}
    167 
    168   // Shapes for use in the examples.
    169   Shape f32scalar_ = ShapeUtil::MakeShape(xla::F32, {});
    170   Shape f32vec4_ = ShapeUtil::MakeShape(F32, {4});
    171 };
    172 
    173 TEST_F(HeapSimulatorTest, ScalarConstant) {
    174   auto builder = HloComputation::Builder(TestName());
    175   auto const0 = builder.AddInstruction(
    176       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    177 
    178   // Constants aren't assigned.  See b/32248867
    179   HeapSimulatorTracker tracker(TestName(), builder.Build(), {const0});
    180   tracker.ExpectCallSequence({{kFinish, nullptr}});
    181 }
    182 
    183 TEST_F(HeapSimulatorTest, OneParam) {
    184   auto builder = HloComputation::Builder(TestName());
    185   auto param0 = builder.AddInstruction(
    186       HloInstruction::CreateParameter(0, f32scalar_, "param0"));
    187 
    188   // A single parameter which is also the output.
    189   HeapSimulatorTracker tracker(TestName(), builder.Build(), {param0});
    190   tracker.ExpectCallSequence({
    191       {kAlloc, tracker.BufferAt(param0, {})},
    192       {kFree, tracker.BufferAt(param0, {})},
    193       {kFinish, nullptr},
    194   });
    195 }
    196 
    197 TEST_F(HeapSimulatorTest, Multiply) {
    198   auto builder = HloComputation::Builder(TestName());
    199   auto paramA = builder.AddInstruction(
    200       HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
    201   auto paramX = builder.AddInstruction(
    202       HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
    203   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
    204       f32vec4_, HloOpcode::kMultiply, paramA, paramX));
    205 
    206   // We must keep all parameters and outputs.
    207   HeapSimulatorTracker tracker(TestName(), builder.Build(),
    208                                {paramA, paramX, mul});
    209   tracker.ExpectCallSequence({
    210       {kAlloc, tracker.BufferAt(paramA, {})},
    211       {kAlloc, tracker.BufferAt(paramX, {})},
    212       {kAlloc, tracker.BufferAt(mul, {})},
    213       // All params and outputs are freed at the end.
    214       {kFree, tracker.BufferAt(paramA, {})},
    215       {kFree, tracker.BufferAt(paramX, {})},
    216       {kFree, tracker.BufferAt(mul, {})},
    217       {kFinish, nullptr},
    218   });
    219 }
    220 
    221 TEST_F(HeapSimulatorTest, MultiplyAdd) {
    222   auto builder = HloComputation::Builder(TestName());
    223   auto paramA = builder.AddInstruction(
    224       HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
    225   auto paramX = builder.AddInstruction(
    226       HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
    227   auto paramY = builder.AddInstruction(
    228       HloInstruction::CreateParameter(2, f32vec4_, "paramY"));
    229   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
    230       f32vec4_, HloOpcode::kMultiply, paramA, paramX));
    231   auto add = builder.AddInstruction(
    232       HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, mul, paramY));
    233 
    234   // The buffer for add is the output, and it's shared with the buffer for mul.
    235   HeapSimulatorTracker tracker(TestName(), builder.Build(),
    236                                {paramA, paramX, mul, paramY, add});
    237   tracker.ExpectCallSequence({
    238       {kAlloc, tracker.BufferAt(paramA, {})},
    239       {kAlloc, tracker.BufferAt(paramX, {})},
    240       {kAlloc, tracker.BufferAt(mul, {})},
    241       {kAlloc, tracker.BufferAt(paramY, {})},
    242       // All params and outputs are freed at the end.
    243       {kFree, tracker.BufferAt(paramA, {})},
    244       {kFree, tracker.BufferAt(paramX, {})},
    245       {kFree, tracker.BufferAt(mul, {})},
    246       {kFree, tracker.BufferAt(paramY, {})},
    247       {kFinish, nullptr},
    248   });
    249   tracker.ExpectSharedBuffers(add, {}, mul, {});
    250 }
    251 
    252 TEST_F(HeapSimulatorTest, MultiplyDot) {
    253   auto builder = HloComputation::Builder(TestName());
    254   auto paramA = builder.AddInstruction(
    255       HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
    256   auto paramX = builder.AddInstruction(
    257       HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
    258   auto paramY = builder.AddInstruction(
    259       HloInstruction::CreateParameter(2, f32scalar_, "paramY"));
    260   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
    261       f32vec4_, HloOpcode::kMultiply, paramA, paramX));
    262   DotDimensionNumbers dot_dnums;
    263   dot_dnums.add_lhs_contracting_dimensions(1);
    264   dot_dnums.add_rhs_contracting_dimensions(0);
    265   auto dot = builder.AddInstruction(
    266       HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums));
    267 
    268   // The buffer for dot is the output, and it cannot be shared with the buffer
    269   // for mul, since dot isn't elementwise.
    270   HeapSimulatorTracker tracker(TestName(), builder.Build(),
    271                                {paramA, paramX, mul, paramY, dot});
    272   tracker.ExpectCallSequence({
    273       {kAlloc, tracker.BufferAt(paramA, {})},
    274       {kAlloc, tracker.BufferAt(paramX, {})},
    275       {kAlloc, tracker.BufferAt(mul, {})},
    276       {kAlloc, tracker.BufferAt(paramY, {})},
    277       {kAlloc, tracker.BufferAt(dot, {})},
    278       // All params and outputs are freed at the end.
    279       {kFree, tracker.BufferAt(paramA, {})},
    280       {kFree, tracker.BufferAt(paramX, {})},
    281       {kFree, tracker.BufferAt(mul, {})},
    282       {kFree, tracker.BufferAt(paramY, {})},
    283       {kFree, tracker.BufferAt(dot, {})},
    284       {kFinish, nullptr},
    285   });
    286 }
    287 
    288 TEST_F(HeapSimulatorTest, MultiplyDotAdd) {
    289   auto builder = HloComputation::Builder(TestName());
    290   auto paramA = builder.AddInstruction(
    291       HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
    292   auto paramX = builder.AddInstruction(
    293       HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
    294   auto paramY = builder.AddInstruction(
    295       HloInstruction::CreateParameter(2, f32scalar_, "paramY"));
    296   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
    297       f32vec4_, HloOpcode::kMultiply, paramA, paramX));
    298   DotDimensionNumbers dot_dnums;
    299   dot_dnums.add_lhs_contracting_dimensions(1);
    300   dot_dnums.add_rhs_contracting_dimensions(0);
    301   auto dot = builder.AddInstruction(
    302       HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums));
    303   auto add = builder.AddInstruction(
    304       HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, dot, paramA));
    305 
    306   // The buffer for add is the output, and it's shared with the buffer for dot.
    307   HeapSimulatorTracker tracker(TestName(), builder.Build(),
    308                                {paramA, paramX, mul, paramY, dot, add});
    309   tracker.ExpectCallSequence({
    310       {kAlloc, tracker.BufferAt(paramA, {})},
    311       {kAlloc, tracker.BufferAt(paramX, {})},
    312       {kAlloc, tracker.BufferAt(mul, {})},
    313       {kAlloc, tracker.BufferAt(paramY, {})},
    314       {kAlloc, tracker.BufferAt(dot, {})},
    315       // All params and outputs are freed at the end.
    316       {kFree, tracker.BufferAt(paramA, {})},
    317       {kFree, tracker.BufferAt(paramX, {})},
    318       {kFree, tracker.BufferAt(mul, {})},
    319       {kFree, tracker.BufferAt(paramY, {})},
    320       {kFree, tracker.BufferAt(dot, {})},
    321       {kFinish, nullptr},
    322   });
    323   tracker.ExpectSharedBuffers(add, {}, dot, {});
    324 }
    325 
    326 TEST_F(HeapSimulatorTest, MultiplyDotDot) {
    327   auto builder = HloComputation::Builder(TestName());
    328   auto paramA = builder.AddInstruction(
    329       HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
    330   auto paramX = builder.AddInstruction(
    331       HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
    332   auto paramY = builder.AddInstruction(
    333       HloInstruction::CreateParameter(2, f32scalar_, "paramY"));
    334   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
    335       f32vec4_, HloOpcode::kMultiply, paramA, paramX));
    336   DotDimensionNumbers dot_dnums;
    337   dot_dnums.add_lhs_contracting_dimensions(1);
    338   dot_dnums.add_rhs_contracting_dimensions(0);
    339   auto dot0 = builder.AddInstruction(
    340       HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums));
    341   auto dot1 = builder.AddInstruction(
    342       HloInstruction::CreateDot(f32vec4_, dot0, paramY, dot_dnums));
    343 
    344   // The buffer for dot1 is the output.  No buffers can be shared.  The buffer
    345   // for mul is freed before the end, since it's no longer used after dot0
    346   // finishes.
    347   HeapSimulatorTracker tracker(TestName(), builder.Build(),
    348                                {paramA, paramX, mul, paramY, dot0, dot1});
    349   tracker.ExpectCallSequence({
    350       {kAlloc, tracker.BufferAt(paramA, {})},
    351       {kAlloc, tracker.BufferAt(paramX, {})},
    352       {kAlloc, tracker.BufferAt(mul, {})},
    353       {kAlloc, tracker.BufferAt(paramY, {})},
    354       {kAlloc, tracker.BufferAt(dot0, {})},
    355       {kFree, tracker.BufferAt(mul, {})},  // mul no longer used
    356       {kAlloc, tracker.BufferAt(dot1, {})},
    357       // All params and outputs are freed at the end.
    358       {kFree, tracker.BufferAt(paramA, {})},
    359       {kFree, tracker.BufferAt(paramX, {})},
    360       {kFree, tracker.BufferAt(paramY, {})},
    361       {kFree, tracker.BufferAt(dot0, {})},
    362       {kFree, tracker.BufferAt(dot1, {})},
    363       {kFinish, nullptr},
    364   });
    365 }
    366 
    367 TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) {
    368   auto builder = HloComputation::Builder(TestName());
    369   auto paramA = builder.AddInstruction(
    370       HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
    371   auto paramX = builder.AddInstruction(
    372       HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
    373   auto paramY = builder.AddInstruction(
    374       HloInstruction::CreateParameter(2, f32scalar_, "paramY"));
    375   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
    376       f32vec4_, HloOpcode::kMultiply, paramA, paramX));
    377   DotDimensionNumbers dot_dnums;
    378   dot_dnums.add_lhs_contracting_dimensions(1);
    379   dot_dnums.add_rhs_contracting_dimensions(0);
    380   auto dot0 = builder.AddInstruction(
    381       HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums));
    382   auto dot1 = builder.AddInstruction(
    383       HloInstruction::CreateDot(f32vec4_, dot0, paramY, dot_dnums));
    384   auto tuple =
    385       builder.AddInstruction(HloInstruction::CreateTuple({dot0, dot1}));
    386 
    387   // The buffers for dot0, dot1 and tuple are the output.  No buffers can be
    388   // shared.  The buffer for mul is freed before the end, since it's no longer
    389   // used after dot0 finishes.
    390   HeapSimulatorTracker tracker(
    391       TestName(), builder.Build(),
    392       {paramA, paramX, mul, paramY, dot0, dot1, tuple});
    393   tracker.ExpectCallSequence({
    394       {kAlloc, tracker.BufferAt(paramA, {})},
    395       {kAlloc, tracker.BufferAt(paramX, {})},
    396       {kAlloc, tracker.BufferAt(mul, {})},
    397       {kAlloc, tracker.BufferAt(paramY, {})},
    398       {kAlloc, tracker.BufferAt(dot0, {})},
    399       {kFree, tracker.BufferAt(mul, {})},  // mul no longer used
    400       {kAlloc, tracker.BufferAt(dot1, {})},
    401       {kAlloc, tracker.BufferAt(tuple, {})},
    402       // All params and outputs are freed at the end.
    403       {kFree, tracker.BufferAt(paramA, {})},
    404       {kFree, tracker.BufferAt(paramX, {})},
    405       {kFree, tracker.BufferAt(paramY, {})},
    406       {kFree, tracker.BufferAt(dot0, {})},
    407       {kFree, tracker.BufferAt(dot1, {})},
    408       {kFree, tracker.BufferAt(tuple, {})},
    409       {kFinish, nullptr},
    410   });
    411 }
    412 
    413 TEST_F(HeapSimulatorTest, WholeModule) {
    414   HeapSimulatorTracker tracker(TestName());
    415 
    416   const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
    417   const Shape tuple_shape =
    418       ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape});
    419 
    420   auto cond_builder = HloComputation::Builder("WhileCond");
    421   HloInstruction* cond_param = cond_builder.AddInstruction(
    422       HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
    423   HloInstruction* cond_iter = cond_builder.AddInstruction(
    424       HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0));
    425   HloInstruction* cond_data = cond_builder.AddInstruction(
    426       HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
    427   HloInstruction* cond_lt = cond_builder.AddInstruction(
    428       HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}),
    429                                    HloOpcode::kLt, cond_iter, cond_data));
    430   HloComputation* cond_computation =
    431       tracker.module()->AddEmbeddedComputation(cond_builder.Build());
    432 
    433   auto body_builder = HloComputation::Builder("WhileBody");
    434   HloInstruction* body_param = body_builder.AddInstruction(
    435       HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
    436   HloComputation* body_computation =
    437       tracker.module()->AddEmbeddedComputation(body_builder.Build());
    438 
    439   auto builder = HloComputation::Builder(TestName());
    440   HloInstruction* param = builder.AddInstruction(
    441       HloInstruction::CreateParameter(0, tuple_shape, "param"));
    442   HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
    443       tuple_shape, cond_computation, body_computation, param));
    444   tracker.module()->AddEntryComputation(builder.Build());
    445 
    446   tracker.RunWholeModule(
    447       {param, while_op, body_param, cond_param, cond_iter, cond_data, cond_lt});
    448   tracker.ExpectCallSequence({
    449       // The entry computation param and while_op are allocated first.
    450       {kAlloc, tracker.BufferAt(param, {})},
    451       {kAlloc, tracker.BufferAt(param, {0})},
    452       {kAlloc, tracker.BufferAt(param, {1})},
    453       {kAlloc, tracker.BufferAt(while_op, {})},
    454       {kAlloc, tracker.BufferAt(while_op, {0})},
    455       {kAlloc, tracker.BufferAt(while_op, {1})},
    456 
    457       // Now the while body param is allocated and freed.
    458       {kAlloc, tracker.BufferAt(body_param, {})},
    459       {kAlloc, tracker.BufferAt(body_param, {0})},
    460       {kAlloc, tracker.BufferAt(body_param, {1})},
    461       {kFree, tracker.BufferAt(body_param, {})},
    462       {kFree, tracker.BufferAt(body_param, {0})},
    463       {kFree, tracker.BufferAt(body_param, {1})},
    464 
    465       // Now the while cond param is allocated. The GTE instructions just alias
    466       // the param elements, so the param tuple can immediately be freed.
    467       {kAlloc, tracker.BufferAt(cond_param, {})},
    468       {kAlloc, tracker.BufferAt(cond_param, {0})},
    469       {kAlloc, tracker.BufferAt(cond_param, {1})},
    470       {kFree, tracker.BufferAt(cond_param, {})},
    471 
    472       // Now the final cond less-than buffer is allocated.
    473       {kAlloc, tracker.BufferAt(cond_lt, {})},
    474 
    475       // The order of the remaining Free calls is based on the LogicalBuffer.id,
    476       // which is deterministic, but not obvious.
    477       {kFree, tracker.BufferAt(param, {})},
    478       {kFree, tracker.BufferAt(param, {0})},
    479       {kFree, tracker.BufferAt(param, {1})},
    480 
    481       {kFree, tracker.BufferAt(while_op, {})},
    482       {kFree, tracker.BufferAt(while_op, {0})},
    483       {kFree, tracker.BufferAt(while_op, {1})},
    484 
    485       {kFree, tracker.BufferAt(cond_param, {0})},
    486       {kFree, tracker.BufferAt(cond_param, {1})},
    487       {kFree, tracker.BufferAt(cond_lt, {})},
    488 
    489       {kFinish, nullptr},
    490   });
    491 }
    492 
    493 // Base class for heap algorithm tests.
    494 class HeapAlgorithmTestBase : public ::testing::Test {
    495  protected:
    496   HeapAlgorithmTestBase() : builder_("heap_simulator_test") {
    497     buffer_a_ = DummyLogicalBuffer();
    498     buffer_b_ = DummyLogicalBuffer();
    499     buffer_c_ = DummyLogicalBuffer();
    500     buffer_d_ = DummyLogicalBuffer();
    501     buffer_e_ = DummyLogicalBuffer();
    502     buffer_f_ = DummyLogicalBuffer();
    503     buffer_g_ = DummyLogicalBuffer();
    504     buffer_h_ = DummyLogicalBuffer();
    505     buffer_i_ = DummyLogicalBuffer();
    506   }
    507   ~HeapAlgorithmTestBase() override {}
    508 
    509   const LogicalBuffer* buffer_a_;
    510   const LogicalBuffer* buffer_b_;
    511   const LogicalBuffer* buffer_c_;
    512   const LogicalBuffer* buffer_d_;
    513   const LogicalBuffer* buffer_e_;
    514   const LogicalBuffer* buffer_f_;
    515   const LogicalBuffer* buffer_g_;
    516   const LogicalBuffer* buffer_h_;
    517   const LogicalBuffer* buffer_i_;
    518 
    519  private:
    520   // Create a dummy LogicalBuffer to pass to the heap algorithm.
    521   const LogicalBuffer* DummyLogicalBuffer() {
    522     const LogicalBuffer::Id id = buffers_.size();
    523     auto const0 = builder_.AddInstruction(
    524         HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    525     buffers_.emplace_back(MakeUnique<LogicalBuffer>(const0, ShapeIndex{}, id));
    526     return buffers_.back().get();
    527   }
    528 
    529   HloComputation::Builder builder_;
    530   std::vector<std::unique_ptr<LogicalBuffer>> buffers_;
    531 };
    532 
    533 class NoFragmentationStatsHeapTest : public HeapAlgorithmTestBase {};
    534 
    535 TEST_F(NoFragmentationStatsHeapTest, Empty) {
    536   NoFragmentationStatsHeap heap;
    537   EXPECT_EQ(0, heap.Finish().heap_size);
    538 }
    539 
    540 TEST_F(NoFragmentationStatsHeapTest, Simple) {
    541   NoFragmentationStatsHeap heap;
    542   heap.Alloc(buffer_a_, 10);
    543   heap.Alloc(buffer_b_, 20);
    544   heap.Alloc(buffer_c_, 30);
    545   heap.Alloc(buffer_d_, 30);
    546   heap.Free(buffer_a_, 10);
    547   heap.Free(buffer_b_, 20);
    548   heap.Free(buffer_c_, 30);
    549   heap.Free(buffer_d_, 30);
    550   EXPECT_EQ(90, heap.Finish().heap_size);
    551 }
    552 
    553 TEST_F(NoFragmentationStatsHeapTest, Mixed) {
    554   NoFragmentationStatsHeap heap;
    555   heap.Alloc(buffer_a_, 10);  // max: A
    556 
    557   heap.Alloc(buffer_b_, 20);  // max: A+B
    558   heap.Free(buffer_b_, 20);
    559 
    560   heap.Alloc(buffer_c_, 30);  // max: A+C
    561   heap.Free(buffer_c_, 30);
    562 
    563   heap.Alloc(buffer_d_, 5);  // max: A+C
    564   heap.Free(buffer_d_, 5);
    565 
    566   heap.Free(buffer_a_, 10);
    567   EXPECT_EQ(40, heap.Finish().heap_size);
    568 }
    569 
    570 class DecreasingSizeRunsHeapTest : public HeapAlgorithmTestBase {};
    571 
    572 TEST_F(DecreasingSizeRunsHeapTest, Empty) {
    573   CallSequence call_sequence;
    574   DecreasingSizeRunsHeap heap(MakeUnique<HeapCallRecorder>(&call_sequence));
    575   heap.Finish();
    576   EXPECT_EQ(call_sequence, CallSequence({
    577                                {kFinish, nullptr},
    578                            }));
    579 }
    580 
    581 TEST_F(DecreasingSizeRunsHeapTest, Simple) {
    582   CallSequence call_sequence;
    583   DecreasingSizeRunsHeap heap(MakeUnique<HeapCallRecorder>(&call_sequence));
    584   heap.Alloc(buffer_a_, 10);
    585   heap.Alloc(buffer_b_, 20);
    586   heap.Alloc(buffer_c_, 30);
    587   heap.Alloc(buffer_d_, 30);
    588   heap.Free(buffer_a_, 10);
    589   heap.Free(buffer_b_, 20);
    590   heap.Free(buffer_c_, 30);
    591   heap.Free(buffer_d_, 30);
    592   heap.Finish();
    593   // Runs of Allocs and Frees are sorted by decreasing size, with buffer id
    594   // tiebreaker.
    595   EXPECT_EQ(call_sequence, CallSequence({
    596                                {kAlloc, buffer_c_},
    597                                {kAlloc, buffer_d_},
    598                                {kAlloc, buffer_b_},
    599                                {kAlloc, buffer_a_},
    600                                {kFree, buffer_c_},
    601                                {kFree, buffer_d_},
    602                                {kFree, buffer_b_},
    603                                {kFree, buffer_a_},
    604                                {kFinish, nullptr},
    605                            }));
    606 }
    607 
    608 TEST_F(DecreasingSizeRunsHeapTest, Mixed) {
    609   CallSequence call_sequence;
    610   DecreasingSizeRunsHeap heap(MakeUnique<HeapCallRecorder>(&call_sequence));
    611   heap.Alloc(buffer_a_, 10);
    612   heap.Alloc(buffer_b_, 20);
    613   heap.Free(buffer_b_, 20);
    614 
    615   heap.Alloc(buffer_c_, 30);
    616   heap.Free(buffer_c_, 30);
    617 
    618   heap.Alloc(buffer_d_, 5);
    619   heap.Free(buffer_d_, 5);
    620   heap.Free(buffer_a_, 10);
    621   heap.Finish();
    622   // Runs of Allocs and Frees are sorted by decreasing size.
    623   EXPECT_EQ(call_sequence, CallSequence({
    624                                {kAlloc, buffer_b_},
    625                                {kAlloc, buffer_a_},
    626                                {kFree, buffer_b_},
    627 
    628                                {kAlloc, buffer_c_},
    629                                {kFree, buffer_c_},
    630 
    631                                {kAlloc, buffer_d_},
    632                                {kFree, buffer_a_},
    633                                {kFree, buffer_d_},
    634                                {kFinish, nullptr},
    635                            }));
    636 }
    637 
    638 class LazyBestFitHeapTest : public HeapAlgorithmTestBase {};
    639 
    640 TEST_F(LazyBestFitHeapTest, Empty) {
    641   LazyBestFitHeap heap(/*alignment=*/1);
    642   const HeapSimulator::Result result = heap.Finish();
    643   EXPECT_EQ(0, result.heap_size);
    644   EXPECT_EQ(0, result.chunk_map.size());
    645 }
    646 
    647 TEST_F(LazyBestFitHeapTest, Simple) {
    648   LazyBestFitHeap heap(/*alignment=*/1);
    649   heap.Alloc(buffer_a_, 10);
    650   heap.Alloc(buffer_b_, 20);
    651   heap.Alloc(buffer_c_, 30);
    652   heap.Alloc(buffer_d_, 30);
    653   heap.Free(buffer_a_, 10);
    654   heap.Free(buffer_b_, 20);
    655   heap.Free(buffer_c_, 30);
    656   heap.Free(buffer_d_, 30);
    657 
    658   const HeapSimulator::Result result = heap.Finish();
    659   EXPECT_EQ(90, result.heap_size);
    660   EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
    661   EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
    662   EXPECT_EQ(30, result.chunk_map.at(buffer_c_).size);
    663   EXPECT_EQ(30, result.chunk_map.at(buffer_d_).size);
    664 
    665   EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset);
    666   EXPECT_EQ(10, result.chunk_map.at(buffer_b_).offset);
    667   EXPECT_EQ(30, result.chunk_map.at(buffer_c_).offset);
    668   EXPECT_EQ(60, result.chunk_map.at(buffer_d_).offset);
    669 }
    670 
    671 TEST_F(LazyBestFitHeapTest, Mixed) {
    672   LazyBestFitHeap heap(/*alignment=*/1);
    673   heap.Alloc(buffer_a_, 10);  // A lazy offset
    674 
    675   heap.Alloc(buffer_b_, 20);  // B lazy offset
    676   heap.Free(buffer_b_, 20);   // B range = [0, 20)  free = [0, 20)
    677 
    678   heap.Alloc(buffer_c_, 30);  // C range = [0, 30)
    679   heap.Free(buffer_c_, 30);   //                    free = [0, 30)
    680 
    681   heap.Alloc(buffer_d_, 5);  // D range = [0, 5)   free = [5, 30)
    682   heap.Free(buffer_d_, 5);   //                    free = [0, 30)
    683 
    684   heap.Free(buffer_a_, 10);  // A range = [30, 10) free = [0, 40)
    685 
    686   const HeapSimulator::Result result = heap.Finish();
    687   EXPECT_EQ(40, result.heap_size);
    688   EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
    689   EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
    690   EXPECT_EQ(30, result.chunk_map.at(buffer_c_).size);
    691   EXPECT_EQ(5, result.chunk_map.at(buffer_d_).size);
    692 
    693   EXPECT_EQ(30, result.chunk_map.at(buffer_a_).offset);
    694   EXPECT_EQ(0, result.chunk_map.at(buffer_b_).offset);
    695   EXPECT_EQ(0, result.chunk_map.at(buffer_c_).offset);
    696   EXPECT_EQ(0, result.chunk_map.at(buffer_d_).offset);
    697 }
    698 
    699 TEST_F(LazyBestFitHeapTest, BestFit) {
    700   LazyBestFitHeap heap(/*alignment=*/1);
    701 
    702   // First alloc/free buffer_a_, to force a big free chunk to appear.
    703   heap.Alloc(buffer_a_, 200);  // A lazy offset
    704   heap.Free(buffer_a_, 200);   // A range = [0, 200)   free = [0, 200)
    705 
    706   // Now alloc a bunch of buffers that are allocated out of the free chunk.
    707   heap.Alloc(buffer_b_, 30);  // B range = [0, 30)    free = [30, 200)
    708   heap.Alloc(buffer_c_, 30);  // C range = [30, 60)   free = [60, 200)
    709   heap.Alloc(buffer_d_, 20);  // D range = [60, 80)   free = [80, 200)
    710   heap.Alloc(buffer_e_, 20);  // E range = [80, 100)  free = [100, 200)
    711   heap.Alloc(buffer_f_, 10);  // F range = [100, 110) free = [110, 200)
    712   heap.Alloc(buffer_g_, 10);  // G range = [110, 120) free = [120, 200)
    713   heap.Alloc(buffer_h_, 80);  // H range = [120, 200)
    714 
    715   // Free buffers to create free chunks of different sizes.
    716   heap.Free(buffer_c_, 30);  // free = [30, 60)
    717   heap.Free(buffer_e_, 20);  // free = [30, 60), [80, 100)
    718   heap.Free(buffer_g_, 10);  // free = [30, 60), [80, 100), [110, 120)
    719 
    720   // The best fit is picked out of the existing free chunks.
    721   heap.Alloc(buffer_i_, 15);  // I range = [80, 95)
    722 
    723   // The frees here ensure the buffer-coalescing logic is exercised.
    724   heap.Free(buffer_b_, 30);
    725   heap.Free(buffer_d_, 20);
    726   heap.Free(buffer_f_, 10);
    727   heap.Free(buffer_h_, 80);
    728   heap.Free(buffer_i_, 15);
    729 
    730   const HeapSimulator::Result result = heap.Finish();
    731   EXPECT_EQ(200, result.heap_size);
    732   EXPECT_EQ(200, result.chunk_map.at(buffer_a_).size);
    733   EXPECT_EQ(30, result.chunk_map.at(buffer_b_).size);
    734   EXPECT_EQ(30, result.chunk_map.at(buffer_c_).size);
    735   EXPECT_EQ(20, result.chunk_map.at(buffer_d_).size);
    736   EXPECT_EQ(20, result.chunk_map.at(buffer_e_).size);
    737   EXPECT_EQ(10, result.chunk_map.at(buffer_f_).size);
    738   EXPECT_EQ(10, result.chunk_map.at(buffer_g_).size);
    739   EXPECT_EQ(80, result.chunk_map.at(buffer_h_).size);
    740   EXPECT_EQ(15, result.chunk_map.at(buffer_i_).size);
    741 
    742   EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset);
    743   EXPECT_EQ(0, result.chunk_map.at(buffer_b_).offset);
    744   EXPECT_EQ(30, result.chunk_map.at(buffer_c_).offset);
    745   EXPECT_EQ(60, result.chunk_map.at(buffer_d_).offset);
    746   EXPECT_EQ(80, result.chunk_map.at(buffer_e_).offset);
    747   EXPECT_EQ(100, result.chunk_map.at(buffer_f_).offset);
    748   EXPECT_EQ(110, result.chunk_map.at(buffer_g_).offset);
    749   EXPECT_EQ(120, result.chunk_map.at(buffer_h_).offset);
    750   EXPECT_EQ(80, result.chunk_map.at(buffer_i_).offset);
    751 }
    752 
    753 TEST_F(LazyBestFitHeapTest, Lazy) {
    754   LazyBestFitHeap heap(/*alignment=*/1);
    755 
    756   // First alloc some buffers, which are all lazily allocated offsets.
    757   heap.Alloc(buffer_a_, 10);
    758   heap.Alloc(buffer_b_, 5);
    759   heap.Alloc(buffer_c_, 10);
    760 
    761   // Now free some buffers, which forces offset assignment.
    762   heap.Free(buffer_a_, 10);  // A range = [0, 10)  free = [0, 10)
    763   heap.Free(buffer_c_, 10);  // C range = [10, 20) free = [0, 20)
    764 
    765   // If we hadn't lazily assigned offsets, the free chunk wouldn't be large
    766   // enough to hold the entire allocation.
    767   heap.Alloc(buffer_d_, 20);  // D range = [0, 20)
    768 
    769   heap.Free(buffer_b_, 5);  // B range = [20, 25)
    770   heap.Free(buffer_d_, 20);
    771 
    772   const HeapSimulator::Result result = heap.Finish();
    773   EXPECT_EQ(25, result.heap_size);
    774   EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
    775   EXPECT_EQ(5, result.chunk_map.at(buffer_b_).size);
    776   EXPECT_EQ(10, result.chunk_map.at(buffer_c_).size);
    777   EXPECT_EQ(20, result.chunk_map.at(buffer_d_).size);
    778 
    779   EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset);
    780   EXPECT_EQ(20, result.chunk_map.at(buffer_b_).offset);
    781   EXPECT_EQ(10, result.chunk_map.at(buffer_c_).offset);
    782   EXPECT_EQ(0, result.chunk_map.at(buffer_d_).offset);
    783 }
    784 
    785 TEST_F(LazyBestFitHeapTest, ReuseLastFreeChunk) {
    786   LazyBestFitHeap heap(/*alignment=*/1);
    787 
    788   // First alloc/free buffer_a_, to force a big free chunk to appear.
    789   heap.Alloc(buffer_a_, 60);  // A lazy offset
    790   heap.Free(buffer_a_, 60);   // A range = [0, 60)   free = [0, 60)
    791 
    792   // Now alloc a bunch of buffers that are allocated out of the free chunk.
    793   heap.Alloc(buffer_b_, 10);  // B range = [0, 10)    free = [10, 60)
    794   heap.Alloc(buffer_c_, 20);  // C range = [10, 30)   free = [30, 60)
    795   heap.Alloc(buffer_d_, 30);  // D range = [30, 60)
    796 
    797   // Free buffers to create free chunks of different sizes.
    798   heap.Free(buffer_b_, 10);  // free = [0, 10)
    799   heap.Free(buffer_d_, 30);  // free = [0, 10), [30, 60)
    800 
    801   // No free chunks are large enough, but the last free chunk is adjacent to the
    802   // end of the heap, so we re-use that chunk.
    803   heap.Alloc(buffer_e_, 40);  // E range = [30, 70)
    804 
    805   heap.Free(buffer_c_, 20);
    806   heap.Free(buffer_e_, 40);
    807 
    808   const HeapSimulator::Result result = heap.Finish();
    809   EXPECT_EQ(70, result.heap_size);
    810   EXPECT_EQ(60, result.chunk_map.at(buffer_a_).size);
    811   EXPECT_EQ(10, result.chunk_map.at(buffer_b_).size);
    812   EXPECT_EQ(20, result.chunk_map.at(buffer_c_).size);
    813   EXPECT_EQ(30, result.chunk_map.at(buffer_d_).size);
    814   EXPECT_EQ(40, result.chunk_map.at(buffer_e_).size);
    815 
    816   EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset);
    817   EXPECT_EQ(0, result.chunk_map.at(buffer_b_).offset);
    818   EXPECT_EQ(10, result.chunk_map.at(buffer_c_).offset);
    819   EXPECT_EQ(30, result.chunk_map.at(buffer_d_).offset);
    820   EXPECT_EQ(30, result.chunk_map.at(buffer_e_).offset);
    821 }
    822 
    823 TEST_F(LazyBestFitHeapTest, Alignment) {
    824   LazyBestFitHeap heap(/*alignment=*/64);
    825 
    826   // First alloc some buffers, which are all lazily allocated offsets.
    827   heap.Alloc(buffer_a_, 10);
    828   heap.Alloc(buffer_b_, 5);
    829   heap.Alloc(buffer_c_, 10);
    830 
    831   // Now free some buffers, which forces offset assignment with alignment.
    832   heap.Free(buffer_a_, 10);  //  A range = [0, 10)    free = [0, 10)
    833   heap.Free(buffer_c_, 10);  //  C range = [64, 74)   free = [0, 74)
    834 
    835   // If we hadn't lazily assigned offsets, and accounted for alignment, the free
    836   // chunk wouldn't be large enough to hold the entire allocation.
    837   heap.Alloc(buffer_d_, 74);  // D range = [0, 74)    free = [)
    838 
    839   heap.Free(buffer_b_, 5);    // B range = [128, 133) free = [74, 133)
    840   heap.Alloc(buffer_e_, 23);  // E range = [128, 151) free = [74, 128)
    841 
    842   heap.Free(buffer_d_, 74);  //                       free = [0, 128)
    843   heap.Free(buffer_e_, 23);  //                       free = [0, 151)
    844 
    845   const HeapSimulator::Result result = heap.Finish();
    846   EXPECT_EQ(151, result.heap_size);
    847   EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
    848   EXPECT_EQ(5, result.chunk_map.at(buffer_b_).size);
    849   EXPECT_EQ(10, result.chunk_map.at(buffer_c_).size);
    850   EXPECT_EQ(74, result.chunk_map.at(buffer_d_).size);
    851   EXPECT_EQ(23, result.chunk_map.at(buffer_e_).size);
    852 
    853   EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset);
    854   EXPECT_EQ(128, result.chunk_map.at(buffer_b_).offset);
    855   EXPECT_EQ(64, result.chunk_map.at(buffer_c_).offset);
    856   EXPECT_EQ(0, result.chunk_map.at(buffer_d_).offset);
    857   EXPECT_EQ(128, result.chunk_map.at(buffer_e_).offset);
    858 }
    859 
    860 }  // namespace
    861 }  // namespace xla
    862