Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Defines the data returned by the XLA buffer assignment packages.
     17 
     18 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
     19 
     20 #include <algorithm>
     21 #include <deque>
     22 #include <ostream>
     23 #include <utility>
     24 
     25 #include "absl/container/flat_hash_map.h"
     26 #include "absl/container/flat_hash_set.h"
     27 #include "absl/memory/memory.h"
     28 #include "absl/strings/str_cat.h"
     29 #include "absl/strings/str_format.h"
     30 #include "tensorflow/compiler/xla/map_util.h"
     31 #include "tensorflow/compiler/xla/service/buffer_value_containers.h"
     32 #include "tensorflow/compiler/xla/service/heap_simulator.h"
     33 #include "tensorflow/compiler/xla/service/hlo.pb.h"
     34 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     35 #include "tensorflow/compiler/xla/shape_util.h"
     36 #include "tensorflow/compiler/xla/status_macros.h"
     37 #include "tensorflow/compiler/xla/types.h"
     38 #include "tensorflow/compiler/xla/util.h"
     39 #include "tensorflow/core/lib/core/errors.h"
     40 #include "tensorflow/core/lib/hash/hash.h"
     41 #include "tensorflow/core/lib/strings/numbers.h"
     42 
     43 namespace xla {
     44 namespace {
     45 
     46 using absl::flat_hash_map;
     47 using absl::flat_hash_set;
     48 using absl::StrAppend;
     49 using absl::StrAppendFormat;
     50 using ::tensorflow::strings::HumanReadableNumBytes;
     51 
     52 template <typename T>
     53 string ColocatedBufferSetsToString(const T& container, const char* title) {
     54   string result;
     55   StrAppend(&result, title, "\n");
     56   for (const auto& it : container) {
     57     StrAppend(&result, "\t", it->ToString(), "\n");
     58   }
     59   return result;
     60 }
     61 
     62 // Checks that points-to set of 'instruction' is unambiguous and distinct
     63 // (ensured by CopyInsertion), then adds the buffer from the points-to set at
     64 // 'index' to 'colocated_set'.
     65 const LogicalBuffer* AddBufferToColocatedSet(
     66     const HloInstruction* instruction, const ShapeIndex& index,
     67     const TuplePointsToAnalysis& points_to_analysis,
     68     std::vector<const LogicalBuffer*>* colocated_set) {
     69   // CopyInsertion ensures root points-to set is unambiguous and distinct.
     70   const auto& points_to = points_to_analysis.GetPointsToSet(instruction);
     71   DCHECK(!points_to.IsAmbiguous());
     72   colocated_set->push_back(points_to.element(index)[0]);
     73   return colocated_set->back();
     74 }
     75 
     76 // Given the interference map of a graph (the list of interfering node indices
     77 // for each node), perform graph coloring such that interfering nodes are
     78 // assigned to different colors. Returns the assigned color of the nodes, where
     79 // the colors are represented as integer values [0, color_count).
     80 std::vector<int64> ColorInterferenceGraph(
     81     const std::vector<std::vector<int64>>& interference_map) {
     82   const int64 node_count = interference_map.size();
     83 
     84   // Sort the nodes such that we assign nodes with more interference first. This
     85   // relies on the common heuristic of assigning the most constrained node
     86   // first, but it would be good to investigate other ordering heuristics too.
     87   std::vector<int64> nodes(node_count);
     88   std::iota(nodes.begin(), nodes.end(), 0);
     89   absl::c_sort(nodes, [&interference_map](const int64 i, const int64 j) {
     90     return interference_map[i].size() > interference_map[j].size();
     91   });
     92 
     93   const int64 kColorUnassigned = -1;
     94   std::vector<int64> assigned_colors(node_count, kColorUnassigned);
     95   for (int64 node : nodes) {
     96     // Mark the colors that are already assigned to the neighbors.
     97     std::vector<bool> available_colors(node_count, true);
     98     for (int64 neighbor : interference_map[node]) {
     99       int64 color = assigned_colors[neighbor];
    100       if (color != kColorUnassigned) {
    101         available_colors[color] = false;
    102       }
    103     }
    104 
    105     // Find the color that is not yet assigned to the neighbors.
    106     int64 color = kColorUnassigned;
    107     for (color = 0; color < available_colors.size(); ++color) {
    108       if (available_colors[color]) {
    109         break;
    110       }
    111     }
    112     CHECK_NE(color, kColorUnassigned);
    113     assigned_colors[node] = color;
    114   }
    115   return assigned_colors;
    116 }
    117 
    118 }  // namespace
    119 
    120 Status GatherComputationsByAllocationType(
    121     const HloModule* module,
    122     std::vector<const HloComputation*>* thread_local_computations,
    123     std::vector<const HloComputation*>* global_computations) {
    124   // Create a worklist of computations paired with whether the allocation must
    125   // be thread-local.
    126   std::deque<std::pair<const HloComputation*, bool>> worklist;
    127   worklist.push_back(std::make_pair(module->entry_computation(),
    128                                     /*is_thread_local*/ false));
    129 
    130   // Sets for quickly checking membership. Computations are returned in vectors
    131   // for stable iteration.
    132   flat_hash_set<const HloComputation*> thread_local_set;
    133   flat_hash_set<const HloComputation*> global_set;
    134 
    135   while (!worklist.empty()) {
    136     auto worklist_front = worklist.front();
    137     worklist.pop_front();
    138     const HloComputation* computation = worklist_front.first;
    139     bool is_thread_local = worklist_front.second;
    140     bool in_thread_local_set = thread_local_set.contains(computation);
    141     bool in_global_set = global_set.contains(computation);
    142 
    143     // If the computation has already been added to the respective set, then
    144     // nothing to do.
    145     if ((is_thread_local && in_thread_local_set) ||
    146         (!is_thread_local && in_global_set)) {
    147       continue;
    148     }
    149 
    150     // If the computation has already been added to the other set this is an
    151     // error condition because the global call to the computation (eg,
    152     // while/call) may return a reference to one of the thread-local buffers to
    153     // the calling computation which will become a dangling reference when the
    154     // thread-local is deallocated with the call return.
    155     if ((is_thread_local && in_global_set) ||
    156         (!is_thread_local && in_thread_local_set)) {
    157       return InvalidArgument(
    158           "computation %s has conflicting allocation requirements (global "
    159           "and thread-local)",
    160           computation->name());
    161     }
    162 
    163     if (is_thread_local) {
    164       thread_local_set.insert(computation);
    165     } else {
    166       global_set.insert(computation);
    167     }
    168 
    169     for (auto* instruction : computation->instructions()) {
    170       for (HloComputation* subcomputation :
    171            instruction->called_computations()) {
    172         switch (instruction->opcode()) {
    173           case HloOpcode::kCall:
    174           case HloOpcode::kConditional:
    175           case HloOpcode::kWhile:
    176             // Call and while must be called from a computation with global
    177             // allocations as they may return references to buffers inside the
    178             // called computation which cannot be thread-local.
    179             if (is_thread_local) {
    180               return InvalidArgument(
    181                   "computation %s cannot contain call/while op because it "
    182                   "requires thread-local buffer allocations",
    183                   computation->name());
    184             }
    185             worklist.push_back(std::make_pair(subcomputation,
    186                                               false));  // Not thread local.
    187             break;
    188           case HloOpcode::kAllReduce:
    189           case HloOpcode::kMap:
    190           case HloOpcode::kReduce:
    191           case HloOpcode::kReduceWindow:
    192           case HloOpcode::kScatter:
    193           case HloOpcode::kSelectAndScatter:
    194           case HloOpcode::kSort:
    195           case HloOpcode::kFusion:
    196             // Map/reduce etc computations are always thread-local.
    197             worklist.push_back(std::make_pair(subcomputation,
    198                                               true));  // Thread local.
    199             break;
    200           default:
    201             return InternalError("Unexpected calling opcode: %s",
    202                                  HloOpcodeString(instruction->opcode()));
    203         }
    204       }
    205     }
    206   }
    207 
    208   // Add the computations to the vectors in post order.
    209   for (auto* computation : module->MakeComputationPostOrder()) {
    210     if (thread_local_set.contains(computation)) {
    211       thread_local_computations->push_back(computation);
    212     } else if (global_set.contains(computation)) {
    213       global_computations->push_back(computation);
    214     }
    215     // If the computation is not reachable from the entry computation, then it
    216     // will not appear in either thread_local_set or global_set. We don't bother
    217     // assigning buffers for these.
    218   }
    219   return Status::OK();
    220 }
    221 
    222 string BufferAllocation::Slice::ToString() const {
    223   return absl::StrCat("{index:", index(), ", offset:", offset_,
    224                       ", size:", size_, "}");
    225 }
    226 
    227 BufferAllocation::Slice BufferAllocation::GetSlice(
    228     const LogicalBuffer& buffer) const {
    229   const OffsetSize os = FindOrDie(assigned_buffers_, &buffer);
    230   return Slice(this, os.offset, os.size);
    231 }
    232 
    233 void BufferAllocation::AddAssignment(const LogicalBuffer& buffer, int64 offset,
    234                                      int64 size) {
    235   VLOG(4) << "Trying to add " << buffer << " to allocation #" << index();
    236   CHECK(!assigned_buffers_.contains(&buffer))
    237       << "LogicalBuffer " << buffer << " already assigned to allocation "
    238       << index_;
    239   CHECK_LE(offset, size_) << "LogicalBuffer " << buffer
    240                           << " offset out of range";
    241   CHECK_LE(offset + size, size_)
    242       << "LogicalBuffer " << buffer
    243       << " size out of range at offset: " << offset << " with size: " << size;
    244   CHECK_EQ(buffer.color(), color())
    245       << "Buffer color " << buffer.color() << " for buffer " << buffer
    246       << " does not match allocation color " << color() << ".";
    247   OffsetSize offset_size;
    248   offset_size.offset = offset;
    249   offset_size.size = size;
    250   assigned_buffers_.emplace(&buffer, offset_size);
    251 }
    252 
    253 BufferAllocationProto BufferAllocation::ToProto() const {
    254   BufferAllocationProto proto;
    255   proto.set_index(index_);
    256   proto.set_size(size_);
    257   proto.set_is_thread_local(is_thread_local_);
    258   proto.set_is_tuple(is_tuple_);
    259   proto.set_color(color_.value());
    260   if (is_entry_computation_parameter_) {
    261     proto.set_is_entry_computation_parameter(true);
    262     for (int64 idx : param_shape_index()) {
    263       proto.add_parameter_shape_index(idx);
    264     }
    265     proto.set_parameter_number(parameter_number_);
    266   }
    267   proto.set_is_constant(is_constant_);
    268   proto.set_maybe_live_out(maybe_live_out_);
    269   for (const auto& buffer_offset_size : assigned_buffers_) {
    270     BufferAllocationProto::Assigned* proto_assigned = proto.add_assigned();
    271     proto_assigned->set_logical_buffer_id(buffer_offset_size.first->id());
    272     proto_assigned->set_offset(buffer_offset_size.second.offset);
    273     proto_assigned->set_size(buffer_offset_size.second.size);
    274   }
    275   absl::c_sort(*proto.mutable_assigned(),
    276                [](const BufferAllocationProto::Assigned& assign1,
    277                   const BufferAllocationProto::Assigned& assign2) {
    278                  return assign1.logical_buffer_id() <
    279                         assign2.logical_buffer_id();
    280                });
    281   return proto;
    282 }
    283 
    284 string BufferAllocation::ToString() const {
    285   string output;
    286   StrAppendFormat(&output, "allocation %d: %p, size %d", index_, this, size());
    287   if (color().value() != 0) {
    288     StrAppend(&output, ", color ", color().value());
    289   }
    290   if (is_entry_computation_parameter()) {
    291     StrAppend(&output, ", parameter ", parameter_number(), " at ShapeIndex ",
    292               param_shape_index().ToString());
    293   }
    294   if (is_constant()) {
    295     StrAppend(&output, ", constant");
    296   }
    297   if (is_thread_local()) {
    298     StrAppend(&output, ", thread-local");
    299   }
    300   if (maybe_live_out()) {
    301     StrAppend(&output, ", maybe-live-out");
    302   }
    303   if (IsPreallocatedTempBuffer()) {
    304     StrAppend(&output, ", preallocated-temp");
    305   }
    306   StrAppend(&output, ":\n");
    307   // Dump the assigned buffers ordered by id.
    308   std::vector<const LogicalBuffer*> sorted_buffers;
    309   for (const auto& buffer_offset_size : assigned_buffers_) {
    310     sorted_buffers.push_back(buffer_offset_size.first);
    311   }
    312   absl::c_sort(sorted_buffers,
    313                [](const LogicalBuffer* a, const LogicalBuffer* b) {
    314                  return a->id() < b->id();
    315                });
    316   for (const LogicalBuffer* buffer : sorted_buffers) {
    317     const OffsetSize& offset_size = FindOrDie(assigned_buffers_, buffer);
    318     StrAppend(&output, absl::StrFormat(
    319                            "  %s [%d,%d]: %s\n", buffer->ToString(),
    320                            offset_size.offset, offset_size.size,
    321                            ShapeUtil::HumanStringWithLayout(buffer->shape())));
    322   }
    323   return output;
    324 }
    325 
    326 std::ostream& operator<<(std::ostream& out, const BufferAllocation& buffer) {
    327   out << buffer.ToString();
    328   return out;
    329 }
    330 
    331 std::ostream& operator<<(std::ostream& out, const BufferAllocation::Slice& s) {
    332   out << s.ToString();
    333   return out;
    334 }
    335 
    336 const PointsToSet& BufferAssignment::GetPointsToSet(
    337     const HloInstruction* instruction) const {
    338   return points_to_analysis().GetPointsToSet(instruction);
    339 }
    340 
    341 bool BufferAssignment::HasAllocation(const LogicalBuffer& buffer) const {
    342   TF_CHECK_OK(points_to_analysis().VerifyBuffer(buffer));
    343   return allocation_index_for_buffer_.contains(&buffer);
    344 }
    345 
    346 const BufferAllocation& BufferAssignment::GetAssignedAllocation(
    347     const LogicalBuffer& buffer) const {
    348   CHECK(HasAllocation(buffer));
    349   return GetAllocation(allocation_index_for_buffer_.at(&buffer));
    350 }
    351 
    352 BufferAllocation* BufferAssignment::GetMutableAssignedAllocation(
    353     const LogicalBuffer& buffer) {
    354   return const_cast<BufferAllocation*>(&GetAssignedAllocation(buffer));
    355 }
    356 
    357 std::set<BufferAllocation::Slice> BufferAssignment::GetAllSlices(
    358     const HloInstruction* instruction, const ShapeIndex& index) const {
    359   std::set<BufferAllocation::Slice> result;
    360   for (const LogicalBuffer* buffer : GetSourceBuffers(instruction, index)) {
    361     if (HasAllocation(*buffer)) {
    362       result.insert(GetAssignedAllocation(*buffer).GetSlice(*buffer));
    363     }
    364   }
    365   return result;
    366 }
    367 
    368 const BufferAllocation& BufferAssignment::GetAllocation(
    369     BufferAllocation::Index index) const {
    370   CHECK_GE(index, 0);
    371   CHECK_LT(index, allocations_.size());
    372   return allocations_[index];
    373 }
    374 
    375 const BufferAllocation* BufferAssignment::GetInstructionAllocation(
    376     const HloInstruction* hlo, const ShapeIndex& shape_index) const {
    377   const PointsToSet& points_to_set = points_to_analysis().GetPointsToSet(hlo);
    378   const LogicalBuffer* buffer = points_to_set.element(shape_index)[0];
    379 
    380   if (!HasAllocation(*buffer)) {
    381     return nullptr;
    382   }
    383 
    384   const BufferAllocation& instruction_allocation =
    385       GetAssignedAllocation(*buffer);
    386   return &instruction_allocation;
    387 }
    388 
    389 BufferAllocation* BufferAssignment::GetMutableAllocation(
    390     BufferAllocation::Index index) {
    391   return const_cast<BufferAllocation*>(&GetAllocation(index));
    392 }
    393 
    394 bool BufferAssignment::HasAllocationAt(const HloInstruction* instruction,
    395                                        const ShapeIndex& index) const {
    396   for (const LogicalBuffer* buffer :
    397        GetPointsToSet(instruction).element(index)) {
    398     if (allocation_index_for_buffer_.contains(buffer)) {
    399       return true;
    400     }
    401   }
    402   return false;
    403 }
    404 
    405 bool BufferAssignment::HasTopLevelAllocation(
    406     const HloInstruction* instruction) const {
    407   return HasAllocationAt(instruction, /*index=*/{});
    408 }
    409 
    410 StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueSlice(
    411     const HloInstruction* instruction, const ShapeIndex& index) const {
    412   VLOG(3) << "Trying to find unique slice for " << instruction->name() << " ["
    413           << index << "]";
    414   BufferAllocation::Slice result;
    415   for (const LogicalBuffer* buffer :
    416        GetPointsToSet(instruction).element(index)) {
    417     VLOG(3) << "Examining buffer " << *buffer;
    418     if (HasAllocation(*buffer)) {
    419       VLOG(3) << "Has allocation";
    420       const BufferAllocation::Slice slice =
    421           GetAssignedAllocation(*buffer).GetSlice(*buffer);
    422       if (result.allocation() == nullptr) {
    423         result = slice;
    424       } else if (result != slice) {
    425         return FailedPrecondition(
    426             "BufferAllocation::Slice for instruction %s at index %s cannot "
    427             "be determined at compile-time.",
    428             instruction->name(), index.ToString());
    429       }
    430     } else {
    431       VLOG(3) << "No allocation";
    432     }
    433   }
    434   if (result.allocation() == nullptr) {
    435     return FailedPrecondition(
    436         "BufferAllocation::Slice not assigned for instruction %s at index %s",
    437         instruction->name(), index.ToString());
    438   }
    439   return result;
    440 }
    441 
    442 StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueTopLevelSlice(
    443     const HloInstruction* instruction) const {
    444   return GetUniqueSlice(instruction, /*index=*/{});
    445 }
    446 
    447 bool BufferAssignment::SharesSliceAtIndex(
    448     const HloInstruction* hlo_a, const ShapeIndex& shape_index_a,
    449     const HloInstruction* hlo_b, const ShapeIndex& shape_index_b) const {
    450   return GetUniqueSlice(hlo_a, shape_index_a).ConsumeValueOrDie() ==
    451          GetUniqueSlice(hlo_b, shape_index_b).ConsumeValueOrDie();
    452 }
    453 
    454 bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a,
    455                                           const HloInstruction* hlo_b) const {
    456   using SliceSet = flat_hash_set<BufferAllocation::Slice>;
    457   // Gets the slices all of instr's subshapes.  If any subshape doesn't have an
    458   // assigned slice, returns the empty set.
    459   auto collect_slices = [&](const HloInstruction* instr) -> SliceSet {
    460     SliceSet slices;
    461     Status status = ShapeUtil::ForEachSubshapeWithStatus(
    462         instr->shape(),
    463         [&](const Shape& /*subshape*/, const ShapeIndex& index) {
    464           auto shape_slices = GetAllSlices(instr, index);
    465           if (shape_slices.empty()) {
    466             return InvalidArgument("No slices assigned to part of instr.");
    467           }
    468           slices.insert(shape_slices.begin(), shape_slices.end());
    469           return Status::OK();
    470         });
    471     if (!status.ok()) {
    472       return {};
    473     }
    474     return slices;
    475   };
    476 
    477   SliceSet slices_a = collect_slices(hlo_a);
    478   SliceSet slices_b = collect_slices(hlo_b);
    479   // hlo_a and hlo_b have disjoint slices if collect_slices succeeded (i.e.
    480   // didn't return the empty set) for both HLOs, and the two resulting sets of
    481   // slices are disjoint.
    482   return !slices_a.empty() && !slices_b.empty() &&
    483          absl::c_none_of(slices_a, [&](const BufferAllocation::Slice& slice) {
    484            return slices_b.contains(slice);
    485          });
    486 }
    487 
    488 StatusOr<BufferAllocation::Slice>
    489 BufferAssignment::GetUniqueTopLevelOutputSlice() const {
    490   return GetUniqueTopLevelSlice(
    491       module_->entry_computation()->root_instruction());
    492 }
    493 
    494 BufferAllocation* BufferAssignment::NewEmptyAllocation(
    495     int64 size, LogicalBuffer::Color color) {
    496   BufferAllocation::Index index = allocations_.size();
    497   allocations_.emplace_back(index, size, color);
    498   BufferAllocation* allocation = &allocations_.back();
    499   return allocation;
    500 }
    501 
    502 BufferAllocation* BufferAssignment::NewAllocation(const LogicalBuffer& buffer,
    503                                                   int64 size) {
    504   BufferAllocation* allocation = NewEmptyAllocation(size, buffer.color());
    505   AddAssignment(allocation, buffer, /*offset=*/0, size);
    506   allocation->peak_buffers_.push_back(&buffer);
    507   return allocation;
    508 }
    509 
    510 // Adds an instruction to the set assigned to the given buffer.
    511 void BufferAssignment::AddAssignment(BufferAllocation* allocation,
    512                                      const LogicalBuffer& buffer, int64 offset,
    513                                      int64 size) {
    514   CHECK(!allocation_index_for_buffer_.contains(&buffer))
    515       << "LogicalBuffer " << buffer << " already has an allocation.";
    516   CHECK(allocation->is_reusable() || allocation->assigned_buffers().empty())
    517       << "Non-reusable allocation already assigned a buffer: "
    518       << allocation->ToString();
    519 
    520   TF_CHECK_OK(points_to_analysis().VerifyBuffer(buffer));
    521 
    522   allocation->AddAssignment(buffer, offset, size);
    523   if (liveness().MaybeLiveOut(buffer)) {
    524     allocation->set_maybe_live_out(true);
    525   }
    526   allocation_index_for_buffer_[&buffer] = allocation->index();
    527 }
    528 
    529 // Combines allocations of temporary buffers of the same color into one big
    530 // BufferAllocation.
    531 void BufferAssignment::CombineTempAllocations() {
    532   VLOG(1) << "CombineTempAllocations()";
    533   flat_hash_map<LogicalBuffer::Color, BufferAllocation,
    534                 LogicalBuffer::Color::Hasher>
    535       combined_allocation_map;
    536 
    537   // Move all temp allocations into a single run at the end of the allocations
    538   // vector.
    539   const auto first_temp_it =
    540       std::partition(allocations_.begin(), allocations_.end(),
    541                      [](const BufferAllocation& allocation) {
    542                        return !allocation.IsPreallocatedTempBuffer();
    543                      });
    544 
    545   // Walk over the run of temp allocations, collecting the allocations belonging
    546   // to the same color.
    547   if (first_temp_it != allocations_.end()) {
    548     for (auto it = first_temp_it; it != allocations_.end(); ++it) {
    549       const BufferAllocation& temp_allocation = *it;
    550       LogicalBuffer::Color color = temp_allocation.color();
    551       auto combined_it = combined_allocation_map.find(color);
    552       if (combined_it == combined_allocation_map.end()) {
    553         // We have found the first temp allocation of this color. Collect
    554         // the other temp allocations of the same color into it.
    555         VLOG(1) << "Combined temp allocation for color " << color
    556                 << " is: " << temp_allocation;
    557         combined_allocation_map.emplace(color, temp_allocation);
    558         continue;
    559       }
    560 
    561       auto* combined_allocation = &combined_it->second;
    562       VLOG(1) << "Combined allocation absorbing temp allocation: "
    563               << temp_allocation;
    564 
    565       // Each temp allocation is placed end-to-end, accounting for alignment.
    566       // The offset of each buffer in the combined allocation is computed from
    567       // the base offset of the allocation.
    568       int64 alignment = color_alignment_(color);
    569       const int64 base =
    570           RoundUpToNearest(combined_allocation->size(), alignment);
    571       combined_allocation->set_size(base + temp_allocation.size());
    572       for (const auto& buffer_offset_size : temp_allocation.assigned_buffers_) {
    573         const LogicalBuffer* buffer = buffer_offset_size.first;
    574         const int64 offset = buffer_offset_size.second.offset;
    575         const int64 size = buffer_offset_size.second.size;
    576         combined_allocation->AddAssignment(*buffer, base + offset, size);
    577       }
    578       if (!temp_allocation.HeapTraces().empty()) {
    579         CHECK_EQ(temp_allocation.HeapTraces().size(), 1);
    580         combined_allocation->AddHeapTrace(temp_allocation.HeapTraces().front());
    581       }
    582       combined_allocation->peak_buffers_.insert(
    583           combined_allocation->peak_buffers_.end(),
    584           temp_allocation.peak_buffers_.begin(),
    585           temp_allocation.peak_buffers_.end());
    586     }
    587     // Replace all existing temporary allocations with the new combined
    588     // allocations.
    589     allocations_.erase(first_temp_it, allocations_.end());
    590     for (auto& combined : combined_allocation_map) {
    591       allocations_.push_back(combined.second);
    592       temp_allocation_total_size_ += combined.second.size();
    593     }
    594   }
    595 
    596   // Update allocation indices to their new positions.
    597   allocation_index_for_buffer_.erase(allocation_index_for_buffer_.begin(),
    598                                      allocation_index_for_buffer_.end());
    599   for (size_t index = 0; index < allocations_.size(); ++index) {
    600     BufferAllocation* allocation = &allocations_[index];
    601     allocation->set_index(index);
    602     for (const auto& buffer_offset_size : allocation->assigned_buffers_) {
    603       const LogicalBuffer* buffer = buffer_offset_size.first;
    604       allocation_index_for_buffer_[buffer] = index;
    605     }
    606   }
    607 }
    608 
    609 Status BufferAssignment::ComputeSummaryStats() {
    610   for (auto& allocation : Allocations()) {
    611     if (allocation.is_entry_computation_parameter()) {
    612       stats_.parameter_allocation_count++;
    613       stats_.parameter_allocation_bytes += allocation.size();
    614     }
    615     if (allocation.is_constant()) {
    616       stats_.constant_allocation_count++;
    617       stats_.constant_allocation_bytes += allocation.size();
    618     }
    619     if (allocation.maybe_live_out()) {
    620       stats_.maybe_live_out_allocation_count++;
    621       stats_.maybe_live_out_allocation_bytes += allocation.size();
    622     }
    623     if (allocation.IsPreallocatedTempBuffer()) {
    624       stats_.preallocated_temp_allocation_count++;
    625       stats_.preallocated_temp_allocation_bytes += allocation.size();
    626     }
    627     stats_.total_allocation_count++;
    628     stats_.total_allocation_bytes += allocation.size();
    629   }
    630 
    631   // Only compute total fragmentation if all computations have schedules.
    632   HloSchedule schedule(module_);
    633   bool schedule_complete = true;
    634   for (const auto& computation : module_->computations()) {
    635     if (!computation->IsFusionComputation()) {
    636       const HloInstructionSequence* sequence =
    637           liveness_->hlo_ordering().SequentialOrder(*computation);
    638       if (sequence == nullptr) {
    639         schedule_complete = false;
    640       } else {
    641         schedule.set_sequence(computation, *sequence);
    642       }
    643     }
    644   }
    645   if (schedule_complete) {
    646     TF_RETURN_IF_ERROR(schedule.Verify());
    647     TF_ASSIGN_OR_RETURN(
    648         const int64 min_size,
    649         HeapSimulator::MinimumMemoryForModule(schedule, buffer_size_));
    650     stats_.total_fragmentation_bytes = stats_.total_allocation_bytes - min_size;
    651   }
    652 
    653   return Status::OK();
    654 }
    655 
    656 string BufferAssignment::Stats::ToString() const {
    657   string s;
    658   StrAppendFormat(&s, "BufferAssignment stats:\n");
    659   StrAppendFormat(&s, "             parameter allocation: %10s\n",
    660                   HumanReadableNumBytes(parameter_allocation_bytes));
    661   StrAppendFormat(&s, "              constant allocation: %10s\n",
    662                   HumanReadableNumBytes(constant_allocation_bytes));
    663   StrAppendFormat(&s, "        maybe_live_out allocation: %10s\n",
    664                   HumanReadableNumBytes(maybe_live_out_allocation_bytes));
    665   StrAppendFormat(&s, "     preallocated temp allocation: %10s\n",
    666                   HumanReadableNumBytes(preallocated_temp_allocation_bytes));
    667   if (preallocated_temp_fragmentation_bytes >= 0) {
    668     const double percent = 100. * preallocated_temp_fragmentation_bytes /
    669                            preallocated_temp_allocation_bytes;
    670     StrAppendFormat(
    671         &s, "  preallocated temp fragmentation: %10s (%.2f%%)\n",
    672         HumanReadableNumBytes(preallocated_temp_fragmentation_bytes), percent);
    673   }
    674   StrAppendFormat(&s, "                 total allocation: %10s\n",
    675                   HumanReadableNumBytes(total_allocation_bytes));
    676   if (total_fragmentation_bytes >= 0) {
    677     const double percent =
    678         100. * total_fragmentation_bytes / total_allocation_bytes;
    679     StrAppendFormat(&s, "              total fragmentation: %10s (%.2f%%)\n",
    680                     HumanReadableNumBytes(total_fragmentation_bytes), percent);
    681   }
    682   return s;
    683 }
    684 
    685 string BufferAssignment::ToString() const {
    686   string output;
    687   absl::StrAppend(&output, "BufferAssignment:\n");
    688   for (auto& allocation : allocations_) {
    689     absl::StrAppend(&output, allocation.ToString());
    690   }
    691   return output;
    692 }
    693 
    694 BufferAssignmentProto BufferAssignment::ToProto() const {
    695   BufferAssignmentProto proto;
    696   // NOTE: TuplePointsToAnalysis state is serialized here in BufferAssigment,
    697   // because we need to do the HasAllocation check for each buffer. Otherwise
    698   // the buffer_size_ call might fail for some backends.
    699   const TuplePointsToAnalysis& points_to_analysis =
    700       liveness_->points_to_analysis();
    701   for (LogicalBuffer::Id id = 0; id < points_to_analysis.num_logical_buffers();
    702        id++) {
    703     auto& buffer = points_to_analysis.logical_buffer(id);
    704     if (HasAllocation(buffer)) {
    705       LogicalBufferProto proto_buffer = buffer.ToProto(buffer_size_);
    706       proto.add_logical_buffers()->Swap(&proto_buffer);
    707 
    708       // Fill buffer aliases.
    709       for (const BufferAlias& alias :
    710            points_to_analysis.GetBufferAliases(buffer)) {
    711         if (alias.instruction() == buffer.instruction() &&
    712             alias.index() == buffer.index()) {
    713           continue;  // skip self-aliases
    714         }
    715         BufferAssignmentProto::BufferAlias* proto_alias =
    716             proto.add_buffer_aliases();
    717         LogicalBufferProto::Location proto_alias_location =
    718             BufferValue::ToLocationProto(*alias.instruction(), alias.index());
    719         proto_alias->set_source_buffer_id(buffer.id());
    720         proto_alias->mutable_location()->Swap(&proto_alias_location);
    721       }
    722     }
    723   }
    724   for (const BufferAllocation& allocation : Allocations()) {
    725     BufferAllocationProto proto_allocation = allocation.ToProto();
    726     proto.add_buffer_allocations()->Swap(&proto_allocation);
    727     for (const HeapSimulatorTrace& heap_trace : allocation.HeapTraces()) {
    728       *proto.add_heap_simulator_traces() = heap_trace;
    729     }
    730   }
    731   return proto;
    732 }
    733 
    734 /* static */
    735 StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::Run(
    736     const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
    737     LogicalBuffer::SizeFunction buffer_size,
    738     LogicalBuffer::AlignmentFunction color_alignment,
    739     bool allow_input_output_aliasing, bool allocate_buffers_for_constants,
    740     BufferLiveness::Colorer colorer, ReuseAllocationFunction reuse_checker) {
    741   BufferAssigner assigner(allocate_buffers_for_constants, std::move(colorer),
    742                           std::move(reuse_checker));
    743   return assigner.CreateAssignment(module, std::move(hlo_ordering),
    744                                    std::move(buffer_size),
    745                                    std::move(color_alignment));
    746 }
    747 
    748 namespace {
    749 
    750 // a and b are in different subcomputations. Check for the case
    751 // where a is inside the while body, and b is outside, part of the same while's
    752 // init-operand or while-result.
    753 bool MayInterfereAcrossSubcomputations(BufferAssignment* assignment,
    754                                        const LogicalBuffer& a_buffer,
    755                                        const LogicalBuffer& b_buffer) {
    756   const CallGraph& call_graph =
    757       assignment->liveness().hlo_ordering().call_graph();
    758   const HloInstruction* a_ancestor;
    759   const HloInstruction* b_ancestor;
    760   std::tie(a_ancestor, b_ancestor) =
    761       call_graph.NearestAncestorsInSameComputation(a_buffer.instruction(),
    762                                                    b_buffer.instruction());
    763   if (a_ancestor == nullptr) {
    764     // No common ancestor.
    765     return true;
    766   }
    767   if (a_ancestor->opcode() == HloOpcode::kWhile &&
    768       call_graph.InstructionIsNestedIn(a_buffer.instruction(),
    769                                        a_ancestor->while_body())) {
    770     const PointsToSet& init_set =
    771         assignment->liveness().points_to_analysis().GetPointsToSet(
    772             a_ancestor->operand(0));
    773     if (init_set.ContainsBuffer(b_buffer)) {
    774       VLOG(4) << "Can't interfere: " << a_buffer << " and " << b_buffer
    775               << " (part of while-operand)";
    776       return false;
    777     }
    778     const PointsToSet& while_set =
    779         assignment->liveness().points_to_analysis().GetPointsToSet(a_ancestor);
    780     if (while_set.ContainsBuffer(b_buffer)) {
    781       VLOG(4) << "Can't interfere: " << a_buffer << " and " << b_buffer
    782               << " (part of while)";
    783       return false;
    784     }
    785   }
    786   return true;
    787 }
    788 
    789 // Return true, if a and b can't possibly interfere (and therefore further
    790 // checking for interference can be skipped). This function checks for special
    791 // cases where copy insertion guarantees no interference, but the regular buffer
    792 // liveness is too conservative:
    793 //
    794 // Operations inside a while-body can't interfere with operations outside the
    795 // while op if their last use is at the while-loop itself as part of the
    796 // while-init op, or the while-result.  For ops that are live across a
    797 // while-loop, copy insertion will already insert the necessary copies to avoid
    798 // such interference.
    799 //
    800 // This allows sharing buffers in cases like this:
    801 // init = {...}
    802 // while (init):
    803 //  p = param(0)
    804 //  gte = get-tuple-element(p), index=i
    805 //  t1 = op1 (gte)
    806 //  t2 = op2 (t1)
    807 //  ROOT tuple = {..., t2, ...}
    808 //
    809 // where t1 and t2 can share the same buffer.
    810 bool MaySkipInterferenceCheck(BufferAssignment* assignment,
    811                               const LogicalBuffer& a_buffer,
    812                               const LogicalBuffer& b_buffer) {
    813   if (a_buffer.instruction()->parent() == b_buffer.instruction()->parent()) {
    814     // Ops within the same computation are not handled here. Assume that they
    815     // may interfere.
    816     return false;
    817   }
    818   return !MayInterfereAcrossSubcomputations(assignment, a_buffer, b_buffer) ||
    819          !MayInterfereAcrossSubcomputations(assignment, b_buffer, a_buffer);
    820 }
    821 
    822 }  // namespace
    823 
    824 bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
    825                                        const LogicalBuffer& buffer,
    826                                        BufferAssignment* assignment) {
    827   const LogicalBuffer::SizeFunction& buffer_size = assignment->buffer_size_;
    828 
    829   CHECK(!assignment->HasAllocation(buffer))
    830       << "buffer " << buffer << " already has an allocation assigned.";
    831 
    832   VLOG(4) << "Trying to assign " << buffer << " to allocation: " << *allocation;
    833 
    834   if (buffer.color() != allocation->color()) {
    835     VLOG(4) << "Can't assign: buffer has color" << buffer.color()
    836             << " and allocation has color " << allocation->color() << ".";
    837     return false;
    838   }
    839 
    840   if (buffer_size(buffer) > allocation->size()) {
    841     VLOG(4) << "Can't assign: buffer is larger than allocation ("
    842             << buffer_size(buffer) << " > " << allocation->size() << ")";
    843     return false;
    844   }
    845 
    846   if (allocation->is_readonly()) {
    847     VLOG(4) << "Can't assign: allocation is readonly";
    848     return false;
    849   }
    850 
    851   if (reuse_checker_ != nullptr &&
    852       !reuse_checker_(*assignment, *allocation, buffer)) {
    853     VLOG(4) << "Can't assign: reuse_checker_(allocation, buffer) == false";
    854     return false;
    855   }
    856 
    857   if (!allocation->is_reusable()) {
    858     VLOG(4) << "Can't assign: allocation is not reusable";
    859     return false;
    860   }
    861 
    862   for (const auto& buffer_offset_size : allocation->assigned_buffers()) {
    863     const LogicalBuffer& assigned_buffer = *buffer_offset_size.first;
    864     if (MaySkipInterferenceCheck(assignment, buffer, assigned_buffer)) {
    865       continue;
    866     }
    867     if (assignment->liveness().MayInterfere(assigned_buffer, buffer)) {
    868       VLOG(4) << "Can't assign: assignee " << assigned_buffer
    869               << " may interfere with " << buffer;
    870       return false;
    871     }
    872     // Copy instruction don't share a buffer with their input operand.
    873     if (buffer.instruction()->IsUserOf(assigned_buffer.instruction()) &&
    874         buffer.instruction()->opcode() == HloOpcode::kCopy) {
    875       VLOG(4) << "Can't assign: assignee " << assigned_buffer
    876               << " is used at copy instruction " << buffer;
    877       return false;
    878     }
    879   }
    880 
    881   // If the buffer is live out of the computation then it should only be
    882   // assigned a buffer which exactly fits the result to avoid wasting memory
    883   // (result buffers can have arbitrary lifetimes).
    884   if (assignment->liveness().MaybeLiveOut(buffer) &&
    885       allocation->size() != buffer_size(buffer)) {
    886     VLOG(4) << "Can't assign: buffer " << buffer
    887             << "is live out and size not the same as allocation";
    888     return false;
    889   }
    890 
    891   assignment->AddAssignment(allocation, buffer, /*offset=*/0,
    892                             buffer_size(buffer));
    893   return true;
    894 }
    895 
    896 Status BufferAssigner::AssignBuffersForComputation(
    897     const HloComputation* computation, bool is_thread_local,
    898     const flat_hash_set<const LogicalBuffer*>& colocated_buffers,
    899     const flat_hash_set<BufferAllocation::Index>& colocated_allocations,
    900     flat_hash_map<const HloComputation*, flat_hash_set<const LogicalBuffer*>>*
    901         buffers_to_assign_sequentially,
    902     BufferAssignment* assignment) {
    903   // Buffers are sorted and assigned to BufferAllocations in decreasing order of
    904   // size.
    905   std::vector<const LogicalBuffer*> sorted_buffers;
    906   for (auto* instruction : computation->instructions()) {
    907     // Add all buffers which this instruction defines. Instruction which don't
    908     // define buffers (eg, bitcast which just forwards a pointer) don't need
    909     // any allocations.
    910     for (const LogicalBuffer* buffer :
    911          assignment->points_to_analysis().GetBuffersDefinedByInstruction(
    912              instruction)) {
    913       sorted_buffers.push_back(buffer);
    914     }
    915   }
    916 
    917   // Generate a post order sort of instructions for sorting of the
    918   // LogicalBuffers.
    919   flat_hash_map<const HloInstruction*, int> post_order_position;
    920   int position = 0;
    921   for (auto* instruction : computation->MakeInstructionPostOrder()) {
    922     post_order_position.emplace(instruction, position);
    923     position++;
    924   }
    925 
    926   // If there is a sequential instruction ordering, we'll delay assignment of
    927   // temp buffers until after the main assignment loop.
    928   const BufferLiveness& liveness = assignment->liveness();
    929   const bool has_sequential_order =
    930       liveness.hlo_ordering().SequentialOrder(*computation) != nullptr;
    931   if (has_sequential_order && buffers_to_assign_sequentially != nullptr) {
    932     // Every sequential computation must get an entry in the
    933     // buffers_to_assign_sequentially map, even if we end up with an empty set
    934     // of buffers. This ensures we can correctly determine whether to run
    935     // whole-module heap simulation.
    936     buffers_to_assign_sequentially->emplace(
    937         computation, flat_hash_set<const LogicalBuffer*>());
    938   }
    939 
    940   // Sort the LogicalBuffers first by size. We assign the larger LogicalBuffers
    941   // first for simplicity. This means any previously created BufferAllocation is
    942   // necessarily large enough to hold the output of the current Buffer in
    943   // consideration.
    944   //
    945   // As a secondary sorting criteria, if the instructions are sequentially
    946   // ordered, we assign live-out buffers before others. Note that for sequential
    947   // computations, we'll take temp buffers that can't re-use any allocations and
    948   // assign them via a heap scheduler. By assigning live-out buffers first, we
    949   // increase the odds that temp buffers can re-use an allocation.
    950   //
    951   // As a final tiebreaker use post order position of the HLO instruction which
    952   // defines the buffer. This means an instruction will appear after its
    953   // operands (assuming operands are the same/larger size) enabling the
    954   // important reuse case where an elementwise instruction reuses one of its
    955   // operand's buffer. This improves locality.
    956   absl::c_sort(sorted_buffers,
    957                [has_sequential_order, &liveness, &post_order_position,
    958                 assignment](const LogicalBuffer* a, const LogicalBuffer* b) {
    959                  // Primary sort is by decreasing buffer size.
    960                  const int64 a_size = assignment->buffer_size_(*a);
    961                  const int64 b_size = assignment->buffer_size_(*b);
    962                  if (a_size != b_size) {
    963                    return a_size > b_size;  // use ">" for decreasing size.
    964                  }
    965                  // Otherwise live out buffers come before others, if the
    966                  // instructions are sequentially ordered.
    967                  if (has_sequential_order) {
    968                    const bool a_live_out = liveness.MaybeLiveOut(*a);
    969                    const bool b_live_out = liveness.MaybeLiveOut(*b);
    970                    if (a_live_out != b_live_out) {
    971                      return a_live_out;
    972                    }
    973                  }
    974                  // Final tiebreaker is in instruction post order.
    975                  return post_order_position.at(a->instruction()) <
    976                         post_order_position.at(b->instruction());
    977                });
    978 
    979   // BufferAllocations are necessarily created in decreasing size order. Keep
    980   // indices of previously created BufferAllocations in allocation_indices.
    981   std::vector<BufferAllocation::Index> allocation_indices;
    982   for (const LogicalBuffer* buffer : sorted_buffers) {
    983     VLOG(3) << "Assigning allocation to: " << *buffer;
    984     if (colocated_buffers.contains(buffer)) {
    985       // Colocated buffers are currently assigned in an earlier pass.
    986       VLOG(3) << "Skipping colocated buffer: " << *buffer;
    987       continue;
    988     }
    989 
    990     TF_RET_CHECK(!assignment->HasAllocation(*buffer));
    991 
    992     const HloInstruction* instruction = buffer->instruction();
    993     const int64 buffer_size = assignment->buffer_size_(*buffer);
    994 
    995     if (instruction->opcode() == HloOpcode::kConstant) {
    996       if (allocate_buffers_for_constants_) {
    997         BufferAllocation* allocation =
    998             assignment->NewAllocation(*buffer, buffer_size);
    999         allocation->set_constant(true);
   1000         VLOG(3) << "New allocation #" << allocation->index() << " for constant "
   1001                 << *buffer;
   1002       }
   1003       continue;
   1004     }
   1005 
   1006     const bool is_entry_parameter =
   1007         instruction->opcode() == HloOpcode::kParameter &&
   1008         computation == computation->parent()->entry_computation();
   1009     if (is_entry_parameter) {
   1010       // If the LogicalBuffer is part of an external parameter, creates a new
   1011       // allocation and sets its parameter number. Parameters of non-entry
   1012       // computations do not need special allocations because they live inside
   1013       // callers.
   1014       BufferAllocation* allocation =
   1015           assignment->NewAllocation(*buffer, buffer_size);
   1016       bool parameter_has_alias =
   1017           assignment->module().input_output_alias_config().ParameterHasAlias(
   1018               instruction->parameter_number(), buffer->index());
   1019       allocation->set_entry_computation_parameter(
   1020           instruction->parameter_number(), buffer->index(),
   1021           parameter_has_alias);
   1022       VLOG(3) << "Mark allocation #" << allocation->index()
   1023               << " as entry computation parameter: " << *buffer;
   1024       continue;
   1025     }
   1026 
   1027     if (is_thread_local) {
   1028       BufferAllocation* allocation =
   1029           assignment->NewAllocation(*buffer, buffer_size);
   1030       allocation->set_is_thread_local(true);
   1031       VLOG(3) << "New allocation #" << allocation->index()
   1032               << " for thread-local: " << *buffer;
   1033       continue;
   1034     }
   1035 
   1036     if (buffer->shape().IsTuple()) {
   1037       BufferAllocation* allocation =
   1038           assignment->NewAllocation(*buffer, buffer_size);
   1039       allocation->set_is_tuple(true);
   1040       VLOG(3) << "New allocation #" << allocation->index()
   1041               << " for tuple-shaped buffer: " << *buffer;
   1042       continue;
   1043     }
   1044 
   1045     // First try to assign a LogicalBuffer to one of its operand allocations to
   1046     // improve locality. This is only possible with elementwise operations
   1047     // (checked in liveness analysis) which are necessarily top-level
   1048     // array-shaped buffers.
   1049     if (buffer->IsTopLevel() && !buffer->IsTuple()) {
   1050       for (auto* operand : instruction->operands()) {
   1051         bool assigned_operand = false;
   1052         for (const auto& operand_slice :
   1053              assignment->GetAllSlices(operand, /*index=*/{})) {
   1054           BufferAllocation* allocation =
   1055               assignment->GetMutableAllocation(operand_slice.index());
   1056           if (!colocated_allocations.contains(allocation->index())) {
   1057             // TODO(b/32491382) Colocated buffers are currently assigned in an
   1058             // earlier pass, and so can break the "increasing allocation size"
   1059             // invariant in this function (causing this CHECK to fail). However,
   1060             // the call to MaybeAssignBuffer is safe as it returns false if
   1061             // allocation.size < buffer.size.
   1062             CHECK_GE(allocation->size(), buffer_size);
   1063           }
   1064           if (MaybeAssignBuffer(allocation, *buffer, assignment)) {
   1065             VLOG(3) << "Reusing (operand) allocation #" << allocation->index()
   1066                     << " for: " << *buffer;
   1067             assigned_operand = true;
   1068             break;
   1069           }
   1070         }
   1071         if (assigned_operand) {
   1072           break;
   1073         }
   1074       }
   1075     }
   1076 
   1077     if (!assignment->HasAllocation(*buffer)) {
   1078       // Find the smallest buffer which can be reused iterating from end of
   1079       // allocation_indices (smallest) to beginning (largest).
   1080       for (int allocation_index = allocation_indices.size() - 1;
   1081            allocation_index >= 0; allocation_index--) {
   1082         BufferAllocation* allocation = assignment->GetMutableAllocation(
   1083             allocation_indices[allocation_index]);
   1084         // Instructions are iterated in increasing buffer size, so any
   1085         // previously create allocation must be large enough to hold this
   1086         // instruction's output (with the exception of colocated buffers).
   1087         if (!colocated_allocations.contains(allocation->index())) {
   1088           // TODO(b/32491382) Colocated buffers are currently assigned in an
   1089           // earlier pass, and so can break the "increasing allocation size"
   1090           // invariant in this function (causing this CHECK to fail). However,
   1091           // the call to MaybeAssignBuffer is safe as it returns false if
   1092           // allocation.size < buffer.size.
   1093           CHECK_GE(allocation->size(), buffer_size);
   1094         }
   1095 
   1096         if (MaybeAssignBuffer(allocation, *buffer, assignment)) {
   1097           VLOG(3) << "Reusing allocation #" << allocation->index()
   1098                   << " for: " << *buffer;
   1099           break;
   1100         }
   1101       }
   1102     }
   1103 
   1104     if (!assignment->HasAllocation(*buffer) && has_sequential_order &&
   1105         !liveness.MaybeLiveOut(*buffer)) {
   1106       // There is a sequential instruction ordering, so we delay assignment of
   1107       // temp buffers until after the loop. We do this right before we decide to
   1108       // create a new allocation, to ensure we've exhausted all the buffer
   1109       // re-use cases above.
   1110       //
   1111       // Entry parameters and thread local buffers were already handled earlier
   1112       // in this loop iteration.  See BufferAllocation::IsPreallocatedTempBuffer
   1113       // for the definition of temp buffers.
   1114       CHECK(!is_entry_parameter) << *buffer;
   1115       CHECK(!is_thread_local) << *buffer;
   1116       (*buffers_to_assign_sequentially)[computation].insert(buffer);
   1117       VLOG(3) << "Delaying assignment of temp buffer: " << *buffer;
   1118       continue;
   1119     }
   1120 
   1121     if (!assignment->HasAllocation(*buffer)) {
   1122       BufferAllocation* allocation =
   1123           assignment->NewAllocation(*buffer, buffer_size);
   1124       allocation_indices.push_back(allocation->index());
   1125       VLOG(3) << "New allocation #" << allocation->index()
   1126               << " for: " << *buffer;
   1127     }
   1128   }
   1129 
   1130   return Status::OK();
   1131 }
   1132 
   1133 flat_hash_map<LogicalBuffer::Color, flat_hash_set<const LogicalBuffer*>,
   1134               LogicalBuffer::Color::Hasher>
   1135 BufferAssigner::SplitBuffersByColor(
   1136     const flat_hash_set<const LogicalBuffer*>& buffers) {
   1137   flat_hash_map<LogicalBuffer::Color, flat_hash_set<const LogicalBuffer*>,
   1138                 LogicalBuffer::Color::Hasher>
   1139       color_map;
   1140   for (auto buffer : buffers) {
   1141     color_map[buffer->color()].insert(buffer);
   1142   }
   1143   return color_map;
   1144 }
   1145 
   1146 Status BufferAssigner::AssignBuffersWithSequentialOrdering(
   1147     const flat_hash_map<const HloComputation*,
   1148                         flat_hash_set<const LogicalBuffer*>>&
   1149         buffers_to_assign_sequentially,
   1150     bool run_whole_module_heap_simulation, BufferAssignment* assignment) {
   1151   // Run the sequence of instructions through the heap simulator.  The heuristic
   1152   // that seems to give the best results is lazy-best-fit, with all runs of
   1153   // alloc / free calls sorted in decreasing size order.
   1154   const HloOrdering& hlo_ordering = assignment->liveness().hlo_ordering();
   1155 
   1156   // Returns a heap algorithm that chooses the best result from several
   1157   // algorithms.
   1158   auto get_heap_algorithm = [&](int64 alignment) {
   1159     auto algorithms =
   1160         absl::make_unique<std::vector<std::unique_ptr<HeapAlgorithm>>>();
   1161     algorithms->push_back(absl::make_unique<DecreasingSizeRunsHeap>(
   1162         absl::make_unique<LazyBestFitHeap>(alignment)));
   1163     algorithms->push_back(
   1164         absl::make_unique<GlobalDecreasingSizeBestFitHeap>(alignment));
   1165     return absl::make_unique<ChooseBestHeapAlgorithm>(std::move(algorithms));
   1166   };
   1167 
   1168   if (run_whole_module_heap_simulation) {
   1169     // Run the heap simulation over the whole module. This reduces memory usage,
   1170     // since buffers for kCall, kWhile, and kConditional sub-computations are
   1171     // only live for the duration of their calling instructions.
   1172     VLOG(1) << "Running whole-module heap simulation";
   1173     HloSchedule schedule(&assignment->module());
   1174     flat_hash_set<const LogicalBuffer*> all_buffers_to_assign;
   1175     for (const auto& pair : buffers_to_assign_sequentially) {
   1176       const HloComputation* computation = pair.first;
   1177       const flat_hash_set<const LogicalBuffer*>& buffers_to_assign =
   1178           pair.second;
   1179       const HloInstructionSequence* instruction_sequence =
   1180           hlo_ordering.SequentialOrder(*computation);
   1181       CHECK(instruction_sequence != nullptr) << computation->name();
   1182       schedule.set_sequence(computation, *instruction_sequence);
   1183       all_buffers_to_assign.insert(buffers_to_assign.begin(),
   1184                                    buffers_to_assign.end());
   1185     }
   1186     auto color_map = SplitBuffersByColor(all_buffers_to_assign);
   1187     for (auto& single_colored_set : color_map) {
   1188       auto color = single_colored_set.first;
   1189       VLOG(2) << "Simulating heap for color " << color;
   1190       int64 alignment = assignment->color_alignment_(color);
   1191       HeapSimulator::Options options;
   1192       options.alloc_constants = allocate_buffers_for_constants_;
   1193       BufferValueFlatSet buffer_value_set =
   1194           ToBufferValueFlatSet(single_colored_set.second);
   1195       options.buffers_to_assign = &buffer_value_set;
   1196       TF_ASSIGN_OR_RETURN(
   1197           const HeapSimulator::Result result,
   1198           HeapSimulator::Run(get_heap_algorithm(alignment),
   1199                              assignment->module(), schedule,
   1200                              assignment->points_to_analysis(),
   1201                              assignment->buffer_size_, options));
   1202       AssignBuffersFromHeapSimulator(result, assignment,
   1203                                      single_colored_set.first);
   1204     }
   1205   } else {
   1206     // Run the heap-simulation on a per-computation basis. Buffers for
   1207     // sub-computations are assigned disjoint BufferAllocations, assuming the
   1208     // worst-case that they may all be live concurrently.
   1209     VLOG(1) << "Running per-computation heap simulation";
   1210     for (const auto& pair : buffers_to_assign_sequentially) {
   1211       const HloComputation* computation = pair.first;
   1212       const flat_hash_set<const LogicalBuffer*>& buffers_to_assign =
   1213           pair.second;
   1214       const HloInstructionSequence* instruction_sequence =
   1215           hlo_ordering.SequentialOrder(*computation);
   1216       CHECK(instruction_sequence != nullptr) << computation->name();
   1217       auto color_map = SplitBuffersByColor(buffers_to_assign);
   1218       for (auto& single_colored_set : color_map) {
   1219         auto color = single_colored_set.first;
   1220         VLOG(2) << "Simulating heap for color " << color;
   1221         int64 alignment = assignment->color_alignment_(color);
   1222         HeapSimulator::Options options;
   1223         BufferValueFlatSet buffer_value_set =
   1224             ToBufferValueFlatSet(single_colored_set.second);
   1225         options.buffers_to_assign = &buffer_value_set;
   1226         TF_ASSIGN_OR_RETURN(
   1227             const HeapSimulator::Result result,
   1228             HeapSimulator::Run(get_heap_algorithm(alignment), *computation,
   1229                                *instruction_sequence,
   1230                                assignment->points_to_analysis(),
   1231                                assignment->buffer_size_, options));
   1232         AssignBuffersFromHeapSimulator(result, assignment,
   1233                                        single_colored_set.first);
   1234       }
   1235     }
   1236   }
   1237   return Status::OK();
   1238 }
   1239 
   1240 namespace {
   1241 
   1242 // Computes and returns the set of logical buffers live at the point of maximal
   1243 // liveness in the given heap trace. LogicalBuffers are (stabily) sorted by id.
   1244 std::vector<const LogicalBuffer*> ComputePeakMemoryLogicalBuffers(
   1245     const BufferAllocation& allocation, const HeapSimulatorTrace& heap_trace) {
   1246   // Create a map from LogicalBuffer::Id to LogicalBuffer* for the logical
   1247   // buffers in this allocation.
   1248   absl::flat_hash_map<LogicalBuffer::Id, const LogicalBuffer*> id_to_buffer;
   1249   absl::flat_hash_map<const LogicalBuffer*, int64> buffer_sizes;
   1250   for (const auto& pair : allocation.assigned_buffers()) {
   1251     const LogicalBuffer* buffer = pair.first;
   1252     const BufferAllocation::OffsetSize& offset_size = pair.second;
   1253     id_to_buffer[buffer->id()] = buffer;
   1254     buffer_sizes[buffer] = offset_size.size;
   1255   }
   1256 
   1257   // Returns how much the given event increases the total size of live
   1258   // buffers. Can be negative.
   1259   auto memory_delta = [&id_to_buffer, &buffer_sizes](
   1260                           const HeapSimulatorTrace::Event& event) -> int64 {
   1261     const LogicalBuffer* buffer = id_to_buffer.at(event.buffer_id());
   1262     const int64 buffer_size = buffer_sizes.at(buffer);
   1263     if (event.kind() == HeapSimulatorTrace::Event::ALLOC) {
   1264       return buffer_size;
   1265     } else if (event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) {
   1266       // Sharing a buffer does not change the live set size for the purposes of
   1267       // the heap simulator. Even though the shared-with buffer may be smaller,
   1268       // the entire allocation remains live.
   1269       return 0;
   1270     } else if (event.kind() == HeapSimulatorTrace::Event::FREE) {
   1271       return -1 * buffer_size;
   1272     }
   1273     LOG(FATAL) << "Unknown event kind: " << event.kind();
   1274   };
   1275 
   1276   // First compute the size of the maximal live set.
   1277   int64 max_live_size = 0;
   1278   int64 live_size = 0;
   1279   for (const auto& event : heap_trace.events()) {
   1280     live_size += memory_delta(event);
   1281     if (max_live_size < live_size) {
   1282       max_live_size = live_size;
   1283     }
   1284   }
   1285 
   1286   // Next gather the set of logical buffers live at the earliest point of
   1287   // maximal live set size.
   1288   absl::flat_hash_set<const LogicalBuffer*> live_buffers;
   1289   live_size = 0;
   1290   for (const auto& event : heap_trace.events()) {
   1291     const LogicalBuffer* buffer = id_to_buffer.at(event.buffer_id());
   1292     if (event.kind() == HeapSimulatorTrace::Event::ALLOC) {
   1293       InsertOrDie(&live_buffers, buffer);
   1294     } else if (event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) {
   1295       // Nothing to do.
   1296     } else if (event.kind() == HeapSimulatorTrace::Event::FREE) {
   1297       CHECK(ContainsKey(live_buffers, buffer));
   1298       live_buffers.erase(buffer);
   1299     }
   1300 
   1301     live_size += memory_delta(event);
   1302     if (live_size == max_live_size) {
   1303       break;
   1304     }
   1305   }
   1306   CHECK_EQ(live_size, max_live_size);
   1307 
   1308   std::vector<const LogicalBuffer*> live_buffers_vector;
   1309   live_buffers_vector.insert(live_buffers_vector.end(), live_buffers.begin(),
   1310                              live_buffers.end());
   1311 
   1312   // Stabily sort the live buffers.
   1313   absl::c_sort(live_buffers_vector,
   1314                [](const LogicalBuffer* a, const LogicalBuffer* b) {
   1315                  return a->id() < b->id();
   1316                });
   1317   return live_buffers_vector;
   1318 }
   1319 
   1320 }  // namespace
   1321 
   1322 void BufferAssigner::AssignBuffersFromHeapSimulator(
   1323     const HeapSimulator::Result& result, BufferAssignment* assignment,
   1324     LogicalBuffer::Color color) {
   1325   if (assignment->stats_.preallocated_temp_fragmentation_bytes == -1) {
   1326     assignment->stats_.preallocated_temp_fragmentation_bytes =
   1327         result.fragmentation_size;
   1328   } else {
   1329     assignment->stats_.preallocated_temp_fragmentation_bytes +=
   1330         result.fragmentation_size;
   1331   }
   1332 
   1333   BufferAllocation* allocation =
   1334       assignment->NewEmptyAllocation(result.heap_size, color);
   1335   for (const auto& buffer_chunk : result.chunk_map) {
   1336     // TODO(lauj) Remove this down_cast after downstream users of
   1337     // BufferAllocation::assigned_buffers() are updated to use BufferValue.
   1338     const LogicalBuffer& buffer =
   1339         *CHECK_NOTNULL(dynamic_cast<const LogicalBuffer*>(buffer_chunk.first));
   1340     const HeapSimulator::Chunk& chunk = buffer_chunk.second;
   1341     assignment->AddAssignment(allocation, buffer, chunk.offset, chunk.size);
   1342   }
   1343   allocation->peak_buffers_ =
   1344       ComputePeakMemoryLogicalBuffers(*allocation, result.debug_trace);
   1345 
   1346   VLOG(1) << "Ran heap simulation for allocation: " << allocation->ToString();
   1347   allocation->AddHeapTrace(result.debug_trace);
   1348 }
   1349 
   1350 // Adds the 'colocated_set' of buffers to 'colocated_buffer_sets', maintaining
   1351 // the invariant that all sets in 'colocated_buffer_sets' are disjoint.
   1352 //
   1353 // A practical example of when this is necessary is a chain of kCall ops:
   1354 //   computation.entry
   1355 //     %a = call() -> computation.1
   1356 //   computation.1
   1357 //     %b = call() -> computation.2
   1358 //   computation.2
   1359 //     %c = parameter()
   1360 // This yields the logical sets {%a,%b} {%b,%c} {%c}, which need to be merged
   1361 // into a single set {%a,%b,%c}
   1362 void BufferAssigner::AddSetToColocatedBufferSets(
   1363     const std::vector<const LogicalBuffer*>& colocated_set,
   1364     std::vector<ColocatedBufferSet>* colocated_buffer_sets) {
   1365   if (colocated_set.empty()) {
   1366     return;
   1367   }
   1368   VLOG(5) << ColocatedBufferSetsToString(colocated_set,
   1369                                          "Adding colocated buffer set");
   1370   // Find existing sets that overlap with at least one buffer from the
   1371   // colocated_set. The resulting 'overlap_set_indices' will have at most
   1372   // colocated_buffer_sets->size() entries, and will be in increasing order.
   1373   std::vector<size_t> overlap_set_indices;
   1374   for (size_t index = 0; index < colocated_buffer_sets->size(); ++index) {
   1375     for (const LogicalBuffer* buffer : colocated_set) {
   1376       if ((*colocated_buffer_sets)[index].contains(buffer)) {
   1377         VLOG(5) << "Found overlap with existing set on buffer "
   1378                 << buffer->ToString() << "\n"
   1379                 << ColocatedBufferSetsToString((*colocated_buffer_sets)[index],
   1380                                                "Overlapping set");
   1381         overlap_set_indices.push_back(index);
   1382         break;
   1383       }
   1384     }
   1385   }
   1386 
   1387   // If there is no overlap with existing sets, create a new set.
   1388   if (overlap_set_indices.empty()) {
   1389     colocated_buffer_sets->emplace_back();
   1390     colocated_buffer_sets->back().insert(colocated_set.begin(),
   1391                                          colocated_set.end());
   1392     VLOG(5) << "No overlap found, new group created";
   1393     return;
   1394   }
   1395 
   1396   // Merge all overlap sets and the colocated set into the first overlap set.
   1397   ColocatedBufferSet* first = &(*colocated_buffer_sets)[overlap_set_indices[0]];
   1398   for (size_t index = 1; index < overlap_set_indices.size(); ++index) {
   1399     const ColocatedBufferSet& overlap_set =
   1400         (*colocated_buffer_sets)[overlap_set_indices[index]];
   1401     first->insert(overlap_set.begin(), overlap_set.end());
   1402   }
   1403   first->insert(colocated_set.begin(), colocated_set.end());
   1404   VLOG(5) << ColocatedBufferSetsToString(
   1405       *first, "Result of the colocated buffer set merging");
   1406 
   1407   // Remove overlap sets that we just merged. The offset accounts for the fact
   1408   // that as elements are erased, the indices need to be adjusted. Keep in mind
   1409   // that overlap_set_indices is in increasing order.
   1410   for (size_t index = 1; index < overlap_set_indices.size(); ++index) {
   1411     const size_t offset = overlap_set_indices[index] - index + 1;
   1412     colocated_buffer_sets->erase(colocated_buffer_sets->begin() + offset);
   1413   }
   1414 }
   1415 
   1416 std::vector<BufferAssigner::ColocatedBufferSet>
   1417 BufferAssigner::MergeColocatedBufferSets(
   1418     const std::vector<ColocatedBufferSet>& colocated_buffer_sets,
   1419     const BufferLiveness& buffer_liveness,
   1420     const LogicalBuffer::SizeFunction& buffer_size) {
   1421   VLOG(1) << "colocation sets count before coalescing:"
   1422           << colocated_buffer_sets.size();
   1423 
   1424   // Returns true if the given buffer is for the entry parameter.
   1425   auto is_readonly_entry_parameter = [](const LogicalBuffer& buffer) {
   1426     auto* instruction = buffer.instruction();
   1427     auto* computation = instruction->parent();
   1428     auto* module = computation->parent();
   1429     return instruction->opcode() == HloOpcode::kParameter &&
   1430            computation == module->entry_computation() &&
   1431            !module->input_output_alias_config().ParameterHasAlias(
   1432                instruction->parameter_number(), buffer.index());
   1433   };
   1434 
   1435   std::vector<bool> set_can_be_merged(colocated_buffer_sets.size(), true);
   1436 
   1437   // Do not merge if one of the sets includes live outs, entry parameters or
   1438   // constants.
   1439   //
   1440   // Buffer liveness does not report the correct live range for entry
   1441   // parameter and live out buffers so we have to special case them here.  On
   1442   // backends that support constant buffer allocations, constant buffers are
   1443   // assigned globals in readonly storage so we can't merge colocated buffer
   1444   // sets containing constants with colocated buffer sets containing writing
   1445   // instructions or other constants.
   1446   //
   1447   // Moreover (on the CPU/GPU backends) the entry parameter buffers belong to
   1448   // the caller of the executable so we can't write to entry parameters
   1449   // either, and the argument for not merging constants also applies to entry
   1450   // parameters.
   1451   for (int64 i = 0; i < colocated_buffer_sets.size(); ++i) {
   1452     for (auto& buffer : colocated_buffer_sets[i]) {
   1453       if (buffer_liveness.MaybeLiveOut(*buffer) ||
   1454           is_readonly_entry_parameter(*buffer) ||
   1455           buffer->instruction()->opcode() == HloOpcode::kConstant) {
   1456         set_can_be_merged[i] = false;
   1457         break;
   1458       }
   1459     }
   1460   }
   1461 
   1462   // Returns true if the two colocated buffer sets (specified by their indices
   1463   // into the colocated_buffer_sets) can be merged into a single set.
   1464   auto cannot_merge_buffer_sets = [&colocated_buffer_sets, &buffer_liveness,
   1465                                    &buffer_size,
   1466                                    &set_can_be_merged](int64 i, int64 j) {
   1467     if (!set_can_be_merged[i] || !set_can_be_merged[j]) {
   1468       return true;
   1469     }
   1470 
   1471     // Colocated sets satisfy the invariant that all buffers within a set have
   1472     // the same size. That means we need to check whether the size is the same
   1473     // between the two sets, but also that it's enough to look at just one
   1474     // buffer within each set.
   1475     if (buffer_size(**colocated_buffer_sets[i].begin()) !=
   1476         buffer_size(**colocated_buffer_sets[j].begin())) {
   1477       return true;
   1478     }
   1479 
   1480     // Do not merge if some pair of buffers interferes with each other.
   1481     for (auto& buffer_a : colocated_buffer_sets[i]) {
   1482       for (auto& buffer_b : colocated_buffer_sets[j]) {
   1483         if (buffer_a->id() != buffer_b->id() &&
   1484             buffer_liveness.MayInterfere(*buffer_a, *buffer_b)) {
   1485           return true;
   1486         }
   1487       }
   1488     }
   1489 
   1490     return false;
   1491   };
   1492 
   1493   // Build the interference map among the colocated buffer sets (nodes), by
   1494   // adding an edge between any two nodes that cannot be merged into a single
   1495   // colocated buffer set.
   1496   std::vector<std::vector<int64>> interference_map(
   1497       colocated_buffer_sets.size());
   1498   for (int64 i = 0; i < colocated_buffer_sets.size(); ++i) {
   1499     for (int64 j = i + 1; j < colocated_buffer_sets.size(); ++j) {
   1500       if (cannot_merge_buffer_sets(i, j)) {
   1501         interference_map[i].push_back(j);
   1502         interference_map[j].push_back(i);
   1503       }
   1504     }
   1505   }
   1506 
   1507   // Assign a color to each colocation set in colocated_buffer_sets, such that
   1508   // the sets that can be merged are assigned with the same color.
   1509   auto assigned_colors = ColorInterferenceGraph(interference_map);
   1510 
   1511   // Merge the buffer sets with the same color.
   1512   CHECK(!assigned_colors.empty());
   1513   int64 num_sets =
   1514       *std::max_element(assigned_colors.begin(), assigned_colors.end()) + 1;
   1515   std::vector<ColocatedBufferSet> new_colocated_buffer_sets(num_sets);
   1516   for (int64 i = 0; i < colocated_buffer_sets.size(); ++i) {
   1517     const auto& buffer_set = colocated_buffer_sets[i];
   1518     new_colocated_buffer_sets[assigned_colors[i]].insert(buffer_set.begin(),
   1519                                                          buffer_set.end());
   1520   }
   1521 
   1522   VLOG(1) << "colocation sets count after coalescing:"
   1523           << colocated_buffer_sets.size();
   1524   return new_colocated_buffer_sets;
   1525 }
   1526 
   1527 // Builds sets of buffers in 'colocated_buffer_sets' which should be colocated
   1528 // in the same allocation (currently just supports kWhile, kCall, and
   1529 // kConditional and input output aliasing).
   1530 void BufferAssigner::BuildColocatedBufferSets(
   1531     const HloModule* module, const BufferLiveness& buffer_liveness,
   1532     const LogicalBuffer::SizeFunction& buffer_size,
   1533     std::vector<ColocatedBufferSet>* colocated_buffer_sets) {
   1534   const TuplePointsToAnalysis& points_to_analysis =
   1535       buffer_liveness.points_to_analysis();
   1536 
   1537   // Set up colocated buffer set for input and output.
   1538   VLOG(4) << "Input/Output Alias Config: ";
   1539   VLOG(4) << module->input_output_alias_config();
   1540   module->input_output_alias_config().ForEachAlias(
   1541       [&](const ShapeIndex& output_index,
   1542           const HloInputOutputAliasConfig::Alias& alias) {
   1543         std::vector<const LogicalBuffer*> colocated_set;
   1544         AddBufferToColocatedSet(module->entry_computation()->root_instruction(),
   1545                                 output_index, points_to_analysis,
   1546                                 &colocated_set);
   1547         AddBufferToColocatedSet(
   1548             module->entry_computation()->parameter_instruction(
   1549                 alias.parameter_number),
   1550             alias.parameter_index, points_to_analysis, &colocated_set);
   1551         AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
   1552       });
   1553 
   1554   for (const HloComputation* computation : module->MakeComputationPostOrder()) {
   1555     if (computation->IsFusionComputation()) {
   1556       continue;
   1557     }
   1558     for (const HloInstruction* instruction :
   1559          computation->MakeInstructionPostOrder()) {
   1560       const HloOpcode opcode = instruction->opcode();
   1561       if (opcode == HloOpcode::kWhile) {
   1562         const HloInstruction* while_hlo = instruction;
   1563         ShapeUtil::ForEachSubshape(
   1564             while_hlo->shape(),
   1565             [this, while_hlo, &points_to_analysis, buffer_size,
   1566              colocated_buffer_sets](const Shape& /*subshape*/,
   1567                                     const ShapeIndex& index) {
   1568               std::vector<const LogicalBuffer*> colocated_set;
   1569               // Add while.init.
   1570               AddBufferToColocatedSet(while_hlo->operand(0), index,
   1571                                       points_to_analysis, &colocated_set);
   1572               // Add while.result.
   1573               AddBufferToColocatedSet(while_hlo, index, points_to_analysis,
   1574                                       &colocated_set);
   1575               // Add while.cond.parameter.
   1576               AddBufferToColocatedSet(
   1577                   while_hlo->while_condition()->parameter_instruction(0), index,
   1578                   points_to_analysis, &colocated_set);
   1579               // Add while.body.parameter.
   1580               AddBufferToColocatedSet(
   1581                   while_hlo->while_body()->parameter_instruction(0), index,
   1582                   points_to_analysis, &colocated_set);
   1583               // Add while.body.root.
   1584               AddBufferToColocatedSet(
   1585                   while_hlo->while_body()->root_instruction(), index,
   1586                   points_to_analysis, &colocated_set);
   1587               AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
   1588             });
   1589       } else if (opcode == HloOpcode::kCall) {
   1590         const HloInstruction* call_hlo = instruction;
   1591         const HloComputation* callee = call_hlo->to_apply();
   1592         const HloInstruction* root_hlo = callee->root_instruction();
   1593         for (int64 i = 0; i < call_hlo->operand_count(); i++) {
   1594           const HloInstruction* call_param = callee->parameter_instruction(i);
   1595           const HloInstruction* call_operand = call_hlo->operand(i);
   1596           ShapeUtil::ForEachSubshape(
   1597               call_operand->shape(),
   1598               [&](const Shape& /*subshape*/, const ShapeIndex& index) {
   1599                 std::vector<const LogicalBuffer*> colocated_set;
   1600                 AddBufferToColocatedSet(call_param, index, points_to_analysis,
   1601                                         &colocated_set);
   1602                 AddBufferToColocatedSet(call_operand, index, points_to_analysis,
   1603                                         &colocated_set);
   1604                 AddSetToColocatedBufferSets(colocated_set,
   1605                                             colocated_buffer_sets);
   1606               });
   1607         }
   1608         ShapeUtil::ForEachSubshape(
   1609             call_hlo->shape(),
   1610             [this, call_hlo, root_hlo, &points_to_analysis,
   1611              colocated_buffer_sets](const Shape& /*subshape*/,
   1612                                     const ShapeIndex& index) {
   1613               std::vector<const LogicalBuffer*> colocated_set;
   1614               // Add call.result.
   1615               AddBufferToColocatedSet(call_hlo, index, points_to_analysis,
   1616                                       &colocated_set);
   1617               // Add call.subcomputation.root.
   1618               AddBufferToColocatedSet(root_hlo, index, points_to_analysis,
   1619                                       &colocated_set);
   1620               AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
   1621             });
   1622       } else if (opcode == HloOpcode::kConditional) {
   1623         const HloInstruction* conditional = instruction;
   1624         ShapeUtil::ForEachSubshape(
   1625             conditional->shape(),
   1626             [this, conditional, &points_to_analysis, colocated_buffer_sets](
   1627                 const Shape& /*subshape*/, const ShapeIndex& index) {
   1628               std::vector<const LogicalBuffer*> colocated_set;
   1629               // Add cond.result.
   1630               AddBufferToColocatedSet(conditional, index, points_to_analysis,
   1631                                       &colocated_set);
   1632               for (int j = 0; j < conditional->branch_count(); ++j) {
   1633                 // Add each cond.branch_computation[j].root.
   1634                 AddBufferToColocatedSet(
   1635                     conditional->branch_computation(j)->root_instruction(),
   1636                     index, points_to_analysis, &colocated_set);
   1637               }
   1638               AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
   1639             });
   1640 
   1641         for (int j = 0; j < conditional->branch_count(); ++j) {
   1642           // Add branch_operand[j] (which is operand[j+1]) and
   1643           // cond.branch_computation[j].parameter(0) as a colocated
   1644           // buffer set. Note that this has to be done for each subshape in the
   1645           // branch_operand of the case.
   1646           ShapeUtil::ForEachSubshape(
   1647               conditional->operand(j + 1)->shape(),
   1648               [this, j, conditional, &points_to_analysis,
   1649                colocated_buffer_sets](const Shape& /*subshape*/,
   1650                                       const ShapeIndex& index) {
   1651                 std::vector<const LogicalBuffer*> branch_set;
   1652                 // Add cond.operand[j+1].
   1653                 AddBufferToColocatedSet(conditional->operand(j + 1), index,
   1654                                         points_to_analysis, &branch_set);
   1655                 // Add cond.branch_computation[j].parameter_instruction(0).
   1656                 AddBufferToColocatedSet(
   1657                     conditional->branch_computation(j)->parameter_instruction(
   1658                         0),
   1659                     index, points_to_analysis, &branch_set);
   1660                 AddSetToColocatedBufferSets(branch_set, colocated_buffer_sets);
   1661               });
   1662         }
   1663       }
   1664     }
   1665   }
   1666 
   1667   if (colocated_buffer_sets->empty()) {
   1668     return;
   1669   }
   1670 
   1671   int64 i = 0;
   1672   for (const auto& colocated_set : *colocated_buffer_sets) {
   1673     VLOG(4) << "Colocated set " << i++ << ":";
   1674     for (const auto& buffer : colocated_set) {
   1675       VLOG(4) << "  " << buffer->ToString();
   1676     }
   1677   }
   1678   // Try to find more coalescing opportunities among the colocated buffer sets.
   1679   //
   1680   // TODO(b/32491382): We should be able to remove this by using the
   1681   // module-level liveness analysis, which would let us directly detect buffer
   1682   // sharing opportunities between the while instruction buffer and the buffers
   1683   // from the predicate and body computation, as well as sharing across
   1684   // different while instructions.
   1685   std::vector<ColocatedBufferSet> new_colocated_buffer_sets =
   1686       MergeColocatedBufferSets(*colocated_buffer_sets, buffer_liveness,
   1687                                buffer_size);
   1688   std::swap(*colocated_buffer_sets, new_colocated_buffer_sets);
   1689 }
   1690 
   1691 // Assigns all colocated buffer sets in 'colocated_buffer_sets' to the same
   1692 // allocation in 'assignment'.
   1693 void BufferAssigner::AssignColocatedBufferSets(
   1694     const std::vector<ColocatedBufferSet>& colocated_buffer_sets,
   1695     BufferAssignment* assignment,
   1696     flat_hash_set<const LogicalBuffer*>* colocated_buffers,
   1697     flat_hash_set<BufferAllocation::Index>* colocated_allocations) {
   1698   for (const ColocatedBufferSet& colocated_buffer_set : colocated_buffer_sets) {
   1699     BufferAllocation* allocation = nullptr;
   1700     // Set 'entry_parameter_number' and 'entry_parameter_shape_idx' if entry
   1701     // param in 'colocated_buffer_set'.
   1702     int64 entry_parameter_number = -1;
   1703     const ShapeIndex* entry_parameter_shape_idx = nullptr;
   1704     bool is_constant = false;
   1705     for (const LogicalBuffer* buffer : colocated_buffer_set) {
   1706       const HloInstruction* instruction = buffer->instruction();
   1707       const HloComputation* computation = instruction->parent();
   1708       if (instruction->opcode() == HloOpcode::kParameter &&
   1709           computation == computation->parent()->entry_computation()) {
   1710         entry_parameter_number = instruction->parameter_number();
   1711         entry_parameter_shape_idx = &buffer->index();
   1712       } else if (instruction->opcode() == HloOpcode::kConstant) {
   1713         is_constant = true;
   1714       }
   1715     }
   1716 
   1717     CHECK(!is_constant || entry_parameter_number == -1)
   1718         << "Copy insertion should have inserted copies to prevent this.";
   1719 
   1720     for (const LogicalBuffer* buffer : colocated_buffer_set) {
   1721       const int64 buffer_size = assignment->buffer_size_(*buffer);
   1722       if (allocation == nullptr) {
   1723         // TODO(b/32491382) Avoid current trivial solution of using new
   1724         // allocations for each colocated buffer set. When liveness has
   1725         // module-level scope, we can allow buffers to be shared across
   1726         // computations (in some cases).
   1727         allocation = assignment->NewAllocation(*buffer, buffer_size);
   1728         if (is_constant) {
   1729           allocation->set_constant(true);
   1730         }
   1731         colocated_allocations->insert(allocation->index());
   1732       } else {
   1733         CHECK_EQ(buffer_size, allocation->size())
   1734             << "Buffer: " << *buffer << " size mismatch in colocated buffer "
   1735             << "allocation: " << *allocation;
   1736         assignment->AddAssignment(allocation, *buffer, /*offset=*/0,
   1737                                   buffer_size);
   1738       }
   1739       colocated_buffers->insert(buffer);
   1740     }
   1741 
   1742     // If an allocation contains a parameter, set corresponding fields.
   1743     if (entry_parameter_number >= 0) {
   1744       bool parameter_has_alias =
   1745           assignment->module().input_output_alias_config().ParameterHasAlias(
   1746               entry_parameter_number, *entry_parameter_shape_idx);
   1747       allocation->set_entry_computation_parameter(entry_parameter_number,
   1748                                                   *entry_parameter_shape_idx,
   1749                                                   parameter_has_alias);
   1750     }
   1751   }
   1752 }
   1753 
   1754 StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
   1755     const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
   1756     LogicalBuffer::SizeFunction buffer_size,
   1757     LogicalBuffer::AlignmentFunction color_alignment) {
   1758   TF_ASSIGN_OR_RETURN(std::unique_ptr<BufferLiveness> liveness,
   1759                       BufferLiveness::Run(module, std::move(hlo_ordering)));
   1760 
   1761   VLOG(1) << "Assigning buffers to module " << module->name();
   1762   XLA_VLOG_LINES(2, module->ToString());
   1763   XLA_VLOG_LINES(3, liveness->ToString());
   1764   XLA_VLOG_LINES(3, liveness->points_to_analysis().ToString());
   1765 
   1766   // Can't use absl::make_unique because BufferAssignment constructor is
   1767   // private.
   1768   std::unique_ptr<BufferAssignment> assignment(
   1769       new BufferAssignment(module, std::move(liveness), std::move(buffer_size),
   1770                            std::move(color_alignment)));
   1771 
   1772   // Assign buffers with the tightest constraints first (colocated buffer sets).
   1773   // Once b/32491382 enables module-level liveness analysis, we may be able
   1774   // to assign colocated buffers (or at least reuse their allocation for
   1775   // buffers outside of the set) in AssignBuffersForComputation.
   1776   flat_hash_set<const LogicalBuffer*> colocated_buffers;
   1777   flat_hash_set<BufferAllocation::Index> colocated_allocations;
   1778   std::vector<ColocatedBufferSet> colocated_buffer_sets;
   1779   BuildColocatedBufferSets(module, assignment->liveness(),
   1780                            assignment->buffer_size_, &colocated_buffer_sets);
   1781   TF_RETURN_IF_ERROR(colorer_(assignment->liveness()));
   1782   VLOG(3) << "After coloring:";
   1783   XLA_VLOG_LINES(3, assignment->points_to_analysis().ToString());
   1784 
   1785   AssignColocatedBufferSets(colocated_buffer_sets, assignment.get(),
   1786                             &colocated_buffers, &colocated_allocations);
   1787 
   1788   std::vector<const HloComputation*> thread_local_computations;
   1789   std::vector<const HloComputation*> global_computations;
   1790   TF_RETURN_IF_ERROR(GatherComputationsByAllocationType(
   1791       module, &thread_local_computations, &global_computations));
   1792 
   1793   // First assign buffers for global computatations. Temporary buffers for
   1794   // sequential computations are collected in 'buffers_to_assign_sequentially'.
   1795   flat_hash_map<const HloComputation*, flat_hash_set<const LogicalBuffer*>>
   1796       buffers_to_assign_sequentially;
   1797   for (auto* computation : global_computations) {
   1798     TF_RETURN_IF_ERROR(AssignBuffersForComputation(
   1799         computation,
   1800         /*is_thread_local=*/false, colocated_buffers, colocated_allocations,
   1801         &buffers_to_assign_sequentially, assignment.get()));
   1802   }
   1803   // Assign buffers with sequential ordering, if any. If all global computations
   1804   // are sequential, we can run heap simuation on the whole module, which
   1805   // reduces memory usage.
   1806   const bool run_whole_module_heap_simulation =
   1807       buffers_to_assign_sequentially.size() == global_computations.size();
   1808   TF_RETURN_IF_ERROR(AssignBuffersWithSequentialOrdering(
   1809       buffers_to_assign_sequentially, run_whole_module_heap_simulation,
   1810       assignment.get()));
   1811 
   1812   // Now assign buffers for thread-local computations. All LogicalBuffers get
   1813   // their own BufferAllocation.
   1814   for (auto* computation : thread_local_computations) {
   1815     TF_RET_CHECK(computation != module->entry_computation());
   1816     if (computation->IsFusionComputation()) {
   1817       continue;
   1818     }
   1819     TF_RETURN_IF_ERROR(AssignBuffersForComputation(
   1820         computation,
   1821         /*is_thread_local=*/true, colocated_buffers, colocated_allocations,
   1822         /*buffers_to_assign_sequentially=*/nullptr, assignment.get()));
   1823   }
   1824 
   1825   // Mark all buffers which may be live out of the entry computation as
   1826   // "liveout".
   1827   for (const LogicalBuffer* buffer :
   1828        assignment->liveness().maybe_live_out_buffers()) {
   1829     VLOG(3) << "maybe_live_out LogicalBuffer: " << *buffer;
   1830     if (assignment->HasAllocation(*buffer)) {
   1831       BufferAllocation* alloc =
   1832           assignment->GetMutableAssignedAllocation(*buffer);
   1833       alloc->set_maybe_live_out(true);
   1834       VLOG(3) << "maybe_live_out BufferAllocation: " << *alloc;
   1835     }
   1836   }
   1837 
   1838   // Combines allocations of temporary buffers into one big BufferAllocation.
   1839   // This can only be performed after all buffers have been assigned, and after
   1840   // maybe_live_out is marked, since it is used to determine whether an
   1841   // allocation contains temporary buffers or not.
   1842   assignment->CombineTempAllocations();
   1843 
   1844   XLA_VLOG_LINES(2, assignment->ToString());
   1845   TF_RETURN_IF_ERROR(assignment->ComputeSummaryStats());
   1846   XLA_VLOG_LINES(1, assignment->GetStats().ToString());
   1847   return std::move(assignment);
   1848 }
   1849 
   1850 }  // namespace xla
   1851