Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Defines the data returned by the XLA buffer assignment packages.
     17 
     18 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
     19 
     20 #include <algorithm>
     21 #include <deque>
     22 #include <ostream>
     23 #include <utility>
     24 
     25 #include "tensorflow/compiler/xla/map_util.h"
     26 #include "tensorflow/compiler/xla/ptr_util.h"
     27 #include "tensorflow/compiler/xla/service/heap_simulator.h"
     28 #include "tensorflow/compiler/xla/service/hlo.pb.h"
     29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     30 #include "tensorflow/compiler/xla/service/hlo_scheduling.h"
     31 #include "tensorflow/compiler/xla/shape_util.h"
     32 #include "tensorflow/compiler/xla/status_macros.h"
     33 #include "tensorflow/compiler/xla/types.h"
     34 #include "tensorflow/compiler/xla/util.h"
     35 #include "tensorflow/core/lib/core/errors.h"
     36 #include "tensorflow/core/lib/hash/hash.h"
     37 #include "tensorflow/core/lib/strings/numbers.h"
     38 #include "tensorflow/core/lib/strings/str_util.h"
     39 #include "tensorflow/core/lib/strings/strcat.h"
     40 #include "tensorflow/core/lib/strings/stringprintf.h"
     41 
     42 namespace xla {
     43 
     44 using ::tensorflow::gtl::FlatMap;
     45 using ::tensorflow::gtl::FlatSet;
     46 using ::tensorflow::strings::Appendf;
     47 using ::tensorflow::strings::HumanReadableNumBytes;
     48 using ::tensorflow::strings::Printf;
     49 using ::tensorflow::strings::StrAppend;
     50 
     51 size_t BufferAllocation::Slice::Hasher::operator()(Slice s) const {
     52   uint64 h = std::hash<int64>()(s.index());
     53   h = tensorflow::Hash64Combine(h, std::hash<int64>()(s.offset()));
     54   h = tensorflow::Hash64Combine(h, std::hash<int64>()(s.size()));
     55   return h;
     56 }
     57 
     58 string BufferAllocation::Slice::ToString() const {
     59   return tensorflow::strings::StrCat("{index:", index(), ", offset:", offset_,
     60                                      ", size:", size_, "}");
     61 }
     62 
     63 BufferAllocation::Slice BufferAllocation::GetSlice(
     64     const LogicalBuffer& buffer) const {
     65   const OffsetSize os = FindOrDie(assigned_buffers_, &buffer);
     66   return Slice(this, os.offset, os.size);
     67 }
     68 
     69 void BufferAllocation::AddAssignment(const LogicalBuffer& buffer, int64 offset,
     70                                      int64 size) {
     71   VLOG(4) << "Trying to add " << buffer << " to " << this;
     72   CHECK(assigned_buffers_.count(&buffer) == 0)
     73       << "LogicalBuffer " << buffer << " already assigned to allocation "
     74       << index_;
     75   CHECK_LE(offset, size_) << "LogicalBuffer " << buffer
     76                           << " offset out of range";
     77   CHECK_LE(offset + size, size_)
     78       << "LogicalBuffer " << buffer
     79       << " size out of range at offset: " << offset << " with size: " << size;
     80   CHECK_EQ(buffer.color(), color())
     81       << "Buffer color " << buffer.color() << " for buffer " << buffer
     82       << " does not match allocation color " << color() << ".";
     83   OffsetSize offset_size;
     84   offset_size.offset = offset;
     85   offset_size.size = size;
     86   assigned_buffers_.emplace(&buffer, offset_size);
     87 }
     88 
     89 BufferAllocationProto BufferAllocation::ToProto() const {
     90   BufferAllocationProto proto;
     91   proto.set_index(index_);
     92   proto.set_size(size_);
     93   proto.set_is_thread_local(is_thread_local_);
     94   proto.set_is_reusable(is_reusable_);
     95   proto.set_color(color_.value());
     96   if (is_entry_computation_parameter_) {
     97     proto.set_is_entry_computation_parameter(true);
     98     for (int64 idx : param_shape_index()) {
     99       proto.add_parameter_shape_index(idx);
    100     }
    101     proto.set_parameter_number(parameter_number_);
    102   }
    103   proto.set_maybe_live_out(maybe_live_out_);
    104   for (const auto& buffer_offset_size : assigned_buffers_) {
    105     BufferAllocationProto::Assigned* proto_assigned = proto.add_assigned();
    106     proto_assigned->set_logical_buffer_id(buffer_offset_size.first->id());
    107     proto_assigned->set_offset(buffer_offset_size.second.offset);
    108     proto_assigned->set_size(buffer_offset_size.second.size);
    109   }
    110   std::sort(proto.mutable_assigned()->begin(), proto.mutable_assigned()->end(),
    111             [](const BufferAllocationProto::Assigned& assign1,
    112                const BufferAllocationProto::Assigned& assign2) {
    113               return assign1.logical_buffer_id() < assign2.logical_buffer_id();
    114             });
    115   return proto;
    116 }
    117 
    118 string BufferAllocation::ToString() const {
    119   string output;
    120   Appendf(&output, "allocation %lld: %p, size %lld", index_, this, size());
    121   if (color().value() != 0) {
    122     StrAppend(&output, ", color ", color().value());
    123   }
    124   if (is_entry_computation_parameter()) {
    125     StrAppend(&output, ", parameter ", parameter_number(), " at ShapeIndex ",
    126               param_shape_index().ToString());
    127   }
    128   if (is_thread_local()) {
    129     StrAppend(&output, ", thread-local");
    130   }
    131   if (maybe_live_out()) {
    132     StrAppend(&output, ", maybe-live-out");
    133   }
    134   if (IsPreallocatedTempBuffer()) {
    135     StrAppend(&output, ", preallocated-temp");
    136   }
    137   StrAppend(&output, ":\n");
    138   // Dump the assigned buffers ordered by id.
    139   std::vector<const LogicalBuffer*> sorted_buffers;
    140   for (const auto& buffer_offset_size : assigned_buffers_) {
    141     sorted_buffers.push_back(buffer_offset_size.first);
    142   }
    143   std::sort(sorted_buffers.begin(), sorted_buffers.end(),
    144             [](const LogicalBuffer* a, const LogicalBuffer* b) {
    145               return a->id() < b->id();
    146             });
    147   for (const LogicalBuffer* buffer : sorted_buffers) {
    148     const OffsetSize& offset_size = FindOrDie(assigned_buffers_, buffer);
    149     StrAppend(&output,
    150               tensorflow::strings::Printf(
    151                   "  %s [%lld,%lld]: %s\n", buffer->ToString().c_str(),
    152                   offset_size.offset, offset_size.size,
    153                   ShapeUtil::HumanStringWithLayout(buffer->shape()).c_str()));
    154   }
    155   return output;
    156 }
    157 
    158 std::ostream& operator<<(std::ostream& out, const BufferAllocation& buffer) {
    159   out << buffer.ToString();
    160   return out;
    161 }
    162 
    163 std::ostream& operator<<(std::ostream& out, const BufferAllocation::Slice& s) {
    164   out << s.ToString();
    165   return out;
    166 }
    167 
    168 const PointsToSet& BufferAssignment::GetPointsToSet(
    169     const HloInstruction* instruction) const {
    170   return points_to_analysis().GetPointsToSet(instruction);
    171 }
    172 
    173 bool BufferAssignment::HasAllocation(const LogicalBuffer& buffer) const {
    174   TF_CHECK_OK(points_to_analysis().VerifyBuffer(buffer));
    175   return allocation_index_for_buffer_.count(&buffer) > 0;
    176 }
    177 
    178 const BufferAllocation& BufferAssignment::GetAssignedAllocation(
    179     const LogicalBuffer& buffer) const {
    180   CHECK(HasAllocation(buffer));
    181   return GetAllocation(allocation_index_for_buffer_.at(&buffer));
    182 }
    183 
    184 BufferAllocation* BufferAssignment::GetMutableAssignedAllocation(
    185     const LogicalBuffer& buffer) {
    186   return const_cast<BufferAllocation*>(&GetAssignedAllocation(buffer));
    187 }
    188 
    189 std::set<BufferAllocation::Slice> BufferAssignment::GetAllSlices(
    190     const HloInstruction* instruction, const ShapeIndex& index) const {
    191   std::set<BufferAllocation::Slice> result;
    192   for (const LogicalBuffer* buffer : GetSourceBuffers(instruction, index)) {
    193     if (HasAllocation(*buffer)) {
    194       result.insert(GetAssignedAllocation(*buffer).GetSlice(*buffer));
    195     }
    196   }
    197   return result;
    198 }
    199 
    200 const BufferAllocation& BufferAssignment::GetAllocation(
    201     BufferAllocation::Index index) const {
    202   CHECK_GE(index, 0);
    203   CHECK_LT(index, allocations_.size());
    204   return allocations_[index];
    205 }
    206 
    207 BufferAllocation* BufferAssignment::GetMutableAllocation(
    208     BufferAllocation::Index index) {
    209   return const_cast<BufferAllocation*>(&GetAllocation(index));
    210 }
    211 
    212 bool BufferAssignment::HasAllocationAt(const HloInstruction* instruction,
    213                                        const ShapeIndex& index) const {
    214   for (const LogicalBuffer* buffer :
    215        GetPointsToSet(instruction).element(index)) {
    216     if (allocation_index_for_buffer_.count(buffer) > 0) {
    217       return true;
    218     }
    219   }
    220   return false;
    221 }
    222 
    223 bool BufferAssignment::HasTopLevelAllocation(
    224     const HloInstruction* instruction) const {
    225   return HasAllocationAt(instruction, /*index=*/{});
    226 }
    227 
    228 StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueSlice(
    229     const HloInstruction* instruction, const ShapeIndex& index) const {
    230   VLOG(3) << "Trying to find unique slice for " << instruction->name() << " ["
    231           << index << "]";
    232   BufferAllocation::Slice result;
    233   for (const LogicalBuffer* buffer :
    234        GetPointsToSet(instruction).element(index)) {
    235     VLOG(3) << "Examining buffer " << *buffer;
    236     if (HasAllocation(*buffer)) {
    237       VLOG(3) << "Has allocation";
    238       const BufferAllocation::Slice slice =
    239           GetAssignedAllocation(*buffer).GetSlice(*buffer);
    240       if (result.allocation() == nullptr) {
    241         result = slice;
    242       } else if (result != slice) {
    243         return FailedPrecondition(
    244             "BufferAllocation::Slice for instruction %s at index %s cannot "
    245             "be determined at compile-time.",
    246             instruction->name().c_str(), index.ToString().c_str());
    247       }
    248     } else {
    249       VLOG(3) << "No allocation";
    250     }
    251   }
    252   if (result.allocation() == nullptr) {
    253     return FailedPrecondition(
    254         "BufferAllocation::Slice not assigned for instruction %s at index %s",
    255         instruction->name().c_str(), index.ToString().c_str());
    256   }
    257   return result;
    258 }
    259 
    260 StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueTopLevelSlice(
    261     const HloInstruction* instruction) const {
    262   return GetUniqueSlice(instruction, /*index=*/{});
    263 }
    264 
    265 bool BufferAssignment::SharesSliceAtIndex(
    266     const HloInstruction* hlo_a, const ShapeIndex& shape_index_a,
    267     const HloInstruction* hlo_b, const ShapeIndex& shape_index_b) const {
    268   return GetUniqueSlice(hlo_a, shape_index_a).ConsumeValueOrDie() ==
    269          GetUniqueSlice(hlo_b, shape_index_b).ConsumeValueOrDie();
    270 }
    271 
    272 bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a,
    273                                           const HloInstruction* hlo_b) const {
    274   using SliceSet =
    275       FlatSet<BufferAllocation::Slice, BufferAllocation::Slice::Hasher>;
    276   // Gets the slices all of instr's subshapes.  If any subshape doesn't have an
    277   // assigned slice, returns the empty set.
    278   auto collect_slices = [&](const HloInstruction* instr) -> SliceSet {
    279     SliceSet slices;
    280     Status status = ShapeUtil::ForEachSubshapeWithStatus(
    281         instr->shape(),
    282         [&](const Shape& /*subshape*/, const ShapeIndex& index) {
    283           auto shape_slices = GetAllSlices(instr, index);
    284           if (shape_slices.empty()) {
    285             return InvalidArgument("No slices assigned to part of instr.");
    286           }
    287           slices.insert(shape_slices.begin(), shape_slices.end());
    288           return Status::OK();
    289         });
    290     if (!status.ok()) {
    291       return {};
    292     }
    293     return slices;
    294   };
    295 
    296   SliceSet slices_a = collect_slices(hlo_a);
    297   SliceSet slices_b = collect_slices(hlo_b);
    298   // hlo_a and hlo_b have disjoint slices if collect_slices succeeded (i.e.
    299   // didn't return the empty set) for both HLOs, and the two resulting sets of
    300   // slices are disjoint.
    301   return !slices_a.empty() && !slices_b.empty() &&
    302          std::none_of(slices_a.begin(), slices_a.end(),
    303                       [&](const BufferAllocation::Slice& slice) {
    304                         return slices_b.count(slice) > 0;
    305                       });
    306 }
    307 
    308 StatusOr<BufferAllocation::Slice>
    309 BufferAssignment::GetUniqueTopLevelOutputSlice() const {
    310   return GetUniqueTopLevelSlice(
    311       module_->entry_computation()->root_instruction());
    312 }
    313 
    314 BufferAllocation* BufferAssignment::NewEmptyAllocation(
    315     int64 size, bool is_thread_local, bool is_reusable,
    316     LogicalBuffer::Color color) {
    317   BufferAllocation::Index index = allocations_.size();
    318   allocations_.emplace_back(index, size, is_thread_local, is_reusable, color);
    319   BufferAllocation* allocation = &allocations_.back();
    320   return allocation;
    321 }
    322 
    323 BufferAllocation* BufferAssignment::NewAllocation(const LogicalBuffer& buffer,
    324                                                   int64 size,
    325                                                   bool is_thread_local,
    326                                                   bool is_reusable) {
    327   BufferAllocation* allocation =
    328       NewEmptyAllocation(size, is_thread_local, is_reusable, buffer.color());
    329   AddAssignment(allocation, buffer, /*offset=*/0, size);
    330   return allocation;
    331 }
    332 
    333 // Adds an instruction to the set assigned to the given buffer.
    334 void BufferAssignment::AddAssignment(BufferAllocation* allocation,
    335                                      const LogicalBuffer& buffer, int64 offset,
    336                                      int64 size) {
    337   CHECK_EQ(0, allocation_index_for_buffer_.count(&buffer))
    338       << "LogicalBuffer " << buffer << " already has an allocation.";
    339   CHECK(allocation->is_reusable() || allocation->assigned_buffers().empty())
    340       << "Non-reusable allocation already assigned a buffer";
    341 
    342   TF_CHECK_OK(points_to_analysis().VerifyBuffer(buffer));
    343 
    344   allocation->AddAssignment(buffer, offset, size);
    345   allocation_index_for_buffer_[&buffer] = allocation->index();
    346 }
    347 
    348 // Combines allocations of temporary buffers of the same color into one big
    349 // BufferAllocation.
    350 void BufferAssignment::CombineTempAllocations() {
    351   FlatMap<LogicalBuffer::Color, BufferAllocation, LogicalBuffer::Color::Hasher>
    352       combined_allocation_map;
    353 
    354   // Move all temp allocations into a single run at the end of the allocations
    355   // vector.
    356   const auto first_temp_it =
    357       std::partition(allocations_.begin(), allocations_.end(),
    358                      [](const BufferAllocation& allocation) {
    359                        return !allocation.IsPreallocatedTempBuffer();
    360                      });
    361 
    362   // Walk over the run of temp allocations, collecting the allocations belonging
    363   // to the same color.
    364   if (first_temp_it != allocations_.end()) {
    365     for (auto it = first_temp_it; it != allocations_.end(); ++it) {
    366       const BufferAllocation& temp_allocation = *it;
    367       LogicalBuffer::Color color = temp_allocation.color();
    368       auto combined_it = combined_allocation_map.find(color);
    369       if (combined_it == combined_allocation_map.end()) {
    370         // We have found the first temp allocation of this color. Collect
    371         // the other temp allocations of the same color into it.
    372         combined_allocation_map.emplace(color, temp_allocation);
    373         continue;
    374       }
    375 
    376       auto* combined_allocation = &combined_it->second;
    377       // Each temp allocation is placed end-to-end, accounting for alignment.
    378       // The offset of each buffer in the combined allocation is computed from
    379       // the base offset of the allocation.
    380       int64 alignment = color_alignment_(color);
    381       const int64 base =
    382           RoundUpToNearest(combined_allocation->size(), alignment);
    383       combined_allocation->set_size(base + temp_allocation.size());
    384       for (const auto& buffer_offset_size : temp_allocation.assigned_buffers_) {
    385         const LogicalBuffer* buffer = buffer_offset_size.first;
    386         const int64 offset = buffer_offset_size.second.offset;
    387         const int64 size = buffer_offset_size.second.size;
    388         combined_allocation->AddAssignment(*buffer, base + offset, size);
    389       }
    390     }
    391     // Replace all existing temporary allocations with the new combined
    392     // allocations.
    393     allocations_.erase(first_temp_it, allocations_.end());
    394     for (auto& combined : combined_allocation_map) {
    395       allocations_.push_back(combined.second);
    396       temp_allocation_total_size_ += combined.second.size();
    397     }
    398   }
    399 
    400   // Update allocation indices to their new positions.
    401   allocation_index_for_buffer_.clear_no_resize();
    402   for (size_t index = 0; index < allocations_.size(); ++index) {
    403     BufferAllocation* allocation = &allocations_[index];
    404     allocation->set_index(index);
    405     for (const auto& buffer_offset_size : allocation->assigned_buffers_) {
    406       const LogicalBuffer* buffer = buffer_offset_size.first;
    407       allocation_index_for_buffer_[buffer] = index;
    408     }
    409   }
    410 }
    411 
    412 Status BufferAssignment::ComputeSummaryStats() {
    413   for (auto& allocation : Allocations()) {
    414     if (allocation.is_entry_computation_parameter()) {
    415       stats_.parameter_allocation_count++;
    416       stats_.parameter_allocation_bytes += allocation.size();
    417     }
    418     if (allocation.maybe_live_out()) {
    419       stats_.maybe_live_out_allocation_count++;
    420       stats_.maybe_live_out_allocation_bytes += allocation.size();
    421     }
    422     if (allocation.IsPreallocatedTempBuffer()) {
    423       stats_.preallocated_temp_allocation_count++;
    424       stats_.preallocated_temp_allocation_bytes += allocation.size();
    425     }
    426     stats_.total_allocation_count++;
    427     stats_.total_allocation_bytes += allocation.size();
    428   }
    429 
    430   // Only compute total fragmentation if all computations are sequential.
    431   SequentialHloOrdering::HloModuleSequence module_sequence;
    432   for (const auto& computation : module_->computations()) {
    433     const std::vector<const HloInstruction*>* sequence =
    434         liveness_->hlo_ordering().SequentialOrder(*computation);
    435     if (sequence != nullptr) {
    436       module_sequence.emplace(computation, *sequence);
    437     }
    438   }
    439   if (module_sequence.size() == module_->computation_count()) {
    440     TF_ASSIGN_OR_RETURN(
    441         const int64 min_size,
    442         MinimumMemoryForSequence(module_sequence, buffer_size_));
    443     stats_.total_fragmentation_bytes = stats_.total_allocation_bytes - min_size;
    444   }
    445 
    446   return Status::OK();
    447 }
    448 
    449 string BufferAssignment::Stats::ToString() const {
    450   string s;
    451   Appendf(&s, "BufferAssignment stats:\n");
    452   Appendf(&s, "             parameter allocation: %10s\n",
    453           HumanReadableNumBytes(parameter_allocation_bytes).c_str());
    454   Appendf(&s, "        maybe_live_out allocation: %10s\n",
    455           HumanReadableNumBytes(maybe_live_out_allocation_bytes).c_str());
    456   Appendf(&s, "     preallocated temp allocation: %10s\n",
    457           HumanReadableNumBytes(preallocated_temp_allocation_bytes).c_str());
    458   if (preallocated_temp_fragmentation_bytes >= 0) {
    459     const double percent = 100. * preallocated_temp_fragmentation_bytes /
    460                            preallocated_temp_allocation_bytes;
    461     Appendf(
    462         &s, "  preallocated temp fragmentation: %10s (%.2f%%)\n",
    463         HumanReadableNumBytes(preallocated_temp_fragmentation_bytes).c_str(),
    464         percent);
    465   }
    466   Appendf(&s, "                 total allocation: %10s\n",
    467           HumanReadableNumBytes(total_allocation_bytes).c_str());
    468   if (total_fragmentation_bytes >= 0) {
    469     const double percent =
    470         100. * total_fragmentation_bytes / total_allocation_bytes;
    471     Appendf(&s, "              total fragmentation: %10s (%.2f%%)\n",
    472             HumanReadableNumBytes(total_fragmentation_bytes).c_str(), percent);
    473   }
    474   return s;
    475 }
    476 
    477 string BufferAssignment::ToString() const {
    478   string output;
    479   tensorflow::strings::StrAppend(&output, "BufferAssignment:\n");
    480   for (auto& allocation : allocations_) {
    481     tensorflow::strings::StrAppend(&output, allocation.ToString());
    482   }
    483   return output;
    484 }
    485 
    486 BufferAssignmentProto BufferAssignment::ToProto() const {
    487   BufferAssignmentProto proto;
    488   // NOTE: TuplePointsToAnalysis state is serialized here in BufferAssigment,
    489   // because we need to do the HasAllocation check for each buffer. Otherwise
    490   // the buffer_size_ call might fail for some backends.
    491   const TuplePointsToAnalysis& points_to_analysis =
    492       liveness_->points_to_analysis();
    493   for (LogicalBuffer::Id id = 0; id < points_to_analysis.num_logical_buffers();
    494        id++) {
    495     auto& buffer = points_to_analysis.logical_buffer(id);
    496     if (HasAllocation(buffer)) {
    497       LogicalBufferProto proto_buffer = buffer.ToProto(buffer_size_);
    498       proto.add_logical_buffers()->Swap(&proto_buffer);
    499 
    500       // Fill buffer aliases.
    501       for (const BufferAlias& alias :
    502            points_to_analysis.GetBufferAliases(buffer)) {
    503         if (alias.instruction() == buffer.instruction() &&
    504             alias.index() == buffer.index()) {
    505           continue;  // skip self-aliases
    506         }
    507         BufferAssignmentProto::BufferAlias* proto_alias =
    508             proto.add_buffer_aliases();
    509         LogicalBufferProto::Location proto_alias_location =
    510             LogicalBuffer::ToLocationProto(*alias.instruction(), alias.index());
    511         proto_alias->set_source_buffer_id(buffer.id());
    512         proto_alias->mutable_location()->Swap(&proto_alias_location);
    513       }
    514     }
    515   }
    516   for (const BufferAllocation& allocation : Allocations()) {
    517     BufferAllocationProto proto_allocation = allocation.ToProto();
    518     proto.add_buffer_allocations()->Swap(&proto_allocation);
    519   }
    520   for (const HeapSimulatorTrace& trace : heap_simulator_traces_) {
    521     *proto.add_heap_simulator_traces() = trace;
    522   }
    523   return proto;
    524 }
    525 
    526 namespace {
    527 
    528 // Walk the call graph of the HLO module and place each computation into either
    529 // thread_local_computations or global_computations depending upon whether the
    530 // computation requires thread-local allocations or global allocations. The
    531 // elements in thread_local_computations and global_computations are in post
    532 // order (if computation A has an instruction which calls computation B, then A
    533 // will appear after B in the vector).
    534 Status GatherComputationsByAllocationType(
    535     const HloModule* module,
    536     std::vector<const HloComputation*>* thread_local_computations,
    537     std::vector<const HloComputation*>* global_computations) {
    538   // Create a worklist of computations paired with whether the allocation must
    539   // be thread-local.
    540   std::deque<std::pair<const HloComputation*, bool>> worklist;
    541   worklist.push_back(std::make_pair(module->entry_computation(),
    542                                     /*is_thread_local*/ false));
    543 
    544   // Sets for quickly checking membership. Computations are returned in vectors
    545   // for stable iteration.
    546   FlatSet<const HloComputation*> thread_local_set;
    547   FlatSet<const HloComputation*> global_set;
    548 
    549   while (!worklist.empty()) {
    550     auto worklist_front = worklist.front();
    551     worklist.pop_front();
    552     const HloComputation* computation = worklist_front.first;
    553     bool is_thread_local = worklist_front.second;
    554     bool in_thread_local_set = thread_local_set.count(computation) > 0;
    555     bool in_global_set = global_set.count(computation) > 0;
    556 
    557     // If the computation has already been added to the respective set, then
    558     // nothing to do.
    559     if ((is_thread_local && in_thread_local_set) ||
    560         (!is_thread_local && in_global_set)) {
    561       continue;
    562     }
    563 
    564     // If the computation has already been added to the other set this is an
    565     // error condition because the global call to the computation (eg,
    566     // while/call) may return a reference to one of the thread-local buffers to
    567     // the calling computation which will become a dangling reference when the
    568     // thread-local is deallocated with the call return.
    569     if ((is_thread_local && in_global_set) ||
    570         (!is_thread_local && in_thread_local_set)) {
    571       return InvalidArgument(
    572           "computation %s has conflicting allocation requirements (global "
    573           "and thread-local)",
    574           computation->name().c_str());
    575     }
    576 
    577     if (is_thread_local) {
    578       thread_local_set.insert(computation);
    579     } else {
    580       global_set.insert(computation);
    581     }
    582 
    583     for (auto* instruction : computation->instructions()) {
    584       for (HloComputation* subcomputation :
    585            instruction->called_computations()) {
    586         switch (instruction->opcode()) {
    587           case HloOpcode::kCall:
    588           case HloOpcode::kConditional:
    589           case HloOpcode::kWhile:
    590             // Call and while must be called from a computation with global
    591             // allocations as they may return references to buffers inside the
    592             // called computation which cannot be thread-local.
    593             if (is_thread_local) {
    594               return InvalidArgument(
    595                   "computation %s cannot contain call/while op because it "
    596                   "requires thread-local buffer allocations",
    597                   computation->name().c_str());
    598             }
    599             worklist.push_back(std::make_pair(subcomputation,
    600                                               false));  // Not thread local.
    601             break;
    602           case HloOpcode::kMap:
    603           case HloOpcode::kReduce:
    604           case HloOpcode::kReduceWindow:
    605           case HloOpcode::kSelectAndScatter:
    606           case HloOpcode::kFusion:
    607             // Map/reduce etc computations are always thread-local.
    608             worklist.push_back(std::make_pair(subcomputation,
    609                                               true));  // Thread local.
    610             break;
    611           default:
    612             return InternalError(
    613                 "Unexpected calling opcode: %s",
    614                 HloOpcodeString(instruction->opcode()).c_str());
    615         }
    616       }
    617     }
    618   }
    619 
    620   // Add the computations to the vectors in post order.
    621   for (auto* computation : module->MakeComputationPostOrder()) {
    622     if (thread_local_set.count(computation) > 0) {
    623       thread_local_computations->push_back(computation);
    624     } else if (global_set.count(computation) > 0) {
    625       global_computations->push_back(computation);
    626     }
    627     // If the computation is not reachable from the entry computation, then it
    628     // will not appear in either thread_local_set or global_set. We don't bother
    629     // assigning buffers for these.
    630   }
    631   return Status::OK();
    632 }
    633 
    634 }  // namespace
    635 
    636 /* static */
    637 StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::Run(
    638     const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
    639     LogicalBuffer::SizeFunction buffer_size,
    640     LogicalBuffer::AlignmentFunction color_alignment,
    641     bool allow_input_output_aliasing, BufferLiveness::Colorer colorer) {
    642   BufferAssigner assigner(allow_input_output_aliasing, std::move(colorer));
    643   return assigner.CreateAssignment(module, std::move(hlo_ordering),
    644                                    std::move(buffer_size),
    645                                    std::move(color_alignment));
    646 }
    647 
    648 bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
    649                                        const LogicalBuffer& buffer,
    650                                        BufferAssignment* assignment) {
    651   const LogicalBuffer::SizeFunction& buffer_size = assignment->buffer_size_;
    652 
    653   CHECK(!assignment->HasAllocation(buffer))
    654       << "buffer " << buffer << " already has an allocation assigned.";
    655 
    656   VLOG(4) << "Trying to assign " << buffer << " to allocation: " << *allocation;
    657 
    658   if (buffer.color() != allocation->color()) {
    659     VLOG(4) << "Can't assign: buffer has color" << buffer.color()
    660             << " and allocation has color " << allocation->color() << ".";
    661     return false;
    662   }
    663 
    664   if (buffer_size(buffer) > allocation->size()) {
    665     VLOG(4) << "Can't assign: buffer is larger than allocation ("
    666             << buffer_size(buffer) << " > " << allocation->size() << ")";
    667     return false;
    668   }
    669 
    670   if (allocation->is_entry_computation_parameter()) {
    671     VLOG(4) << "Can't assign: allocation holds parameter";
    672     return false;
    673   }
    674 
    675   if (!allocation->is_reusable()) {
    676     VLOG(4) << "Can't assign: allocation is not reusable";
    677     return false;
    678   }
    679 
    680   for (const auto& buffer_offset_size : allocation->assigned_buffers()) {
    681     const LogicalBuffer& assigned_buffer = *buffer_offset_size.first;
    682     if (assignment->liveness().MayInterfere(assigned_buffer, buffer)) {
    683       VLOG(4) << "Can't assign: assignee " << assigned_buffer
    684               << " may interfere with " << buffer;
    685       return false;
    686     }
    687     // Copy instruction don't share a buffer with their input operand.
    688     if (buffer.instruction()->IsUserOf(assigned_buffer.instruction()) &&
    689         buffer.instruction()->opcode() == HloOpcode::kCopy) {
    690       VLOG(4) << "Can't assign: assignee " << assigned_buffer
    691               << " is used at copy instruction " << buffer;
    692       return false;
    693     }
    694   }
    695 
    696   if (allow_input_output_aliasing_ && allocation->maybe_live_out()) {
    697     const HloComputation* entry_computation =
    698         assignment->module_->entry_computation();
    699     for (auto param : entry_computation->parameter_instructions()) {
    700       for (auto& param_buffer :
    701            assignment->points_to_analysis().GetBuffersDefinedByInstruction(
    702                param)) {
    703         if (assignment->liveness().MayInterfere(*param_buffer, buffer)) {
    704           VLOG(4) << "Can't assign: Parameter interference with result";
    705           return false;
    706         }
    707       }
    708     }
    709   }
    710 
    711   // If the buffer is live out of the computation then it should only be
    712   // assigned a buffer which exactly fits the result to avoid wasting memory
    713   // (result buffers can have arbitrary lifetimes).
    714   if (assignment->liveness().MaybeLiveOut(buffer) &&
    715       allocation->size() != buffer_size(buffer)) {
    716     VLOG(4) << "Can't assign: buffer " << buffer
    717             << "is live out and size not the same as allocation";
    718     return false;
    719   }
    720 
    721   assignment->AddAssignment(allocation, buffer, /*offset=*/0,
    722                             buffer_size(buffer));
    723   return true;
    724 }
    725 
    726 Status BufferAssigner::AssignBuffersForComputation(
    727     const HloComputation* computation, const DebugOptions& debug_options,
    728     bool is_thread_local,
    729     const FlatSet<const LogicalBuffer*>& colocated_buffers,
    730     const FlatSet<BufferAllocation::Index>& colocated_allocations,
    731     FlatMap<const HloComputation*, FlatSet<const LogicalBuffer*>>*
    732         buffers_to_assign_sequentially,
    733     BufferAssignment* assignment) {
    734   // Buffers are sorted and assigned to BufferAllocations in decreasing order of
    735   // size.
    736   std::vector<const LogicalBuffer*> sorted_buffers;
    737   for (auto* instruction : computation->instructions()) {
    738     // Add all buffers which this instruction defines. Instruction which don't
    739     // define buffers (eg, bitcast which just forwards a pointer) don't need
    740     // any allocations.
    741     for (const LogicalBuffer* buffer :
    742          assignment->points_to_analysis().GetBuffersDefinedByInstruction(
    743              instruction)) {
    744       sorted_buffers.push_back(buffer);
    745     }
    746   }
    747 
    748   // Generate a post order sort of instructions for sorting of the
    749   // LogicalBuffers.
    750   FlatMap<const HloInstruction*, int> post_order_position;
    751   int position = 0;
    752   for (auto* instruction : computation->MakeInstructionPostOrder()) {
    753     post_order_position.emplace(instruction, position);
    754     position++;
    755   }
    756 
    757   // If there is a sequential instruction ordering, we'll delay assignment of
    758   // temp buffers until after the main assignment loop.
    759   const BufferLiveness& liveness = assignment->liveness();
    760   const bool has_sequential_order =
    761       liveness.hlo_ordering().SequentialOrder(*computation) != nullptr;
    762   if (has_sequential_order && buffers_to_assign_sequentially != nullptr) {
    763     // Every sequential computation must get an entry in the
    764     // buffers_to_assign_sequentially map, even if we end up with an empty set
    765     // of buffers. This ensures we can correctly determine whether to run
    766     // whole-module heap simulation.
    767     buffers_to_assign_sequentially->emplace(computation,
    768                                             FlatSet<const LogicalBuffer*>());
    769   }
    770 
    771   // Sort the LogicalBuffers first by size. We assign the larger LogicalBuffers
    772   // first for simplicity. This means any previously created BufferAllocation is
    773   // necessarily large enough to hold the output of the current Buffer in
    774   // consideration.
    775   //
    776   // As a secondary sorting criteria, if the instructions are sequentially
    777   // ordered, we assign live-out buffers before others. Note that for sequential
    778   // computations, we'll take temp buffers that can't re-use any allocations and
    779   // assign them via a heap scheduler. By assigning live-out buffers first, we
    780   // increase the odds that temp buffers can re-use an allocation.
    781   //
    782   // As a final tiebreaker use post order position of the HLO instruction which
    783   // defines the buffer. This means an instruction will appear after its
    784   // operands (assuming operands are the same/larger size) enabling the
    785   // important reuse case where an elementwise instruction reuses one of its
    786   // operand's buffer. This improves locality.
    787   std::sort(sorted_buffers.begin(), sorted_buffers.end(),
    788             [this, has_sequential_order, &liveness, &post_order_position,
    789              assignment](const LogicalBuffer* a, const LogicalBuffer* b) {
    790               // Primary sort is by decreasing buffer size.
    791               const int64 a_size = assignment->buffer_size_(*a);
    792               const int64 b_size = assignment->buffer_size_(*b);
    793               if (a_size != b_size) {
    794                 return a_size > b_size;  // use ">" for decreasing size.
    795               }
    796               // Otherwise live out buffers come before others, if the
    797               // instructions are sequentially ordered.
    798               if (has_sequential_order) {
    799                 const bool a_live_out = liveness.MaybeLiveOut(*a);
    800                 const bool b_live_out = liveness.MaybeLiveOut(*b);
    801                 if (a_live_out != b_live_out) {
    802                   return a_live_out;
    803                 }
    804               }
    805               // Final tiebreaker is in instruction post order.
    806               return post_order_position.at(a->instruction()) <
    807                      post_order_position.at(b->instruction());
    808             });
    809 
    810   // BufferAllocations are necessarily created in decreasing size order. Keep
    811   // indices of previously created BufferAllocations in allocation_indices.
    812   std::vector<BufferAllocation::Index> allocation_indices;
    813   for (const LogicalBuffer* buffer : sorted_buffers) {
    814     VLOG(3) << "Assigning allocation to: " << *buffer;
    815     if (colocated_buffers.count(buffer) > 0) {
    816       // Colocated buffers are currently assigned in an earlier pass.
    817       VLOG(3) << "Skipping colocated buffer: " << *buffer;
    818       continue;
    819     }
    820 
    821     TF_RET_CHECK(!assignment->HasAllocation(*buffer));
    822 
    823     const HloInstruction* instruction = buffer->instruction();
    824     if (instruction->opcode() == HloOpcode::kConstant) {
    825       // No BufferAllocations for constants.
    826       // TODO(b/32248867): For consistency, constants should get allocations.
    827       VLOG(3) << "Skipping constant: " << *buffer;
    828       continue;
    829     }
    830 
    831     const int64 buffer_size = assignment->buffer_size_(*buffer);
    832 
    833     const bool is_entry_parameter =
    834         instruction->opcode() == HloOpcode::kParameter &&
    835         computation == computation->parent()->entry_computation();
    836     if (is_entry_parameter) {
    837       // If the LogicalBuffer is part of an external parameter, creates a new
    838       // allocation and sets its parameter number. Parameters of non-entry
    839       // computations do not need special allocations because they live inside
    840       // callers.
    841       BufferAllocation* allocation =
    842           assignment->NewAllocation(*buffer, buffer_size,
    843                                     /*is_thread_local=*/false,
    844                                     /*is_reusable=*/false);
    845       allocation->set_entry_computation_parameter(
    846           instruction->parameter_number(), buffer->index());
    847       VLOG(3) << "New allocation #" << allocation->index()
    848               << " for entry computation parameter: " << *buffer;
    849       continue;
    850     }
    851 
    852     if (is_thread_local) {
    853       // We do not reuse thread-local buffers for now, because they are
    854       // dynamically allocated and their lifetimes are hard to compute.
    855       BufferAllocation* allocation = assignment->NewAllocation(
    856           *buffer, buffer_size, is_thread_local, /*is_reusable=*/false);
    857       VLOG(3) << "New allocation #" << allocation->index()
    858               << " for thread-local: " << *buffer;
    859       continue;
    860     }
    861 
    862     if (ShapeUtil::IsTuple(buffer->shape())) {
    863       // TODO(b/34669761): Don't reuse tuple buffers because the GPU backend
    864       // assumes longer buffer liveness than indicated by the analysis.
    865       BufferAllocation* allocation = assignment->NewAllocation(
    866           *buffer, buffer_size, is_thread_local, /*is_reusable=*/false);
    867       VLOG(3) << "New allocation #" << allocation->index()
    868               << " for tuple-shaped buffer: " << *buffer;
    869       continue;
    870     }
    871 
    872     // First try to assign a LogicalBuffer to one of its operand allocations to
    873     // improve locality. This is only possible with elementwise operations
    874     // (checked in liveness analysis) which are necessarily top-level
    875     // array-shaped buffers.
    876     if (buffer->IsTopLevel() && !buffer->IsTuple()) {
    877       for (auto* operand : instruction->operands()) {
    878         bool assigned_operand = false;
    879         for (const auto& operand_slice :
    880              assignment->GetAllSlices(operand, /*index=*/{})) {
    881           BufferAllocation* allocation =
    882               assignment->GetMutableAllocation(operand_slice.index());
    883           if (colocated_allocations.count(allocation->index()) == 0) {
    884             // TODO(b/32491382) Colocated buffers are currently assigned in an
    885             // earlier pass, and so can break the "increasing allocation size"
    886             // invariant in this function (causing this CHECK to fail). However,
    887             // the call to MaybeAssignBuffer is safe as it returns false if
    888             // allocation.size < buffer.size.
    889             CHECK_GE(allocation->size(), buffer_size);
    890           }
    891           if (MaybeAssignBuffer(allocation, *buffer, assignment)) {
    892             VLOG(3) << "Reusing (operand) allocation #" << allocation->index()
    893                     << " for: " << *buffer;
    894             assigned_operand = true;
    895             break;
    896           }
    897         }
    898         if (assigned_operand) {
    899           break;
    900         }
    901       }
    902     }
    903 
    904     if (!assignment->HasAllocation(*buffer)) {
    905       // Find the smallest buffer which can be reused iterating from end of
    906       // allocation_indices (smallest) to beginning (largest).
    907       for (int allocation_index = allocation_indices.size() - 1;
    908            allocation_index >= 0; allocation_index--) {
    909         BufferAllocation* allocation = assignment->GetMutableAllocation(
    910             allocation_indices[allocation_index]);
    911         // Instructions are iterated in increasing buffer size, so any
    912         // previously create allocation must be large enough to hold this
    913         // instruction's output (with the exception of colocated buffers).
    914         if (colocated_allocations.count(allocation->index()) == 0) {
    915           // TODO(b/32491382) Colocated buffers are currently assigned in an
    916           // earlier pass, and so can break the "increasing allocation size"
    917           // invariant in this function (causing this CHECK to fail). However,
    918           // the call to MaybeAssignBuffer is safe as it returns false if
    919           // allocation.size < buffer.size.
    920           CHECK_GE(allocation->size(), buffer_size);
    921         }
    922 
    923         if (MaybeAssignBuffer(allocation, *buffer, assignment)) {
    924           VLOG(3) << "Reusing allocation #" << allocation->index()
    925                   << " for: " << *buffer;
    926           break;
    927         }
    928       }
    929     }
    930 
    931     if (!assignment->HasAllocation(*buffer) && has_sequential_order &&
    932         !liveness.MaybeLiveOut(*buffer)) {
    933       // There is a sequential instruction ordering, so we delay assignment of
    934       // temp buffers until after the loop. We do this right before we decide to
    935       // create a new allocation, to ensure we've exhausted all the buffer
    936       // re-use cases above.
    937       //
    938       // Entry parameters and thread local buffers were already handled earlier
    939       // in this loop iteration.  See BufferAllocation::IsPreallocatedTempBuffer
    940       // for the definition of temp buffers.
    941       CHECK(!is_entry_parameter) << *buffer;
    942       CHECK(!is_thread_local) << *buffer;
    943       (*buffers_to_assign_sequentially)[computation].insert(buffer);
    944       VLOG(3) << "Delaying assignment of temp buffer: " << *buffer;
    945       continue;
    946     }
    947 
    948     if (!assignment->HasAllocation(*buffer)) {
    949       BufferAllocation* allocation = assignment->NewAllocation(
    950           *buffer, buffer_size, is_thread_local, /*is_reusable=*/true);
    951       allocation_indices.push_back(allocation->index());
    952       VLOG(3) << "New allocation #" << allocation->index()
    953               << " for: " << *buffer;
    954     }
    955   }
    956 
    957   return Status::OK();
    958 }
    959 
    960 FlatMap<LogicalBuffer::Color, FlatSet<const LogicalBuffer*>,
    961         LogicalBuffer::Color::Hasher>
    962 BufferAssigner::SplitBuffersByColor(
    963     const FlatSet<const LogicalBuffer*>& buffers) {
    964   FlatMap<LogicalBuffer::Color, FlatSet<const LogicalBuffer*>,
    965           LogicalBuffer::Color::Hasher>
    966       color_map;
    967   for (auto buffer : buffers) {
    968     color_map[buffer->color()].insert(buffer);
    969   }
    970   return color_map;
    971 }
    972 
    973 Status BufferAssigner::AssignBuffersWithSequentialOrdering(
    974     const FlatMap<const HloComputation*, FlatSet<const LogicalBuffer*>>&
    975         buffers_to_assign_sequentially,
    976     bool run_whole_module_heap_simulation, BufferAssignment* assignment) {
    977   // Run the sequence of instructions through the heap simulator.  The heuristic
    978   // that seems to give the best results is lazy-best-fit, with all runs of
    979   // alloc / free calls sorted in decreasing size order.
    980   const HloOrdering& hlo_ordering = assignment->liveness().hlo_ordering();
    981   if (run_whole_module_heap_simulation) {
    982     // Run the heap simulation over the whole module. This reduces memory usage,
    983     // since buffers for kCall, kWhile, and kConditional sub-computations are
    984     // only live for the duration of their calling instructions.
    985     VLOG(1) << "Running whole-module heap simulation";
    986     SequentialHloOrdering::HloModuleSequence module_sequence;
    987     FlatSet<const LogicalBuffer*> all_buffers_to_assign;
    988     for (const auto& pair : buffers_to_assign_sequentially) {
    989       const HloComputation* computation = pair.first;
    990       const FlatSet<const LogicalBuffer*>& buffers_to_assign = pair.second;
    991       const std::vector<const HloInstruction*>* instruction_sequence =
    992           hlo_ordering.SequentialOrder(*computation);
    993       CHECK(instruction_sequence != nullptr) << computation->name();
    994       module_sequence[computation] = *instruction_sequence;
    995       all_buffers_to_assign.insert(buffers_to_assign.begin(),
    996                                    buffers_to_assign.end());
    997     }
    998     auto color_map = SplitBuffersByColor(all_buffers_to_assign);
    999     for (auto& single_colored_set : color_map) {
   1000       auto color = single_colored_set.first;
   1001       VLOG(2) << "Simulating heap for color " << color;
   1002       int64 alignment = assignment->color_alignment_(color);
   1003       HeapSimulator::Options options;
   1004       options.buffers_to_assign = &single_colored_set.second;
   1005       TF_ASSIGN_OR_RETURN(
   1006           const HeapSimulator::Result result,
   1007           HeapSimulator::Run(MakeUnique<DecreasingSizeRunsHeap>(
   1008                                  MakeUnique<LazyBestFitHeap>(alignment)),
   1009                              assignment->module(), module_sequence,
   1010                              assignment->points_to_analysis(),
   1011                              assignment->buffer_size_, options));
   1012       AssignBuffersFromHeapSimulator(result, assignment,
   1013                                      single_colored_set.first);
   1014     }
   1015   } else {
   1016     // Run the heap-simulation on a per-computation basis. Buffers for
   1017     // sub-computations are assigned disjoint BufferAllocations, assuming the
   1018     // worst-case that they may all be live concurrently.
   1019     VLOG(1) << "Running per-computation heap simulation";
   1020     for (const auto& pair : buffers_to_assign_sequentially) {
   1021       const HloComputation* computation = pair.first;
   1022       const FlatSet<const LogicalBuffer*>& buffers_to_assign = pair.second;
   1023       const std::vector<const HloInstruction*>* instruction_sequence =
   1024           hlo_ordering.SequentialOrder(*computation);
   1025       CHECK(instruction_sequence != nullptr) << computation->name();
   1026       auto color_map = SplitBuffersByColor(buffers_to_assign);
   1027       for (auto& single_colored_set : color_map) {
   1028         auto color = single_colored_set.first;
   1029         VLOG(2) << "Simulating heap for color " << color;
   1030         int64 alignment = assignment->color_alignment_(color);
   1031         HeapSimulator::Options options;
   1032         options.buffers_to_assign = &single_colored_set.second;
   1033         TF_ASSIGN_OR_RETURN(
   1034             const HeapSimulator::Result result,
   1035             HeapSimulator::Run(MakeUnique<DecreasingSizeRunsHeap>(
   1036                                    MakeUnique<LazyBestFitHeap>(alignment)),
   1037                                *computation, *instruction_sequence,
   1038                                assignment->points_to_analysis(),
   1039                                assignment->buffer_size_, options));
   1040         AssignBuffersFromHeapSimulator(result, assignment,
   1041                                        single_colored_set.first);
   1042       }
   1043     }
   1044   }
   1045   return Status::OK();
   1046 }
   1047 
   1048 void BufferAssigner::AssignBuffersFromHeapSimulator(
   1049     const HeapSimulator::Result& result, BufferAssignment* assignment,
   1050     LogicalBuffer::Color color) {
   1051   if (assignment->stats_.preallocated_temp_fragmentation_bytes == -1) {
   1052     assignment->stats_.preallocated_temp_fragmentation_bytes =
   1053         result.fragmentation_size;
   1054   } else {
   1055     assignment->stats_.preallocated_temp_fragmentation_bytes +=
   1056         result.fragmentation_size;
   1057   }
   1058 
   1059   BufferAllocation* allocation = assignment->NewEmptyAllocation(
   1060       result.heap_size, /*is_thread_local=*/false, /*is_reusable=*/true, color);
   1061   for (const auto& buffer_chunk : result.chunk_map) {
   1062     const LogicalBuffer& buffer = *buffer_chunk.first;
   1063     const HeapSimulator::Chunk& chunk = buffer_chunk.second;
   1064     assignment->AddAssignment(allocation, buffer, chunk.offset, chunk.size);
   1065   }
   1066 
   1067   assignment->heap_simulator_traces_.push_back(result.debug_trace);
   1068 }
   1069 
   1070 // Adds the 'colocated_set' of buffers to 'colocated_buffer_sets', maintaining
   1071 // the invariant that all sets in 'colocated_buffer_sets' are disjoint.
   1072 //
   1073 // A practical example of when this is necessary is a chain of kCall ops:
   1074 //   computation.entry
   1075 //     %a = call() -> computation.1
   1076 //   computation.1
   1077 //     %b = call() -> computation.2
   1078 //   computation.2
   1079 //     %c = parameter()
   1080 // This yields the logical sets {%a,%b} {%b,%c} {%c}, which need to be merged
   1081 // into a single set {%a,%b,%c}
   1082 void BufferAssigner::AddSetToColocatedBufferSets(
   1083     const std::vector<const LogicalBuffer*>& colocated_set,
   1084     std::vector<ColocatedBufferSet>* colocated_buffer_sets) {
   1085   if (colocated_set.empty()) {
   1086     return;
   1087   }
   1088 
   1089   // Find existing sets that overlap with at least one buffer from the
   1090   // colocated_set. The resulting 'overlap_set_indices' will have at most
   1091   // colocated_buffer_sets->size() entries, and will be in increasing order.
   1092   std::vector<size_t> overlap_set_indices;
   1093   for (size_t index = 0; index < colocated_buffer_sets->size(); ++index) {
   1094     for (const LogicalBuffer* buffer : colocated_set) {
   1095       if ((*colocated_buffer_sets)[index].count(buffer) > 0) {
   1096         overlap_set_indices.push_back(index);
   1097         break;
   1098       }
   1099     }
   1100   }
   1101 
   1102   // If there is no overlap with existing sets, create a new set.
   1103   if (overlap_set_indices.empty()) {
   1104     colocated_buffer_sets->emplace_back();
   1105     colocated_buffer_sets->back().insert(colocated_set.begin(),
   1106                                          colocated_set.end());
   1107     return;
   1108   }
   1109 
   1110   // Merge all overlap sets and the colocated set into the first overlap set.
   1111   ColocatedBufferSet* first = &(*colocated_buffer_sets)[overlap_set_indices[0]];
   1112   for (size_t index = 1; index < overlap_set_indices.size(); ++index) {
   1113     const ColocatedBufferSet& overlap_set =
   1114         (*colocated_buffer_sets)[overlap_set_indices[index]];
   1115     first->insert(overlap_set.begin(), overlap_set.end());
   1116   }
   1117   first->insert(colocated_set.begin(), colocated_set.end());
   1118 
   1119   // Remove overlap sets that we just merged. The offset accounts for the fact
   1120   // that as elements are erased, the indices need to be adjusted. Keep in mind
   1121   // that overlap_set_indices is in increasing order.
   1122   for (size_t index = 1; index < overlap_set_indices.size(); ++index) {
   1123     const size_t offset = overlap_set_indices[index] - index + 1;
   1124     colocated_buffer_sets->erase(colocated_buffer_sets->begin() + offset);
   1125   }
   1126 }
   1127 
   1128 namespace {
   1129 
   1130 // Checks that points-to set of 'instruction' is unambiguous and distinct
   1131 // (ensured by CopyInsertion), then adds the buffer from the points-to set at
   1132 // 'index' to 'colocated_set'.
   1133 const LogicalBuffer* AddBufferToColocatedSet(
   1134     const HloInstruction* instruction, const ShapeIndex& index,
   1135     const TuplePointsToAnalysis& points_to_analysis,
   1136     std::vector<const LogicalBuffer*>* colocated_set) {
   1137   // CopyInsertion ensures root points-to set is unambiguous and distinct.
   1138   const auto& points_to = points_to_analysis.GetPointsToSet(instruction);
   1139   DCHECK(!points_to.IsAmbiguous());
   1140   colocated_set->push_back(points_to.element(index)[0]);
   1141   return colocated_set->back();
   1142 }
   1143 
   1144 // Given the interference map of a graph (the list of interfering node indices
   1145 // for each node), perform graph coloring such that interfering nodes are
   1146 // assigned to different colors. Returns the assigned color of the nodes, where
   1147 // the colors are represented as integer values [0, color_count).
   1148 std::vector<int64> ColorInterferenceGraph(
   1149     const std::vector<std::vector<int64>>& interference_map) {
   1150   const int64 node_count = interference_map.size();
   1151 
   1152   // Sort the nodes such that we assign nodes with more interference first. This
   1153   // relies on the common heuristic of assigning the most constrained node
   1154   // first, but it would be good to investigate other ordering heuristics too.
   1155   std::vector<int64> nodes(node_count);
   1156   std::iota(nodes.begin(), nodes.end(), 0);
   1157   std::sort(nodes.begin(), nodes.end(),
   1158             [&interference_map](const int64 i, const int64 j) {
   1159               return interference_map[i].size() > interference_map[j].size();
   1160             });
   1161 
   1162   const int64 kColorUnassigned = -1;
   1163   std::vector<int64> assigned_colors(node_count, kColorUnassigned);
   1164   for (int64 node : nodes) {
   1165     // Mark the colors that are already assigned to the neighbors.
   1166     std::vector<bool> available_colors(node_count, true);
   1167     for (int64 neighbor : interference_map[node]) {
   1168       int64 color = assigned_colors[neighbor];
   1169       if (color != kColorUnassigned) {
   1170         available_colors[color] = false;
   1171       }
   1172     }
   1173 
   1174     // Find the color that is not yet assigned to the neighbors.
   1175     int64 color = kColorUnassigned;
   1176     for (color = 0; color < available_colors.size(); ++color) {
   1177       if (available_colors[color]) {
   1178         break;
   1179       }
   1180     }
   1181     CHECK_NE(color, kColorUnassigned);
   1182     assigned_colors[node] = color;
   1183   }
   1184   return assigned_colors;
   1185 }
   1186 
   1187 }  // namespace
   1188 
   1189 std::vector<BufferAssigner::ColocatedBufferSet>
   1190 BufferAssigner::MergeColocatedBufferSets(
   1191     const std::vector<ColocatedBufferSet>& colocated_buffer_sets,
   1192     const BufferLiveness& buffer_liveness,
   1193     const LogicalBuffer::SizeFunction& buffer_size) {
   1194   VLOG(1) << "colocation sets count before coalescing:"
   1195           << colocated_buffer_sets.size();
   1196 
   1197   // Returns true if the given buffer is for the entry parameter.
   1198   auto is_entry_parameter = [](const LogicalBuffer& buffer) {
   1199     auto* instruction = buffer.instruction();
   1200     auto* computation = instruction->parent();
   1201     auto* module = computation->parent();
   1202     return instruction->opcode() == HloOpcode::kParameter &&
   1203            computation == module->entry_computation();
   1204   };
   1205 
   1206   // Returns true if the two colocated buffer sets (specified by their indices
   1207   // into the colocated_buffer_sets) can be merged into a single set.
   1208   auto cannot_merge_buffer_sets = [&colocated_buffer_sets, &buffer_liveness,
   1209                                    &buffer_size,
   1210                                    &is_entry_parameter](int64 i, int64 j) {
   1211     for (auto& buffer_a : colocated_buffer_sets[i]) {
   1212       for (auto& buffer_b : colocated_buffer_sets[j]) {
   1213         // Do not merge if the set includes live outs or entry parameters.
   1214         if ((buffer_liveness.MaybeLiveOut(*buffer_a) &&
   1215              is_entry_parameter(*buffer_b)) ||
   1216             (buffer_liveness.MaybeLiveOut(*buffer_b) &&
   1217              is_entry_parameter(*buffer_a))) {
   1218           return true;
   1219         }
   1220         // Do not merge if the buffers interfere with each other.
   1221         if (buffer_a->id() != buffer_b->id() &&
   1222             buffer_liveness.MayInterfere(*buffer_a, *buffer_b)) {
   1223           return true;
   1224         }
   1225         // Do not merge if the buffer sizes are different.
   1226         if (buffer_size(*buffer_a) != buffer_size(*buffer_b)) {
   1227           return true;
   1228         }
   1229       }
   1230     }
   1231     return false;
   1232   };
   1233 
   1234   // Build the interference map among the colocated buffer sets (nodes), by
   1235   // adding an edge between any two nodes that cannot be merged into a single
   1236   // colocated buffer set.
   1237   std::vector<std::vector<int64>> interference_map(
   1238       colocated_buffer_sets.size());
   1239   for (int64 i = 0; i < colocated_buffer_sets.size(); ++i) {
   1240     for (int64 j = i + 1; j < colocated_buffer_sets.size(); ++j) {
   1241       if (cannot_merge_buffer_sets(i, j)) {
   1242         interference_map[i].push_back(j);
   1243         interference_map[j].push_back(i);
   1244       }
   1245     }
   1246   }
   1247 
   1248   // Assign a color to each colocation set in colocated_buffer_sets, such that
   1249   // the sets that can be merged are assigned with the same color.
   1250   auto assigned_colors = ColorInterferenceGraph(interference_map);
   1251 
   1252   // Merge the buffer sets with the same color.
   1253   CHECK(!assigned_colors.empty());
   1254   int64 num_sets =
   1255       *std::max_element(assigned_colors.begin(), assigned_colors.end()) + 1;
   1256   std::vector<ColocatedBufferSet> new_colocated_buffer_sets(num_sets);
   1257   for (int64 i = 0; i < colocated_buffer_sets.size(); ++i) {
   1258     const auto& buffer_set = colocated_buffer_sets[i];
   1259     new_colocated_buffer_sets[assigned_colors[i]].insert(buffer_set.begin(),
   1260                                                          buffer_set.end());
   1261   }
   1262 
   1263   VLOG(1) << "colocation sets count after coalescing:"
   1264           << colocated_buffer_sets.size();
   1265   return new_colocated_buffer_sets;
   1266 }
   1267 
   1268 // Builds sets of buffers in 'colocated_buffer_sets' which should be colocated
   1269 // in the same allocation (currently just supports kWhile, kCall, and
   1270 // kConditional).
   1271 void BufferAssigner::BuildColocatedBufferSets(
   1272     const HloModule* module, const BufferLiveness& buffer_liveness,
   1273     const LogicalBuffer::SizeFunction& buffer_size,
   1274     std::vector<ColocatedBufferSet>* colocated_buffer_sets) {
   1275   const TuplePointsToAnalysis& points_to_analysis =
   1276       buffer_liveness.points_to_analysis();
   1277   for (const HloComputation* computation : module->MakeComputationPostOrder()) {
   1278     if (computation->IsFusionComputation()) {
   1279       continue;
   1280     }
   1281     for (const HloInstruction* instruction :
   1282          computation->MakeInstructionPostOrder()) {
   1283       const HloOpcode opcode = instruction->opcode();
   1284       if (opcode == HloOpcode::kWhile) {
   1285         const HloInstruction* while_hlo = instruction;
   1286         ShapeUtil::ForEachSubshape(
   1287             while_hlo->shape(),
   1288             [this, while_hlo, &points_to_analysis, &buffer_liveness,
   1289              buffer_size, computation, colocated_buffer_sets](
   1290                 const Shape& /*subshape*/, const ShapeIndex& index) {
   1291               std::vector<const LogicalBuffer*> colocated_set;
   1292               // Add while.init.
   1293               AddBufferToColocatedSet(while_hlo->operand(0), index,
   1294                                       points_to_analysis, &colocated_set);
   1295               // Add while.result.
   1296               AddBufferToColocatedSet(while_hlo, index, points_to_analysis,
   1297                                       &colocated_set);
   1298               // Add while.cond.parameter.
   1299               AddBufferToColocatedSet(
   1300                   while_hlo->while_condition()->parameter_instruction(0), index,
   1301                   points_to_analysis, &colocated_set);
   1302               // Add while.body.parameter.
   1303               AddBufferToColocatedSet(
   1304                   while_hlo->while_body()->parameter_instruction(0), index,
   1305                   points_to_analysis, &colocated_set);
   1306               // Add while.body.root.
   1307               AddBufferToColocatedSet(
   1308                   while_hlo->while_body()->root_instruction(), index,
   1309                   points_to_analysis, &colocated_set);
   1310               AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
   1311             });
   1312       } else if (opcode == HloOpcode::kCall) {
   1313         const HloInstruction* call_hlo = instruction;
   1314         const HloInstruction* root_hlo =
   1315             call_hlo->to_apply()->root_instruction();
   1316         ShapeUtil::ForEachSubshape(
   1317             call_hlo->shape(),
   1318             [this, call_hlo, root_hlo, &points_to_analysis,
   1319              colocated_buffer_sets](const Shape& /*subshape*/,
   1320                                     const ShapeIndex& index) {
   1321               std::vector<const LogicalBuffer*> colocated_set;
   1322               // Add call.result.
   1323               AddBufferToColocatedSet(call_hlo, index, points_to_analysis,
   1324                                       &colocated_set);
   1325               // Add call.subcomputation.root.
   1326               AddBufferToColocatedSet(root_hlo, index, points_to_analysis,
   1327                                       &colocated_set);
   1328               AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
   1329             });
   1330       } else if (opcode == HloOpcode::kConditional) {
   1331         const HloInstruction* conditional_hlo = instruction;
   1332         ShapeUtil::ForEachSubshape(
   1333             conditional_hlo->shape(),
   1334             [this, conditional_hlo, &points_to_analysis, colocated_buffer_sets](
   1335                 const Shape& /*subshape*/, const ShapeIndex& index) {
   1336               std::vector<const LogicalBuffer*> colocated_set;
   1337               // Add conditional.result.
   1338               AddBufferToColocatedSet(conditional_hlo, index,
   1339                                       points_to_analysis, &colocated_set);
   1340               // Add conditional.true_computation.root.
   1341               AddBufferToColocatedSet(
   1342                   conditional_hlo->true_computation()->root_instruction(),
   1343                   index, points_to_analysis, &colocated_set);
   1344               // Add conditional.false_computation.root.
   1345               AddBufferToColocatedSet(
   1346                   conditional_hlo->false_computation()->root_instruction(),
   1347                   index, points_to_analysis, &colocated_set);
   1348               AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
   1349             });
   1350 
   1351         // Add true_operand and conditional.true_computation.parameter(0) as a
   1352         // colocated buffer set. Note that this has to be done for each subshape
   1353         // in the true_operand of the conditional.
   1354         ShapeUtil::ForEachSubshape(
   1355             conditional_hlo->operand(1)->shape(),
   1356             [this, conditional_hlo, &points_to_analysis, colocated_buffer_sets](
   1357                 const Shape& /*subshape*/, const ShapeIndex& index) {
   1358               std::vector<const LogicalBuffer*> true_set;
   1359               // Add conditional.true_operand.
   1360               AddBufferToColocatedSet(conditional_hlo->operand(1), index,
   1361                                       points_to_analysis, &true_set);
   1362               // Add conditional.true_computation.parameter_instruction(0).
   1363               AddBufferToColocatedSet(
   1364                   conditional_hlo->true_computation()->parameter_instruction(0),
   1365                   index, points_to_analysis, &true_set);
   1366               AddSetToColocatedBufferSets(true_set, colocated_buffer_sets);
   1367             });
   1368 
   1369         // Add false_operand and conditional.false_computation.parameter(0) as a
   1370         // colocated buffer set. Note that this has to be done for each subshape
   1371         // in the false_operand of the conditional.
   1372         ShapeUtil::ForEachSubshape(
   1373             conditional_hlo->operand(2)->shape(),
   1374             [this, conditional_hlo, &points_to_analysis, colocated_buffer_sets](
   1375                 const Shape& /*subshape*/, const ShapeIndex& index) {
   1376               std::vector<const LogicalBuffer*> false_set;
   1377               // Add conditional.false_operand.
   1378               AddBufferToColocatedSet(conditional_hlo->operand(2), index,
   1379                                       points_to_analysis, &false_set);
   1380               // Add conditional.false_computation.parameter_instruction(0).
   1381               AddBufferToColocatedSet(
   1382                   conditional_hlo->false_computation()->parameter_instruction(
   1383                       0),
   1384                   index, points_to_analysis, &false_set);
   1385               AddSetToColocatedBufferSets(false_set, colocated_buffer_sets);
   1386             });
   1387       }
   1388     }
   1389   }
   1390 
   1391   if (colocated_buffer_sets->empty()) {
   1392     return;
   1393   }
   1394 
   1395   // Try to find more coalescing opportunities among the colocated buffer sets.
   1396   //
   1397   // TODO(b/32491382): We should be able to remove this by using the
   1398   // module-level liveness analysis, which would let us directly detect buffer
   1399   // sharing opportunities between the while instruction buffer and the buffers
   1400   // from the predicate and body computation, as well as sharing across
   1401   // different while instructions.
   1402   std::vector<ColocatedBufferSet> new_colocated_buffer_sets =
   1403       MergeColocatedBufferSets(*colocated_buffer_sets, buffer_liveness,
   1404                                buffer_size);
   1405   std::swap(*colocated_buffer_sets, new_colocated_buffer_sets);
   1406 }
   1407 
   1408 // Assigns all colocated buffer sets in 'colocated_buffer_sets' to the same
   1409 // allocation in 'assignment'.
   1410 void BufferAssigner::AssignColocatedBufferSets(
   1411     const std::vector<ColocatedBufferSet>& colocated_buffer_sets,
   1412     BufferAssignment* assignment,
   1413     FlatSet<const LogicalBuffer*>* colocated_buffers,
   1414     FlatSet<BufferAllocation::Index>* colocated_allocations) {
   1415   for (const ColocatedBufferSet& colocated_buffer_set : colocated_buffer_sets) {
   1416     BufferAllocation* allocation = nullptr;
   1417     // Set 'entry_parameter_number' and 'entry_parameter_shape_idx' if entry
   1418     // param in 'colocated_buffer_set'.
   1419     int64 entry_parameter_number = -1;
   1420     const ShapeIndex* entry_parameter_shape_idx = nullptr;
   1421     for (const LogicalBuffer* buffer : colocated_buffer_set) {
   1422       const HloInstruction* instruction = buffer->instruction();
   1423       const HloComputation* computation = instruction->parent();
   1424       if (instruction->opcode() == HloOpcode::kParameter &&
   1425           computation == computation->parent()->entry_computation()) {
   1426         entry_parameter_number = instruction->parameter_number();
   1427         entry_parameter_shape_idx = &buffer->index();
   1428         break;
   1429       }
   1430     }
   1431 
   1432     for (const LogicalBuffer* buffer : colocated_buffer_set) {
   1433       const int64 buffer_size = assignment->buffer_size_(*buffer);
   1434       if (allocation == nullptr) {
   1435         // TODO(b/32491382) Avoid current trivial solution of using new
   1436         // allocations for each colocated buffer set. When liveness has
   1437         // module-level scope, we can allow buffers to be shared across
   1438         // computations (in some cases).
   1439         allocation = assignment->NewAllocation(*buffer, buffer_size,
   1440                                                /*is_thread_local=*/false,
   1441                                                /*is_reusable=*/true);
   1442         if (entry_parameter_number >= 0) {
   1443           // This colocated buffer set contains an entry parameter and other
   1444           // logical buffers which use the parameter as read-only in a while
   1445           // body computation (which updates in place).
   1446           // Set 'entry_computation_parameter' to indicate that it contains
   1447           // an entry parameter, and to prevent reuse in MaybeAssignBuffer.
   1448           allocation->set_entry_computation_parameter(
   1449               entry_parameter_number, *entry_parameter_shape_idx);
   1450         }
   1451         colocated_allocations->insert(allocation->index());
   1452       } else {
   1453         CHECK_EQ(buffer_size, allocation->size())
   1454             << "Buffer: " << *buffer << " size mismatch in colocated buffer "
   1455             << "allocation: " << *allocation;
   1456         assignment->AddAssignment(allocation, *buffer, /*offset=*/0,
   1457                                   buffer_size);
   1458       }
   1459       colocated_buffers->insert(buffer);
   1460     }
   1461   }
   1462 }
   1463 
   1464 StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
   1465     const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
   1466     LogicalBuffer::SizeFunction buffer_size,
   1467     LogicalBuffer::AlignmentFunction color_alignment) {
   1468   TF_ASSIGN_OR_RETURN(std::unique_ptr<BufferLiveness> liveness,
   1469                       BufferLiveness::Run(module, std::move(hlo_ordering)));
   1470 
   1471   VLOG(1) << "Assigning buffers to module " << module->name();
   1472   XLA_VLOG_LINES(2, module->ToString());
   1473   XLA_VLOG_LINES(3, liveness->ToString());
   1474   XLA_VLOG_LINES(3, liveness->points_to_analysis().ToString());
   1475 
   1476   // Can't use MakeUnique because BufferAssignment constructor is private.
   1477   std::unique_ptr<BufferAssignment> assignment(
   1478       new BufferAssignment(module, std::move(liveness), std::move(buffer_size),
   1479                            std::move(color_alignment)));
   1480 
   1481   // Assign buffers with the tightest constraints first (colocated buffer sets).
   1482   // Once b/32491382 enables module-level liveness analysis, we may be able
   1483   // to assign colocated buffers (or at least reuse their allocation for
   1484   // buffers outside of the set) in AssignBuffersForComputation.
   1485   FlatSet<const LogicalBuffer*> colocated_buffers;
   1486   FlatSet<BufferAllocation::Index> colocated_allocations;
   1487   std::vector<ColocatedBufferSet> colocated_buffer_sets;
   1488   BuildColocatedBufferSets(module, assignment->liveness(),
   1489                            assignment->buffer_size_, &colocated_buffer_sets);
   1490   TF_RETURN_IF_ERROR(colorer_(assignment->liveness()));
   1491   VLOG(3) << "After coloring:";
   1492   XLA_VLOG_LINES(3, assignment->points_to_analysis().ToString());
   1493 
   1494   AssignColocatedBufferSets(colocated_buffer_sets, assignment.get(),
   1495                             &colocated_buffers, &colocated_allocations);
   1496 
   1497   std::vector<const HloComputation*> thread_local_computations;
   1498   std::vector<const HloComputation*> global_computations;
   1499   TF_RETURN_IF_ERROR(GatherComputationsByAllocationType(
   1500       module, &thread_local_computations, &global_computations));
   1501 
   1502   // First assign buffers for global computatations. Temporary buffers for
   1503   // sequential computations are collected in 'buffers_to_assign_sequentially'.
   1504   FlatMap<const HloComputation*, FlatSet<const LogicalBuffer*>>
   1505       buffers_to_assign_sequentially;
   1506   for (auto* computation : global_computations) {
   1507     TF_RETURN_IF_ERROR(AssignBuffersForComputation(
   1508         computation, module->config().debug_options(),
   1509         /*is_thread_local=*/false, colocated_buffers, colocated_allocations,
   1510         &buffers_to_assign_sequentially, assignment.get()));
   1511   }
   1512   // Assign buffers with sequential ordering, if any. If all global computations
   1513   // are sequential, we can run heap simuation on the whole module, which
   1514   // reduces memory usage.
   1515   const bool run_whole_module_heap_simulation =
   1516       buffers_to_assign_sequentially.size() == global_computations.size();
   1517   TF_RETURN_IF_ERROR(AssignBuffersWithSequentialOrdering(
   1518       buffers_to_assign_sequentially, run_whole_module_heap_simulation,
   1519       assignment.get()));
   1520 
   1521   // Now assign buffers for thread-local computations. All LogicalBuffers get
   1522   // their own BufferAllocation.
   1523   for (auto* computation : thread_local_computations) {
   1524     TF_RET_CHECK(computation != module->entry_computation());
   1525     if (computation->IsFusionComputation()) {
   1526       continue;
   1527     }
   1528     TF_RETURN_IF_ERROR(AssignBuffersForComputation(
   1529         computation, module->config().debug_options(),
   1530         /*is_thread_local=*/true, colocated_buffers, colocated_allocations,
   1531         /*buffers_to_assign_sequentially=*/nullptr, assignment.get()));
   1532   }
   1533 
   1534   // Mark all buffers which may be live out of the entry computation as
   1535   // "liveout".
   1536   for (const LogicalBuffer* buffer :
   1537        assignment->liveness().maybe_live_out_buffers()) {
   1538     VLOG(3) << "maybe_live_out LogicalBuffer: " << *buffer;
   1539     if (assignment->HasAllocation(*buffer)) {
   1540       BufferAllocation* alloc =
   1541           assignment->GetMutableAssignedAllocation(*buffer);
   1542       alloc->set_maybe_live_out(true);
   1543       VLOG(3) << "maybe_live_out BufferAllocation: " << *alloc;
   1544     }
   1545   }
   1546 
   1547   // Combines allocations of temporary buffers into one big BufferAllocation.
   1548   // This can only be performed after all buffers have been assigned, and after
   1549   // maybe_live_out is marked, since it is used to determine whether an
   1550   // allocation contains temporary buffers or not.
   1551   assignment->CombineTempAllocations();
   1552 
   1553   XLA_VLOG_LINES(2, assignment->ToString());
   1554   TF_RETURN_IF_ERROR(assignment->ComputeSummaryStats());
   1555   XLA_VLOG_LINES(1, assignment->GetStats().ToString());
   1556   return std::move(assignment);
   1557 }
   1558 
   1559 }  // namespace xla
   1560