Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/heap_simulator.h"
     17 
     18 #include <algorithm>
     19 #include <vector>
     20 
     21 #include "tensorflow/compiler/xla/map_util.h"
     22 #include "tensorflow/compiler/xla/service/liveness_util.h"
     23 #include "tensorflow/compiler/xla/util.h"
     24 
     25 namespace xla {
     26 
     27 using tensorflow::gtl::FlatMap;
     28 using tensorflow::gtl::FlatSet;
     29 
     30 namespace {
     31 
     32 // Returns the set of buffers that may be sources of all operands of the given
     33 // instruction.  The returned buffers are guaranteed to have no duplicates, and
     34 // to be sorted in a deterministic order.
     35 std::vector<const LogicalBuffer*> UniqueOperandSourceBuffers(
     36     const HloInstruction* instruction,
     37     const TuplePointsToAnalysis& points_to_analysis) {
     38   std::vector<const LogicalBuffer*> buffers;
     39   for (const HloInstruction* operand : instruction->operands()) {
     40     points_to_analysis.GetPointsToSet(operand).ForEachElement(
     41         [&](const ShapeIndex& /*index*/,
     42             const PointsToSet::BufferList& points_to) {
     43           buffers.insert(buffers.end(), points_to.begin(), points_to.end());
     44         });
     45   }
     46 
     47   // Sort and then remove duplicates from buffers.
     48   std::sort(buffers.begin(), buffers.end(),
     49             [](const LogicalBuffer* a, const LogicalBuffer* b) {
     50               return a->id() < b->id();
     51             });
     52   buffers.erase(std::unique(buffers.begin(), buffers.end(),
     53                             [](const LogicalBuffer* a, const LogicalBuffer* b) {
     54                               return a->id() == b->id();
     55                             }),
     56                 buffers.end());
     57   return buffers;
     58 }
     59 
     60 }  // namespace
     61 
     62 /*static*/
     63 StatusOr<HeapSimulator::Result> HeapSimulator::Run(
     64     std::unique_ptr<HeapAlgorithm> algorithm, const HloModule& module,
     65     const SequentialHloOrdering::HloModuleSequence& module_sequence,
     66     const TuplePointsToAnalysis& points_to_analysis,
     67     const LogicalBuffer::SizeFunction& size_fn, const Options& options) {
     68   HeapSimulator heap(std::move(algorithm), size_fn, options, &module_sequence);
     69   const HloComputation* entry_computation = module.entry_computation();
     70   const std::vector<const HloInstruction*>& instruction_sequence =
     71       FindOrDie(module_sequence, entry_computation);
     72   TF_RETURN_IF_ERROR(heap.RunComputation(
     73       *entry_computation, instruction_sequence, points_to_analysis));
     74   return heap.Finish();
     75 }
     76 
     77 /*static*/
     78 StatusOr<HeapSimulator::Result> HeapSimulator::Run(
     79     std::unique_ptr<HeapAlgorithm> algorithm, const HloComputation& computation,
     80     const std::vector<const HloInstruction*>& instruction_sequence,
     81     const TuplePointsToAnalysis& points_to_analysis,
     82     const LogicalBuffer::SizeFunction& size_fn, const Options& options) {
     83   HeapSimulator heap(std::move(algorithm), size_fn, options,
     84                      /*module_sequence=*/nullptr);
     85   TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence,
     86                                          points_to_analysis));
     87   return heap.Finish();
     88 }
     89 
     90 // Runs a heap simulation for the given 'computation', assuming the given
     91 // 'instruction_sequence'.
     92 Status HeapSimulator::RunComputation(
     93     const HloComputation& computation,
     94     const std::vector<const HloInstruction*>& instruction_sequence,
     95     const TuplePointsToAnalysis& points_to_analysis) {
     96   // The goal here is to minimize memory usage, assuming the given sequential
     97   // ordering of instructions.  The strategy is to walk through the instruction
     98   // sequence, calling Alloc and Free on the underlying heap algorithm.  The
     99   // heap algorithm takes care of packing and reducing fragmentation.
    100   //
    101   // 'live_buffers' tracks the liveness of each buffer that we assign, by
    102   // associating it with a set of HloInstructions that need to be visited.  When
    103   // the set becomes empty, the buffer is no longer used, and can be freed.
    104   FlatMap<const LogicalBuffer*, FlatSet<const HloInstruction*>> live_buffers;
    105 
    106   const HloInstruction* root = computation.root_instruction();
    107   auto output_source_buffers =
    108       points_to_analysis.GetPointsToSet(root).CreateFlattenedSet();
    109 
    110   std::vector<const LogicalBuffer*> dead_buffers_to_free;
    111   std::vector<const LogicalBuffer*> operand_buffers_to_free;
    112   for (const HloInstruction* instruction : instruction_sequence) {
    113     const TuplePointsToAnalysis::BufferDefinitionVector&
    114         buffers_defined_by_instruction =
    115             points_to_analysis.GetBuffersDefinedByInstruction(instruction);
    116 
    117     // Initialize live_buffers for each buffer that we're going to assign.  The
    118     // set of instructions that need to be visited contains all users of all
    119     // aliases.  The alias itself is not necessary; if it has users, the users
    120     // are necessarily scheduled after the alias.  And if it has no users, it is
    121     // either a dead value or an output, both of which are handled below.
    122     //
    123     // We ignore control dependencies here. The reasoning is that the control
    124     // dependencies have already been accounted for in the ordering of the given
    125     // 'instruction_sequence', and should not otherwise artificially extend the
    126     // lifetime of buffers that aren't already connected by a data dependency.
    127     dead_buffers_to_free.clear();
    128     for (const LogicalBuffer* buffer : buffers_defined_by_instruction) {
    129       if (IgnoreBuffer(buffer)) {
    130         continue;
    131       }
    132       FlatSet<const HloInstruction*>* live_set = nullptr;
    133       for (const BufferAlias& alias :
    134            points_to_analysis.GetBufferAliases(*buffer)) {
    135         const std::vector<HloInstruction*>& users =
    136             alias.instruction()->users();
    137         if (!users.empty()) {
    138           if (live_set == nullptr) {
    139             live_set = &live_buffers[buffer];
    140           }
    141           live_set->insert(users.begin(), users.end());
    142         }
    143       }
    144 
    145       // Add a nullptr sentry to ensure entry parameters and output source
    146       // buffers are not freed until the very end.
    147       const bool entry_parameter =
    148           &computation == computation.parent()->entry_computation() &&
    149           buffer->instruction()->opcode() == HloOpcode::kParameter;
    150       const bool output = output_source_buffers.count(buffer) > 0;
    151       if (entry_parameter || output) {
    152         live_buffers[buffer].insert(nullptr);
    153       }
    154 
    155       // If the buffer has no users and isn't an entry parameter or output, it
    156       // must be a dead value.
    157       if (live_buffers.count(buffer) == 0) {
    158         dead_buffers_to_free.push_back(buffer);
    159       }
    160     }
    161 
    162     // Update live_buffers to indicate we've visited this instruction; this is
    163     // the inverse of the initialization logic.  We erase this instruction from
    164     // all source buffers of all operands of this instruction.  Buffers that
    165     // have no instructions left to visit are moved from live_buffers to
    166     // operand_buffers_to_free.
    167     operand_buffers_to_free.clear();
    168     for (const LogicalBuffer* operand_buffer :
    169          UniqueOperandSourceBuffers(instruction, points_to_analysis)) {
    170       if (IgnoreBuffer(operand_buffer)) {
    171         continue;
    172       }
    173       auto it = live_buffers.find(operand_buffer);
    174       FlatSet<const HloInstruction*>* live_set = &it->second;
    175       live_set->erase(instruction);
    176       if (live_set->empty()) {
    177         live_buffers.erase(it);
    178         operand_buffers_to_free.push_back(operand_buffer);
    179       }
    180     }
    181 
    182     // Allocate buffers defined by this instruction.  This is the latest point
    183     // that we can allocate; right before the buffer is first used.  This must
    184     // happen before dead or operand buffers are freed; the instruction reads
    185     // the operand buffers to produce its output.
    186     //
    187     // INVARIANT: Either Alloc or ShareBuffer will be called for each buffer
    188     // that we should assign.
    189     for (const LogicalBuffer* buffer : buffers_defined_by_instruction) {
    190       if (IgnoreBuffer(buffer)) {
    191         continue;
    192       }
    193 
    194       // Check whether the buffer can share with one of its operands; we can
    195       // save memory by sharing the buffer, rather than allocating a new one.
    196       // We can only share with the operand buffer if it is about to be freed;
    197       // we must be the last user of the buffer.
    198       bool shared = false;
    199       if (options_.may_reuse_operand_buffers) {
    200         for (const LogicalBuffer* operand_buffer : operand_buffers_to_free) {
    201           if (buffer->instruction()->IsUserOf(operand_buffer->instruction()) &&
    202               buffer->instruction()->opcode() != HloOpcode::kCopy &&
    203               CanShareOperandBufferWithUser(
    204                   operand_buffer->instruction(), operand_buffer->index(),
    205                   buffer->instruction(), buffer->index(), points_to_analysis)) {
    206             ShareBuffer(buffer, operand_buffer, instruction);
    207             shared = true;
    208             break;
    209           }
    210         }
    211       }
    212 
    213       if (!shared) {
    214         Alloc(buffer, instruction);
    215       }
    216     }
    217 
    218     // If the whole module is sequential, we can save memory by running the
    219     // heap-simulation for sub-computations inline. E.g. the buffers for the
    220     // condition and body of a kWhile instruction are only live for the duration
    221     // of the instruction itself.
    222     //
    223     // The order that the sub-computations are simulated does not affect
    224     // correctness; since the whole module is sequential, we know that the
    225     // sub-computations will never be run concurrently.
    226     if (module_sequence_ != nullptr) {
    227       if (instruction->opcode() == HloOpcode::kCall ||
    228           instruction->opcode() == HloOpcode::kConditional ||
    229           instruction->opcode() == HloOpcode::kWhile) {
    230         for (const HloComputation* called_computation :
    231              instruction->called_computations()) {
    232           const std::vector<const HloInstruction*>& called_sequence =
    233               FindOrDie(*module_sequence_, called_computation);
    234           TF_RETURN_IF_ERROR(RunComputation(
    235               *called_computation, called_sequence, points_to_analysis));
    236         }
    237       }
    238 
    239       // Other sub-computations (e.g. Map, Reduce, ...) are skipped; they are
    240       // assigned "thread-local" allocations, meaning their buffers are not
    241       // allocated up-front at the beginning of the computation.
    242     }
    243 
    244     // Free buffers that are no longer live.  This is the earliest point that we
    245     // can de-allocate; right after the last use of the buffer.
    246     for (const LogicalBuffer* buffer : dead_buffers_to_free) {
    247       Free(buffer, instruction);
    248     }
    249     for (const LogicalBuffer* buffer : operand_buffers_to_free) {
    250       Free(buffer, instruction);
    251     }
    252   }
    253 
    254   // Any remaining live buffers must be entry parameters or output source
    255   // buffers, which had a nullptr sentry added.  Free them now.
    256   for (const auto& buffer_pending : live_buffers) {
    257     const LogicalBuffer* buffer = buffer_pending.first;
    258     const FlatSet<const HloInstruction*>& pending = buffer_pending.second;
    259     CHECK_EQ(pending.size(), 1) << *buffer;
    260     CHECK(*pending.begin() == nullptr) << *buffer;
    261     Free(buffer, root);
    262   }
    263 
    264   return Status::OK();
    265 }
    266 
    267 HeapSimulator::HeapSimulator(
    268     std::unique_ptr<HeapAlgorithm> algorithm,
    269     const LogicalBuffer::SizeFunction& size_fn, const Options& options,
    270     const SequentialHloOrdering::HloModuleSequence* module_sequence)
    271     : no_fragmentation_stats_(MakeUnique<NoFragmentationStatsHeap>()),
    272       algorithm_(std::move(algorithm)),
    273       size_fn_(size_fn),
    274       options_(options),
    275       module_sequence_(module_sequence) {
    276   debug_trace_.set_whole_module_simulation(module_sequence_ != nullptr);
    277 }
    278 
    279 HeapSimulator::~HeapSimulator() {}
    280 
    281 bool HeapSimulator::IgnoreBuffer(const LogicalBuffer* buffer) const {
    282   // Buffers for constants are ignored unless the alloc_constants option is
    283   // set. Also ignore buffers that we're not meant to assign.
    284   //
    285   // TODO(b/32248867): For consistency, constants should get allocations.
    286   if (!options_.alloc_constants &&
    287       buffer->instruction()->opcode() == HloOpcode::kConstant) {
    288     return true;
    289   }
    290   return options_.buffers_to_assign != nullptr &&
    291          options_.buffers_to_assign->count(buffer) == 0;
    292 }
    293 
    294 // Alloc always calls the underlying heap algorithm.
    295 void HeapSimulator::Alloc(const LogicalBuffer* buffer,
    296                           const HloInstruction* instruction) {
    297   CHECK(allocated_buffers_.count(buffer) == 0)
    298       << "Alloc called on allocated buffer: " << *buffer;
    299   CHECK(freed_buffers_.count(buffer) == 0)
    300       << "Alloc called on freed buffer: " << *buffer;
    301 
    302   allocated_buffers_.insert(buffer);
    303   const int64 size = size_fn_(*buffer);
    304   algorithm_->Alloc(buffer, size);
    305   no_fragmentation_stats_->Alloc(buffer, size);
    306 
    307   FillDebugTrace(HeapSimulatorTrace::Event::ALLOC, buffer, instruction,
    308                  nullptr);
    309 }
    310 
    311 // Free calls the underlying algorithm for non-shared buffers, and for shared
    312 // buffers whose group liveness has expired.  Shared group liveness is tracked
    313 // by maintaining a refcount; the Free call on the last buffer in the group
    314 // causes Free to be called on the underlying algorithm.
    315 void HeapSimulator::Free(const LogicalBuffer* buffer,
    316                          const HloInstruction* instruction) {
    317   auto shared_it = shared_buffers_.find(buffer);
    318   if (shared_it != shared_buffers_.end()) {
    319     std::shared_ptr<SharedGroup> group = shared_it->second;
    320     --group->refcount;
    321     if (group->refcount > 0) {
    322       return;
    323     }
    324     CHECK_EQ(group->refcount, 0)
    325         << "Free caused negative refcount on shared buffer: " << *buffer;
    326     buffer = group->canonical;
    327   }
    328 
    329   CHECK(allocated_buffers_.count(buffer) > 0)
    330       << "Free called on non-allocated buffer: " << *buffer;
    331   CHECK(freed_buffers_.count(buffer) == 0)
    332       << "Free called on freed buffer: " << *buffer;
    333 
    334   freed_buffers_.insert(buffer);
    335   const int64 size = size_fn_(*buffer);
    336   algorithm_->Free(buffer, size);
    337   no_fragmentation_stats_->Free(buffer, size);
    338 
    339   FillDebugTrace(HeapSimulatorTrace::Event::FREE, buffer, instruction, nullptr);
    340 }
    341 
    342 // ShareBuffer associates buffers with their SharedGroup in shared_buffers_.
    343 // The 'buffer' must be a non-allocated, non-freed buffer, just like in calls to
    344 // Alloc.  The 'shared' buffer must be a previously allocated or shared buffer.
    345 // Both 'buffer' and 'shared' will be associated with the same SharedGroup.
    346 void HeapSimulator::ShareBuffer(const LogicalBuffer* buffer,
    347                                 const LogicalBuffer* shared,
    348                                 const HloInstruction* instruction) {
    349   CHECK_LE(size_fn_(*buffer), size_fn_(*shared))
    350       << "ShareBuffer oversized buffer" << *buffer << " shared: " << *shared;
    351   CHECK(allocated_buffers_.count(buffer) == 0)
    352       << "ShareBuffer called on allocated buffer: " << *buffer;
    353   CHECK(freed_buffers_.count(buffer) == 0)
    354       << "ShareBuffer called on freed buffer: " << *buffer;
    355   CHECK(freed_buffers_.count(shared) == 0)
    356       << "ShareBuffer called on freed shared buffer: " << *shared;
    357 
    358   const LogicalBuffer* canonical = nullptr;
    359   auto shared_it = shared_buffers_.find(shared);
    360   if (shared_it != shared_buffers_.end()) {
    361     // The 'shared' buffer already has a group; it might be the canonical, but
    362     // also might not be.  Just add 'buffer' to the existing group.
    363     std::shared_ptr<SharedGroup> group = shared_it->second;
    364     canonical = group->canonical;
    365     ++group->refcount;
    366     shared_buffers_.emplace(buffer, group);
    367   } else {
    368     // The 'shared' buffer doesn't have a group; it must be the canonical.  Add
    369     // both 'buffer' and 'shared' to a new group.
    370     CHECK(allocated_buffers_.count(shared) > 0)
    371         << "ShareBuffer called on non-allocated shared buffer: " << *shared;
    372     auto group = std::make_shared<SharedGroup>();
    373     canonical = shared;
    374     group->canonical = canonical;
    375     group->refcount = 2;
    376     shared_buffers_.emplace(buffer, group);
    377     shared_buffers_.emplace(shared, group);
    378   }
    379 
    380   FillDebugTrace(HeapSimulatorTrace::Event::SHARE_WITH, buffer, instruction,
    381                  canonical);
    382 }
    383 
    384 HeapSimulator::Result HeapSimulator::Finish() {
    385   Result result = algorithm_->Finish();
    386 
    387   // Post-process the result to add chunks for shared buffers.  An empty chunk
    388   // map means that either no buffers were allocated, or the heap was only
    389   // collecting statistics, e.g. NoFragmentationStatsHeap.
    390   if (!result.chunk_map.empty()) {
    391     for (const auto& share_pair : shared_buffers_) {
    392       const LogicalBuffer* buffer = share_pair.first;
    393       std::shared_ptr<SharedGroup> group = share_pair.second;
    394       if (buffer != group->canonical) {
    395         // The canonical must already exist in the chunk_map, since we called
    396         // Alloc(canonical) on the underlying algorithm.  Add non-canonical
    397         // chunks with the same offset as the canonical.
    398         Chunk chunk = FindOrDie(result.chunk_map, group->canonical);
    399         chunk.size = size_fn_(*buffer);
    400         result.chunk_map.emplace(buffer, chunk);
    401       }
    402     }
    403     // If we were told to assign specific buffers, make sure we've assigned
    404     // exactly that many buffers.
    405     if (options_.buffers_to_assign != nullptr) {
    406       CHECK_EQ(options_.buffers_to_assign->size(), result.chunk_map.size());
    407     }
    408   }
    409 
    410   // Fragmentation is the difference between the actual and ideal sizes.
    411   const Result no_frag_result = no_fragmentation_stats_->Finish();
    412   result.fragmentation_size = result.heap_size - no_frag_result.heap_size;
    413 
    414   // Copy the debug trace we collected to the final result.
    415   result.debug_trace.Swap(&debug_trace_);
    416 
    417   return result;
    418 }
    419 
    420 void HeapSimulator::FillDebugTrace(HeapSimulatorTrace::Event::Kind kind,
    421                                    const LogicalBuffer* buffer,
    422                                    const HloInstruction* instruction,
    423                                    const LogicalBuffer* share_with_canonical) {
    424   HeapSimulatorTrace::Event* event = debug_trace_.add_events();
    425   event->set_kind(kind);
    426   event->set_buffer_id(buffer->id());
    427   event->set_computation_name(instruction->parent()->name());
    428   event->set_instruction_name(instruction->name());
    429   if (kind == HeapSimulatorTrace::Event::SHARE_WITH) {
    430     CHECK(share_with_canonical != nullptr);
    431     event->set_share_with_canonical_id(share_with_canonical->id());
    432   } else {
    433     CHECK(share_with_canonical == nullptr);
    434   }
    435 }
    436 
    437 void NoFragmentationStatsHeap::Alloc(const LogicalBuffer* buffer, int64 size) {
    438   current_heap_size_ += size;
    439   if (current_heap_size_ > max_heap_size_) {
    440     max_heap_size_ = current_heap_size_;
    441   }
    442 }
    443 
    444 void NoFragmentationStatsHeap::Free(const LogicalBuffer* buffer, int64 size) {
    445   current_heap_size_ -= size;
    446 }
    447 
    448 HeapSimulator::Result NoFragmentationStatsHeap::Finish() {
    449   // The result.chunk_map is empty, since we only collect stats, and don't
    450   // actually compute chunk assignments.
    451   Result result;
    452   result.heap_size = max_heap_size_;
    453   return result;
    454 }
    455 
    456 void DecreasingSizeRunsHeap::Alloc(const LogicalBuffer* buffer, int64 size) {
    457   SetMode(kAlloc);
    458   run_.emplace_back(Op{buffer, size});
    459 }
    460 
    461 void DecreasingSizeRunsHeap::Free(const LogicalBuffer* buffer, int64 size) {
    462   CHECK(mode_ != kInit) << "Free called on empty heap: " << *buffer;
    463   SetMode(kFree);
    464   run_.emplace_back(Op{buffer, size});
    465 }
    466 
    467 HeapSimulator::Result DecreasingSizeRunsHeap::Finish() {
    468   CallAndDrainRun();
    469   return algorithm_->Finish();
    470 }
    471 
    472 void DecreasingSizeRunsHeap::SetMode(Mode mode) {
    473   if (mode_ != mode) {
    474     CallAndDrainRun();
    475     mode_ = mode;
    476   }
    477 }
    478 
    479 void DecreasingSizeRunsHeap::CallAndDrainRun() {
    480   if (mode_ == kInit) {
    481     CHECK(run_.empty());
    482     return;
    483   }
    484 
    485   // Call ops in the run sorted by decreasing size, breaking ties by buffer id.
    486   std::sort(run_.begin(), run_.end(), [](const Op& a, const Op& b) {
    487     if (a.size != b.size) {
    488       return a.size > b.size;
    489     }
    490     return a.buffer->id() < b.buffer->id();
    491   });
    492   for (const Op& op : run_) {
    493     if (mode_ == kAlloc) {
    494       algorithm_->Alloc(op.buffer, op.size);
    495     } else {
    496       algorithm_->Free(op.buffer, op.size);
    497     }
    498   }
    499   run_.clear();
    500 }
    501 
    502 void LazyBestFitHeap::Alloc(const LogicalBuffer* buffer, int64 size) {
    503   // Degenerate case: 0-sized buffers are always allocated at offset 0.
    504   if (size == 0) {
    505     result_.chunk_map.emplace(buffer, Chunk{0, 0});
    506   }
    507 
    508   // First try to allocate from the best-fitting free chunk.
    509   auto best_fit_it = free_.lower_bound(Chunk{0, size});
    510   while (best_fit_it != free_.end()) {
    511     // Account for alignment.
    512     const Chunk best = *best_fit_it;
    513     const int64 new_offset = RoundUpToNearest(best.offset, alignment_);
    514     const int64 new_end = new_offset + size;
    515     if (new_end > best.chunk_end()) {
    516       // We don't fit after accounting for alignment.
    517       ++best_fit_it;
    518       continue;
    519     }
    520     // The buffer is allocated a chunk out of the best-fitting free chunk.
    521     free_.erase(best_fit_it);
    522     result_.chunk_map.emplace(buffer, Chunk{new_offset, size});
    523     // Add remaining portions of the best-fitting free chunk back into free_.
    524     AddFreeChunk(best.offset, new_offset - best.offset);
    525     AddFreeChunk(new_end, best.chunk_end() - new_end);
    526     return;
    527   }
    528 
    529   // The buffer doesn't completely fit into any existing free chunk.  If the
    530   // last free chunk is adjacent to the end of the heap, allocate the buffer
    531   // re-using that space, increasing the heap size.
    532   //
    533   // Allocating the buffer now causes the heap to grow by less than the buffer
    534   // size, whereas if we allocated lazily in Free, the heap would grow by
    535   // exactly the buffer size.  However it's still a greedy heuristical approach;
    536   // we might have ended up with a tighter packing by being lazy here.
    537   //
    538   // In theory we could also check if we could re-use space from the first free
    539   // chunk and grow the heap at the front, and choose whether to grow from the
    540   // front or back based on the amount of re-use.  But that's more complicated,
    541   // and these are all heuristics anyways, so it isn't implemented.
    542   for (auto it = free_.begin(); it != free_.end(); ++it) {
    543     if (it->chunk_end() == result_.heap_size) {
    544       // Account for alignment in the last free chunk.
    545       const Chunk last = *it;
    546       const int64 new_offset = RoundUpToNearest(last.offset, alignment_);
    547       if (new_offset >= last.chunk_end()) {
    548         // There's no point in using the last free chunk if alignment causes us
    549         // to skip over it anyways.
    550         break;
    551       }
    552       // The buffer is allocated a chunk that includes the last free chunk.
    553       free_.erase(it);
    554       result_.chunk_map.emplace(buffer, Chunk{new_offset, size});
    555       // Add remaining portion of the last free chunk back into free_.
    556       AddFreeChunk(last.offset, new_offset - last.offset);
    557       // Grow the heap.
    558       const int64 new_end = new_offset + size;
    559       CHECK_GT(new_end, result_.heap_size);
    560       CHECK_LT(new_end, result_.heap_size + size);
    561       result_.heap_size = new_end;
    562       return;
    563     }
    564   }
    565 
    566   // Otherwise lazily allocate the buffer in Free.
    567   result_.chunk_map.emplace(buffer, Chunk{kLazyAllocOffset, size});
    568 }
    569 
    570 void LazyBestFitHeap::Free(const LogicalBuffer* buffer, int64 size) {
    571   auto alloc_it = result_.chunk_map.find(buffer);
    572   CHECK(alloc_it != result_.chunk_map.end())
    573       << "Free called on non-allocated buffer: " << *buffer;
    574   Chunk* alloc = &alloc_it->second;
    575   CHECK_EQ(alloc->size, size) << "Free with mismatched sizes: " << *buffer;
    576   if (alloc->offset != kLazyAllocOffset) {
    577     // The buffer was already allocated in Alloc, do a normal free.
    578     AddFreeChunk(alloc->offset, alloc->size);
    579   } else {
    580     // This buffer is lazily allocated, so we *can not* allocate out of existing
    581     // free chunks, since that might cause interference between buffers.  The
    582     // buffer is allocated by growing the heap, accounting for alignment.
    583     alloc->offset = RoundUpToNearest(result_.heap_size, alignment_);
    584     const int64 new_end = alloc->chunk_end();
    585     AddFreeChunk(result_.heap_size, new_end - result_.heap_size);
    586     CHECK_GT(new_end, result_.heap_size);
    587     CHECK_GE(new_end, result_.heap_size + alloc->size);
    588     result_.heap_size = new_end;
    589   }
    590 }
    591 
    592 void LazyBestFitHeap::AddFreeChunk(int64 offset, int64 size) {
    593   if (size <= 0) {
    594     return;
    595   }
    596 
    597   // Coalesce the chunk with adjacent free chunks on either side.  We must
    598   // remove the free chunks from free_, since it's ordered by size.
    599   Chunk chunk{offset, size};
    600   for (auto it = free_.begin(); it != free_.end();) {
    601     if (it->chunk_end() == chunk.offset || it->offset == chunk.chunk_end()) {
    602       chunk.offset = std::min(chunk.offset, it->offset);
    603       chunk.size += it->size;
    604       it = free_.erase(it);
    605     } else {
    606       ++it;
    607     }
    608   }
    609 
    610   // This is the only place we add free chunks to free_.  It maintains the
    611   // invariant that all free chunks are disjoint and non-adjacent.
    612   free_.emplace(chunk);
    613 }
    614 
    615 HeapSimulator::Result LazyBestFitHeap::Finish() {
    616   if (!free_.empty()) {
    617     // When Finish is called, all calls to Alloc must have had corresponding
    618     // calls to Free, which will result in a single free chunk [0, heap_size).
    619     CHECK_EQ(free_.size(), 1);
    620     CHECK_EQ(free_.begin()->offset, 0);
    621     CHECK_EQ(free_.begin()->size, result_.heap_size);
    622   }
    623   return result_;
    624 }
    625 
    626 }  // namespace xla
    627