Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_ASSIGNMENT_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_ASSIGNMENT_H_
     18 
     19 #include <functional>
     20 #include <iosfwd>
     21 #include <memory>
     22 #include <string>
     23 #include <vector>
     24 
     25 #include "absl/container/flat_hash_map.h"
     26 #include "absl/container/flat_hash_set.h"
     27 #include "absl/types/span.h"
     28 #include "tensorflow/compiler/xla/service/buffer_liveness.h"
     29 #include "tensorflow/compiler/xla/service/heap_simulator.h"
     30 #include "tensorflow/compiler/xla/service/hlo.pb.h"
     31 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     32 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     33 #include "tensorflow/compiler/xla/service/hlo_module.h"
     34 #include "tensorflow/compiler/xla/service/logical_buffer.h"
     35 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
     36 #include "tensorflow/compiler/xla/statusor.h"
     37 #include "tensorflow/compiler/xla/types.h"
     38 #include "tensorflow/core/platform/logging.h"
     39 #include "tensorflow/core/platform/macros.h"
     40 #include "tensorflow/core/platform/types.h"
     41 
     42 namespace xla {
     43 
     44 // Walk the call graph of the HLO module and place each computation into either
     45 // thread_local_computations or global_computations depending upon whether the
     46 // computation requires thread-local allocations or global allocations. The
     47 // elements in thread_local_computations and global_computations are in post
     48 // order (if computation A has an instruction which calls computation B, then A
     49 // will appear after B in the vector).
     50 Status GatherComputationsByAllocationType(
     51     const HloModule* module,
     52     std::vector<const HloComputation*>* thread_local_computations,
     53     std::vector<const HloComputation*>* global_computations);
     54 
     55 // This class abstracts an allocation of contiguous memory which can hold the
     56 // values described by LogicalBuffers. Each LogicalBuffer occupies a sub-range
     57 // of the allocation, represented by a Slice. A single BufferAllocation may hold
     58 // LogicalBuffers with disjoint liveness, which may have overlapping Slices. A
     59 // single BufferAllocation may also hold LogicalBuffers with overlapping
     60 // liveness, which must have disjoint Slices.
     61 //
     62 // The abstraction includes information required by the backends for allocation,
     63 // use, and deallocation of the buffer. This includes the LogicalBuffers which
     64 // are held in this allocation through the execution of the computation.
     65 class BufferAllocation {
     66  public:
     67   // Holds a unique identifier for each allocation. Values are assigned
     68   // contiguously and can be used as array indexes.
     69   using Index = int64;
     70 
     71   BufferAllocation(Index index, int64 size, LogicalBuffer::Color color)
     72       : index_(index), size_(size), color_(color) {}
     73   ~BufferAllocation() {}
     74 
     75   // Returns the index of this allocation.
     76   Index index() const { return index_; }
     77 
     78   // Whether this allocation is used in a parallel calling context such as
     79   // inside of a map or reduce computation. Such allocations need to be thread
     80   // local.
     81   bool is_thread_local() const { return is_thread_local_; }
     82   void set_is_thread_local(bool is_thread_local) {
     83     is_thread_local_ = is_thread_local;
     84   }
     85 
     86   // Whether this allocation can be used by more than one logical buffer.
     87   bool is_reusable() const {
     88     // We do not reuse thread-local buffers for now, because they are
     89     // dynamically allocated and their lifetimes are hard to compute.
     90     //
     91     // TODO(b/34669761): Don't reuse tuple buffers because the GPU backend
     92     // assumes longer buffer liveness than indicated by the analysis.
     93     return !is_thread_local() && !is_tuple();
     94   }
     95 
     96   // Whether this allocation is readonly i.e. backed by memory we cannot write
     97   // to.
     98   bool is_readonly() const {
     99     // Entry parameters are generally readonly, except when they are aliased
    100     // with any output.
    101     return (is_entry_computation_parameter() &&
    102             !is_parameter_aliased_with_output_) ||
    103            is_constant();
    104   }
    105 
    106   bool is_tuple() const { return is_tuple_; }
    107   void set_is_tuple(bool is_tuple) { is_tuple_ = is_tuple; }
    108 
    109   // Whether this allocation holds a LogicalBuffer from a parameter of the entry
    110   // computation. These buffers have lifetimes which may be longer than the
    111   // XLA computation.
    112   bool is_entry_computation_parameter() const {
    113     return is_entry_computation_parameter_;
    114   }
    115 
    116   // Whether this allocation holds a constant.  On the CPU and GPU backends
    117   // constant allocations are not allocated dynamically, instead we resolve
    118   // references to these buffer allocations to a global in the readonly section
    119   // of the binary.
    120   bool is_constant() const { return is_constant_; }
    121 
    122   // If this allocation holds a Buffer from a parameter of the entry
    123   // computation, this methods returns the parameter number. CHECKs otherwise.
    124   int64 parameter_number() const {
    125     CHECK(is_entry_computation_parameter_);
    126     return parameter_number_;
    127   }
    128 
    129   // If this allocation is for a parameter of the entry computation, this
    130   // function returns which subshape of the parameter the allocation is for.
    131   const ShapeIndex& param_shape_index() const {
    132     CHECK(is_entry_computation_parameter_);
    133     return param_shape_index_;
    134   }
    135 
    136   // Returns whether this allocation is assigned a LogicalBuffer which may
    137   // be live out of the entry computation.
    138   bool maybe_live_out() const { return maybe_live_out_; }
    139 
    140   // Returns the size of the allocation. Necessarily this must be at least as
    141   // large as any LogicalBuffer assigned to this allocation.
    142   int64 size() const { return size_; }
    143 
    144   // Returns the color of the allocation. Only logical buffers with a matching
    145   // color can reside in this allocation.
    146   LogicalBuffer::Color color() const { return color_; }
    147 
    148   struct OffsetSize {
    149     int64 offset = 0;
    150     int64 size = 0;
    151   };
    152 
    153   // Access to the logical buffers assigned to this allocation, and their
    154   // associated logical offsets and sizes.
    155   const absl::flat_hash_map<const LogicalBuffer*, OffsetSize>&
    156   assigned_buffers() const {
    157     return assigned_buffers_;
    158   }
    159 
    160   // A Slice represents a contiguous portion of a memory allocation. It is used
    161   // to identify the memory range that a LogicalBuffer corresponds to.
    162   class Slice {
    163    public:
    164     Slice() {}
    165     Slice(const BufferAllocation* allocation, int64 offset, int64 size)
    166         : allocation_(allocation), offset_(offset), size_(size) {}
    167 
    168     const BufferAllocation* allocation() const { return allocation_; }
    169     Index index() const { return allocation_->index(); }
    170     int64 offset() const { return offset_; }
    171     int64 size() const { return size_; }
    172 
    173     bool operator==(const Slice& other) const {
    174       return index() == other.index() && offset_ == other.offset_ &&
    175              size_ == other.size_;
    176     }
    177     bool operator!=(const Slice& other) const { return !(*this == other); }
    178     bool operator<(const Slice& other) const {
    179       if (index() != other.index()) return index() < other.index();
    180       if (offset_ != other.offset_) return offset_ < other.offset_;
    181       return size_ < other.size_;
    182     }
    183 
    184     // Returns true iff this slice's memory range has a non-empty intersection
    185     // with the other slice's memory range.
    186     bool OverlapsWith(const Slice& other) const {
    187       const int64 end = offset_ + size_;
    188       const int64 other_end = other.offset_ + other.size_;
    189       return index() == other.index() && offset_ < other_end &&
    190              end > other.offset_;
    191     }
    192 
    193     template <typename H>
    194     friend H AbslHashValue(H h, const Slice& s) {
    195       return H::combine(std::move(h), s.index(), s.offset(), s.size());
    196     }
    197 
    198     string ToString() const;
    199 
    200    private:
    201     const BufferAllocation* allocation_ = nullptr;
    202     int64 offset_ = 0;
    203     int64 size_ = 0;
    204   };
    205 
    206   // GetSlice returns the Slice of contiguous memory that holds the value
    207   // described by the given 'buffer'.
    208   // REQUIRES: 'buffer' must be assigned to this allocation.
    209   Slice GetSlice(const LogicalBuffer& buffer) const;
    210 
    211   string ToString() const;
    212   BufferAllocationProto ToProto() const;
    213 
    214   // Whether the buffer is a parameter to or live out of the entry computation.
    215   bool IsInputOrOutput() const {
    216     return is_entry_computation_parameter() || maybe_live_out();
    217   }
    218 
    219   // Whether the buffer is a temporary buffer allocated before
    220   // Executable::ExecuteOnStream.
    221   bool IsPreallocatedTempBuffer() const {
    222     // Parameters do not need temporary buffers.
    223     return !is_entry_computation_parameter() &&
    224            // LogicalBuffers that maybe pointed to by the output should live out
    225            // of the computation.
    226            !maybe_live_out() &&
    227            // Thread-local buffers are allocated using `alloca`s.
    228            !is_thread_local() &&
    229            // Constant buffers are allocated as global values.
    230            !is_constant();
    231   }
    232 
    233   // Add a heap trace which was used to assign slices to logical buffers in this
    234   // allocation. A single BufferAllocation may include multiple heap traces
    235   // in the case of the temporary block where there is a heap trace per
    236   // computation.
    237   void AddHeapTrace(const HeapSimulatorTrace& heap_trace) {
    238     heap_traces_.push_back(heap_trace);
    239   }
    240 
    241   // Return the set of heap traces used to assign slices to logical buffers in
    242   // this allocation.
    243   const std::vector<HeapSimulatorTrace> HeapTraces() const {
    244     return heap_traces_;
    245   }
    246 
    247   // Returns the LogicalBuffers which are live at the point of peak memory usage
    248   // for this allocation. The point of peak memory usage is the point at which
    249   // the total size of all live logical buffers is maximal. If peak memory is
    250   // reached at multiple points, the set of logical buffers live at the earliest
    251   // maximal point is returned. The vector is stabily sorted by
    252   // LogicalBuffer::Index.
    253   const std::vector<const LogicalBuffer*>& PeakMemoryLogicalBuffers() const {
    254     return peak_buffers_;
    255   }
    256 
    257   // Get the number of bytes lost to fragmentation. This is equal to the
    258   // difference between the size of the allocation and the size of the maximal
    259   // live set.
    260   int64 fragmentation_bytes() const { return fragmentation_bytes_; }
    261 
    262   bool operator==(const BufferAllocation& other) const {
    263     return index_ == other.index_;
    264   }
    265   bool operator!=(const BufferAllocation& other) const {
    266     return !(*this == other);
    267   }
    268   bool operator<(const BufferAllocation& other) const {
    269     return index() < other.index();
    270   }
    271 
    272  private:
    273   // Only BufferAssigner and BufferAssignment can modify BufferAllocation.
    274   friend class BufferAssigner;
    275   friend class BufferAssignment;
    276 
    277   // Adds a LogicalBuffer to the set assigned to this buffer.
    278   void AddAssignment(const LogicalBuffer& buffer, int64 offset, int64 size);
    279 
    280   void set_entry_computation_parameter(int64 parameter_number,
    281                                        ShapeIndex param_shape_index,
    282                                        bool parameter_aliased_with_output) {
    283     is_entry_computation_parameter_ = true;
    284     is_parameter_aliased_with_output_ = parameter_aliased_with_output;
    285     parameter_number_ = parameter_number;
    286     param_shape_index_ = std::move(param_shape_index);
    287   }
    288 
    289   void set_constant(bool is_constant) { is_constant_ = is_constant; }
    290   void set_maybe_live_out(bool value) { maybe_live_out_ = value; }
    291   void set_index(Index index) { index_ = index; }
    292   void set_size(int64 size) { size_ = size; }
    293 
    294   // The index of the allocation in the BufferAssignment.
    295   Index index_;
    296 
    297   // Size of the allocation in bytes.
    298   int64 size_;
    299 
    300   // Whether this buffer needs to be thread-local.
    301   bool is_thread_local_ = false;
    302 
    303   // Whether this buffer holds a tuple.
    304   bool is_tuple_ = false;
    305 
    306   // Color of the allocation.
    307   LogicalBuffer::Color color_;
    308 
    309   // Whether this allocation holds an entry computation parameter. Entry
    310   // computation parameters are special be cause they have lifetimes which may
    311   // outlast the computation.
    312   bool is_entry_computation_parameter_ = false;
    313 
    314   // Whether this entry computation parameter is aliased with output.
    315   bool is_parameter_aliased_with_output_ = false;
    316 
    317   // If this allocation holds an entry computation parameter, this field
    318   // indicates the index (starting from 0) of the parameter.
    319   int64 parameter_number_ = 0;
    320 
    321   // If this buffer is for an entry computation parameter, which subshape of the
    322   // parameter is it for?
    323   ShapeIndex param_shape_index_;
    324 
    325   // Whether the allocation contains a LogicalBuffer which may be live-out of
    326   // the entry computation. Note that this flag is conservatively computed by
    327   // TuplePointsToAnalysis.  That is, an allocation marked `maybe_live_out_`
    328   // might not actually escape.
    329   bool maybe_live_out_ = false;
    330 
    331   // See comment on the is_constant() accessor.
    332   bool is_constant_ = false;
    333 
    334   // Mapping from the set of buffers assigned to this allocation to their
    335   // logical offsets and sizes.
    336   absl::flat_hash_map<const LogicalBuffer*, OffsetSize> assigned_buffers_;
    337 
    338   int64 fragmentation_bytes_ = 0;
    339   std::vector<HeapSimulatorTrace> heap_traces_;
    340 
    341   // Set of buffers live at the point of peak memory usage for this allocation.
    342   std::vector<const LogicalBuffer*> peak_buffers_;
    343 };
    344 
    345 // Add stream operators for nicer output of CHECK/RET_CHECK failures.
    346 std::ostream& operator<<(std::ostream& out, const BufferAllocation& s);
    347 std::ostream& operator<<(std::ostream& out, const BufferAllocation::Slice& s);
    348 
    349 // This class encapsulates an assignment of the LogicalBuffers in an XLA
    350 // module to a set of BufferAllocations.
    351 class BufferAssignment {
    352  public:
    353   // Returns the vector containing all buffer allocations in this assignment.
    354   const std::vector<BufferAllocation>& Allocations() const {
    355     return allocations_;
    356   }
    357 
    358   // Returns the total size allocation holding all temporary buffers.
    359   int64 temp_allocation_total_size() const {
    360     return temp_allocation_total_size_;
    361   }
    362 
    363   // Returns whether the given buffer has been assigned an allocation.
    364   bool HasAllocation(const LogicalBuffer& buffer) const;
    365 
    366   // Returns the allocation that a particular LogicalBuffer has been assigned
    367   // to. CHECKs if buffer has not been assigned an allocation.
    368   const BufferAllocation& GetAssignedAllocation(
    369       const LogicalBuffer& buffer) const;
    370 
    371   // Returns the allocation with the given index. CHECKs if no allocation exists
    372   // with the given index.
    373   const BufferAllocation& GetAllocation(BufferAllocation::Index index) const;
    374 
    375   // Returns the allocation with the given instruction and shape index. nullptr
    376   // if no allocation exists.
    377   const BufferAllocation* GetInstructionAllocation(
    378       const HloInstruction* hlo, const ShapeIndex& shape_index) const;
    379 
    380   // Builds and returns a vector containing the slices which might contain the
    381   // subvalue at the given index of given instruction.
    382   std::set<BufferAllocation::Slice> GetAllSlices(
    383       const HloInstruction* instruction, const ShapeIndex& index) const;
    384 
    385   // Convenience function which returns whether the buffer of the
    386   // instruction at the given index is assigned an allocation.
    387   bool HasAllocationAt(const HloInstruction* instruction,
    388                        const ShapeIndex& index) const;
    389 
    390   // Convenience function which returns whether the top-level buffer of the
    391   // instruction (index == {}) is assigned an allocation.
    392   bool HasTopLevelAllocation(const HloInstruction* instruction) const;
    393 
    394   // Convenience function which returns the unique slice containing the buffer
    395   // at the given index of the given instruction. If a slice is not assigned or
    396   // the slice cannot be determined at compile time then an error is returned.
    397   StatusOr<BufferAllocation::Slice> GetUniqueSlice(
    398       const HloInstruction* instruction, const ShapeIndex& index) const;
    399   // Like GetUniqueSlice but fixes the index to the top-level of the shape
    400   // (index = {}).
    401   StatusOr<BufferAllocation::Slice> GetUniqueTopLevelSlice(
    402       const HloInstruction* instruction) const;
    403   // Like GetUniqueTopLevelSlice but returns the slice for the output of the
    404   // entry computation of the HLO module (ie, the result of the XLA
    405   // computation).
    406   StatusOr<BufferAllocation::Slice> GetUniqueTopLevelOutputSlice() const;
    407 
    408   // Returns the set LogicalBuffers which may be the source of the value at the
    409   // given index and instruction.
    410   const PointsToSet::BufferList& GetSourceBuffers(
    411       const HloInstruction* instruction, const ShapeIndex& index) const {
    412     return GetPointsToSet(instruction).element(index);
    413   }
    414 
    415   // Returns true if 'hlo_a{shape_index_a}' and 'hlo_b{shape_index_b}'
    416   // share the same BufferAllocation::Slice.
    417   // Returns false otherwise.
    418   // REQUIRES: BufferAssignment assigned allocations to both instructions.
    419   bool SharesSliceAtIndex(const HloInstruction* hlo_a,
    420                           const ShapeIndex& shape_index_a,
    421                           const HloInstruction* hlo_b,
    422                           const ShapeIndex& shape_index_b) const;
    423 
    424   // Returns true if the top-level buffers of hlo_a and hlo_b are the same.
    425   // REQUIRES: HasTopLevelAllocation(hlo_a) && HasTopLevelAllocation(hlo_b).
    426   bool SharesTopLevelSlice(const HloInstruction* hlo_a,
    427                            const HloInstruction* hlo_b) const {
    428     return SharesSliceAtIndex(hlo_a, {}, hlo_b, {});
    429   }
    430 
    431   // Returns true if hlo_a and hlo_b both have at least one buffer assigned for
    432   // their top-level and each of their nested shape indices, and if hlo_a's
    433   // buffers are all different from hlo_b's buffers.
    434   bool HaveDisjointSlices(const HloInstruction* hlo_a,
    435                           const HloInstruction* hlo_b) const;
    436 
    437   // Returns the underlying points-to analysis used for this assignment.
    438   const TuplePointsToAnalysis& points_to_analysis() const {
    439     return liveness_->points_to_analysis();
    440   }
    441 
    442   // Returns the BufferLiveness object used to construct this assignment.
    443   const BufferLiveness& liveness() const { return *liveness_; }
    444 
    445   string ToString() const;
    446   BufferAssignmentProto ToProto() const;
    447 
    448   // Statistics for the assignment.  Values initialized to -1 are not always
    449   // collected; fragmentation is only collected for instructions that have a
    450   // sequential total ordering.
    451   struct Stats {
    452     int64 parameter_allocation_count = 0;
    453     int64 parameter_allocation_bytes = 0;
    454     int64 constant_allocation_count = 0;
    455     int64 constant_allocation_bytes = 0;
    456     int64 maybe_live_out_allocation_count = 0;
    457     int64 maybe_live_out_allocation_bytes = 0;
    458     int64 preallocated_temp_allocation_count = 0;
    459     int64 preallocated_temp_allocation_bytes = 0;
    460     int64 preallocated_temp_fragmentation_bytes = -1;
    461     int64 total_allocation_count = 0;
    462     int64 total_allocation_bytes = 0;
    463     int64 total_fragmentation_bytes = -1;
    464 
    465     string ToString() const;
    466   };
    467   const Stats& GetStats() const { return stats_; }
    468 
    469  private:
    470   // Only BufferAssigner can build or modify BufferAssignments.
    471   friend class BufferAssigner;
    472 
    473   BufferAssignment(const HloModule* module,
    474                    std::unique_ptr<BufferLiveness> liveness,
    475                    LogicalBuffer::SizeFunction buffer_size,
    476                    LogicalBuffer::AlignmentFunction color_alignment)
    477       : module_(module),
    478         liveness_(std::move(liveness)),
    479         buffer_size_(std::move(buffer_size)),
    480         color_alignment_(std::move(color_alignment)) {}
    481 
    482   // Creates and returns a new BufferAllocation, with no assigned
    483   // LogicalBuffers. Ownership is maintained internally.
    484   BufferAllocation* NewEmptyAllocation(int64 size, LogicalBuffer::Color color);
    485 
    486   // Helper that calls NewEmptyAllocation and AddAssignment in one call,
    487   // creating an allocation containing a single LogicalBuffer.
    488   BufferAllocation* NewAllocation(const LogicalBuffer& buffer, int64 size);
    489 
    490   // Adds a LogicalBuffer to the set assigned to the given allocation.
    491   void AddAssignment(BufferAllocation* allocation, const LogicalBuffer& buffer,
    492                      int64 offset, int64 size);
    493 
    494   // Returns the HloModule used to construct this assignment.
    495   const HloModule& module() const { return *module_; }
    496 
    497   // Convenience function which returns the PointsToSet for the given
    498   // instruction. Extracted from the liveness object.
    499   const PointsToSet& GetPointsToSet(const HloInstruction* instruction) const;
    500 
    501   // Mutable accessors for allocations.
    502   BufferAllocation* GetMutableAssignedAllocation(const LogicalBuffer& buffer);
    503   BufferAllocation* GetMutableAllocation(BufferAllocation::Index index);
    504 
    505   // Combines allocations of temporary buffers into one big BufferAllocation.
    506   void CombineTempAllocations();
    507 
    508   // Computes stats for the assignment, to be retrieved by GetStats.
    509   Status ComputeSummaryStats();
    510 
    511   // The vector of buffer allocations. Indexed by BufferAllocation::Index.
    512   std::vector<BufferAllocation> allocations_;
    513 
    514   // The total size of all temporary buffers.
    515   int64 temp_allocation_total_size_ = 0;
    516 
    517   // Maps Buffers to the index of the BufferAllocation which holds the buffer.
    518   absl::flat_hash_map<const LogicalBuffer*, BufferAllocation::Index>
    519       allocation_index_for_buffer_;
    520 
    521   const HloModule* module_;
    522   const std::unique_ptr<BufferLiveness> liveness_;
    523 
    524   // Function which returns the buffer size for a given logical buffer (shape).
    525   LogicalBuffer::SizeFunction buffer_size_;
    526 
    527   // Function which returns the alignment for a given logical buffer color.
    528   LogicalBuffer::AlignmentFunction color_alignment_;
    529 
    530   Stats stats_;
    531 
    532   TF_DISALLOW_COPY_AND_ASSIGN(BufferAssignment);
    533 };
    534 
    535 // A class which constructs a buffer assignment.
    536 class BufferAssigner {
    537  public:
    538   // Returns false if a buffer cannot be assigned to given allocation.
    539   using ReuseAllocationFunction = std::function<bool(
    540       const BufferAssignment& assignment, const BufferAllocation& alloc,
    541       const LogicalBuffer& buffer)>;
    542 
    543   // Build and return a BufferAssignment for the given module. The given
    544   // HloOrdering is used to determine buffer liveness. buffer_size and
    545   // color_alignment are functions which returns the size and alignment of a
    546   // LogicalBuffer.  allow_input_output_aliasing specifies whether input buffer
    547   // are allowed to be reused as outbut buffers by the client code.
    548   static StatusOr<std::unique_ptr<BufferAssignment>> Run(
    549       const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
    550       LogicalBuffer::SizeFunction buffer_size,
    551       LogicalBuffer::AlignmentFunction color_alignment,
    552       bool allow_input_output_aliasing = false,
    553       bool allocate_buffers_for_constants = false,
    554       BufferLiveness::Colorer colorer = BufferLiveness::DefaultColorer(),
    555       ReuseAllocationFunction reuse_checker = nullptr);
    556 
    557  private:
    558   BufferAssigner(bool allocate_buffers_for_constants,
    559                  BufferLiveness::Colorer colorer,
    560                  ReuseAllocationFunction reuse_checker)
    561       : allocate_buffers_for_constants_(allocate_buffers_for_constants),
    562         colorer_(colorer),
    563         reuse_checker_(reuse_checker) {}
    564   virtual ~BufferAssigner() = default;
    565 
    566   // Create a buffer assignment.
    567   StatusOr<std::unique_ptr<BufferAssignment>> CreateAssignment(
    568       const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
    569       LogicalBuffer::SizeFunction buffer_size,
    570       LogicalBuffer::AlignmentFunction color_alignment);
    571 
    572   // Assigns buffers to the instructions in the given computation. "assignment"
    573   // is modified to reflect the new buffer assignments. If is_thread_local is
    574   // true, then all assigned buffers have the is_thread_local flag set to
    575   // true.
    576   Status AssignBuffersForComputation(
    577       const HloComputation* computation, bool is_thread_local,
    578       const absl::flat_hash_set<const LogicalBuffer*>& colocated_buffers,
    579       const absl::flat_hash_set<BufferAllocation::Index>& colocated_allocations,
    580       absl::flat_hash_map<const HloComputation*,
    581                           absl::flat_hash_set<const LogicalBuffer*>>*
    582           buffers_to_assign_sequentially,
    583       BufferAssignment* assignment);
    584 
    585   // Assigns 'buffers_to_assign_sequentially' using heap simulation, assuming
    586   // the HLO instructions will be executed in the sequential order given by
    587   // assignment->liveness().hlo_ordering().SequentialOrder. If
    588   // 'run_whole_module_heap_simulation' is true, the heap simulation will be run
    589   // assuming all global computations are sequentially ordered.
    590   Status AssignBuffersWithSequentialOrdering(
    591       const absl::flat_hash_map<const HloComputation*,
    592                                 absl::flat_hash_set<const LogicalBuffer*>>&
    593           buffers_to_assign_sequentially,
    594       bool run_whole_module_heap_simulation, BufferAssignment* assignment);
    595 
    596   // Uses the results of the heap simulator to create a single allocation, with
    597   // LogicalBuffers packed to specific offsets.
    598   void AssignBuffersFromHeapSimulator(const HeapSimulator::Result& result,
    599                                       BufferAssignment* assignment,
    600                                       LogicalBuffer::Color color);
    601 
    602   // Tries to assign the given instruction to the given buffer. Returns if the
    603   // assignment was successful.
    604   bool MaybeAssignBuffer(BufferAllocation* allocation,
    605                          const LogicalBuffer& buffer,
    606                          BufferAssignment* assignment);
    607 
    608   // Colocated buffers are logical buffers from different computations which
    609   // alias. Explicitly handling these colocated buffers is necessary because
    610   // points-to analysis is computation level scope and does not recognize
    611   // aliasing across computations (b/32491382).
    612   using ColocatedBufferSet = absl::flat_hash_set<const LogicalBuffer*>;
    613 
    614   // Returns a vector of ColocatedBufferSet objects, where each
    615   // ColocatedBufferSet aggregates a set of related LogicalBuffers from 'module'
    616   // which should be colocated in the same buffer allocation.
    617   void BuildColocatedBufferSets(
    618       const HloModule* module, const BufferLiveness& buffer_liveness,
    619       const LogicalBuffer::SizeFunction& buffer_size,
    620       std::vector<ColocatedBufferSet>* colocated_buffer_sets);
    621 
    622   // For each buffer set in 'colocated_buffer_sets', assigns all buffers in the
    623   // same set to the same buffer allocation in 'assignment'.
    624   void AssignColocatedBufferSets(
    625       const std::vector<ColocatedBufferSet>& colocated_buffer_sets,
    626       BufferAssignment* assignment,
    627       absl::flat_hash_set<const LogicalBuffer*>* colocated_buffers,
    628       absl::flat_hash_set<BufferAllocation::Index>* colocated_allocations);
    629 
    630   // Adds the 'colocated_set' of buffers to 'colocated_buffer_sets', maintaining
    631   // the invariant that all sets in 'colocated_buffer_sets' are disjoint.
    632   void AddSetToColocatedBufferSets(
    633       const std::vector<const LogicalBuffer*>& colocated_set,
    634       std::vector<ColocatedBufferSet>* colocated_buffer_sets);
    635 
    636   // Given a list of colocated buffer sets (each colocated buffer set represents
    637   // the logical buffers that would be assigned to the same physical buffer),
    638   // try to merge the sets if the buffers can be shared. Returns the merged set.
    639   std::vector<ColocatedBufferSet> MergeColocatedBufferSets(
    640       const std::vector<ColocatedBufferSet>& colocated_buffer_sets,
    641       const BufferLiveness& buffer_liveness,
    642       const LogicalBuffer::SizeFunction& buffer_size);
    643 
    644   // Split a set of buffers into several sets, each of which contains buffers
    645   // colored with the same color.
    646   absl::flat_hash_map<LogicalBuffer::Color,
    647                       absl::flat_hash_set<const LogicalBuffer*>,
    648                       LogicalBuffer::Color::Hasher>
    649   SplitBuffersByColor(const absl::flat_hash_set<const LogicalBuffer*>& buffers);
    650 
    651   // If true, allocate buffers for constant instructions.
    652   bool allocate_buffers_for_constants_;
    653 
    654   // Functor used to assign colors to newly allocated logical buffers.
    655   BufferLiveness::Colorer colorer_;
    656 
    657   // Functor to check if a buffer can reuse an allocation.
    658   ReuseAllocationFunction reuse_checker_;
    659 
    660   TF_DISALLOW_COPY_AND_ASSIGN(BufferAssigner);
    661 };
    662 
    663 }  // namespace xla
    664 
    665 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_ASSIGNMENT_H_
    666