Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_ASSIGNMENT_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_ASSIGNMENT_H_
     18 
     19 #include <functional>
     20 #include <iosfwd>
     21 #include <memory>
     22 #include <string>
     23 #include <vector>
     24 
     25 #include "tensorflow/compiler/xla/service/buffer_liveness.h"
     26 #include "tensorflow/compiler/xla/service/heap_simulator.h"
     27 #include "tensorflow/compiler/xla/service/hlo.pb.h"
     28 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     30 #include "tensorflow/compiler/xla/service/hlo_module.h"
     31 #include "tensorflow/compiler/xla/service/logical_buffer.h"
     32 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
     33 #include "tensorflow/compiler/xla/statusor.h"
     34 #include "tensorflow/compiler/xla/types.h"
     35 #include "tensorflow/compiler/xla/xla_data.pb.h"
     36 #include "tensorflow/core/lib/gtl/array_slice.h"
     37 #include "tensorflow/core/lib/gtl/flatmap.h"
     38 #include "tensorflow/core/lib/gtl/flatset.h"
     39 #include "tensorflow/core/platform/logging.h"
     40 #include "tensorflow/core/platform/macros.h"
     41 #include "tensorflow/core/platform/types.h"
     42 
     43 namespace xla {
     44 
     45 // This class abstracts an allocation of contiguous memory which can hold the
     46 // values described by LogicalBuffers. Each LogicalBuffer occupies a sub-range
     47 // of the allocation, represented by a Slice. A single BufferAllocation may hold
     48 // LogicalBuffers with disjoint liveness, which may have overlapping Slices. A
     49 // single BufferAllocation may also hold LogicalBuffers with overlapping
     50 // liveness, which must have disjoint Slices.
     51 //
     52 // The abstraction includes information required by the backends for allocation,
     53 // use, and deallocation of the buffer. This includes the LogicalBuffers which
     54 // are held in this allocation through the execution of the computation.
     55 class BufferAllocation {
     56  public:
     57   // Holds a unique identifier for each allocation. Values are assigned
     58   // contiguously and can be used as array indexes.
     59   using Index = int64;
     60 
     61   BufferAllocation(Index index, int64 size, bool is_thread_local,
     62                    bool is_reusable, LogicalBuffer::Color color)
     63       : index_(index),
     64         size_(size),
     65         is_thread_local_(is_thread_local),
     66         is_reusable_(is_reusable),
     67         color_(color) {}
     68   ~BufferAllocation() {}
     69 
     70   // Returns the index of this allocation.
     71   Index index() const { return index_; }
     72 
     73   // Whether this allocation is used in a parallel calling context such as
     74   // inside of a map or reduce computation. Such allocations need to be thread
     75   // local.
     76   bool is_thread_local() const { return is_thread_local_; }
     77 
     78   // Whether this allocation can be used by more than one logical buffer.
     79   bool is_reusable() const { return is_reusable_; }
     80 
     81   // Whether this allocation holds a LogicalBuffer from a parameter of the entry
     82   // computation. These buffers have lifetimes which may be longer than the
     83   // XLA computation.
     84   bool is_entry_computation_parameter() const {
     85     return is_entry_computation_parameter_;
     86   }
     87   // If this allocation holds a Buffer from a parameter of the entry
     88   // computation, this methods returns the parameter number. CHECKs otherwise.
     89   int64 parameter_number() const {
     90     CHECK(is_entry_computation_parameter_);
     91     return parameter_number_;
     92   }
     93 
     94   // If this allocation is for a parameter of the entry computation, this
     95   // function returns which subshape of the parameter the allocation is for.
     96   const ShapeIndex& param_shape_index() const {
     97     CHECK(is_entry_computation_parameter_);
     98     return param_shape_index_;
     99   }
    100 
    101   // Returns whether this allocation is assigned a LogicalBuffer which may
    102   // be live out of the entry computation.
    103   bool maybe_live_out() const { return maybe_live_out_; }
    104 
    105   // Returns the size of the allocation. Necessarily this must be at least as
    106   // large as any LogicalBuffer assigned to this allocation.
    107   int64 size() const { return size_; }
    108 
    109   // Returns the color of the allocation. Only logical buffers with a matching
    110   // color can reside in this allocation.
    111   LogicalBuffer::Color color() const { return color_; }
    112 
    113   struct OffsetSize {
    114     int64 offset = 0;
    115     int64 size = 0;
    116   };
    117 
    118   // Access to the logical buffers assigned to this allocation, and their
    119   // associated logical offsets and sizes.
    120   const tensorflow::gtl::FlatMap<const LogicalBuffer*, OffsetSize>&
    121   assigned_buffers() const {
    122     return assigned_buffers_;
    123   }
    124 
    125   // A Slice represents a contiguous portion of a memory allocation. It is used
    126   // to identify the memory range that a LogicalBuffer corresponds to.
    127   class Slice {
    128    public:
    129     Slice() {}
    130     Slice(const BufferAllocation* allocation, int64 offset, int64 size)
    131         : allocation_(allocation), offset_(offset), size_(size) {}
    132 
    133     const BufferAllocation* allocation() const { return allocation_; }
    134     Index index() const { return allocation_->index(); }
    135     int64 offset() const { return offset_; }
    136     int64 size() const { return size_; }
    137 
    138     bool operator==(const Slice& other) const {
    139       return index() == other.index() && offset_ == other.offset_ &&
    140              size_ == other.size_;
    141     }
    142     bool operator!=(const Slice& other) const { return !(*this == other); }
    143     bool operator<(const Slice& other) const {
    144       if (index() != other.index()) return index() < other.index();
    145       if (offset_ != other.offset_) return offset_ < other.offset_;
    146       return size_ < other.size_;
    147     }
    148 
    149     // Returns true iff this slice's memory range has a non-empty intersection
    150     // with the other slice's memory range.
    151     bool OverlapsWith(const Slice& other) const {
    152       const int64 end = offset_ + size_;
    153       const int64 other_end = other.offset_ + other.size_;
    154       return index() == other.index() && offset_ < other_end &&
    155              end > other.offset_;
    156     }
    157 
    158     struct Hasher {
    159       size_t operator()(Slice s) const;
    160     };
    161 
    162     string ToString() const;
    163 
    164    private:
    165     const BufferAllocation* allocation_ = nullptr;
    166     int64 offset_ = 0;
    167     int64 size_ = 0;
    168   };
    169 
    170   // GetSlice returns the Slice of contiguous memory that holds the value
    171   // described by the given 'buffer'.
    172   // REQUIRES: 'buffer' must be assigned to this allocation.
    173   Slice GetSlice(const LogicalBuffer& buffer) const;
    174 
    175   string ToString() const;
    176   BufferAllocationProto ToProto() const;
    177 
    178   // Whether the buffer is a parameter to or live out of the entry computation.
    179   bool IsInputOrOutput() const {
    180     return is_entry_computation_parameter() || maybe_live_out();
    181   }
    182 
    183   // Whether the buffer is a temporary buffer allocated before
    184   // Executable::ExecuteOnStream.
    185   bool IsPreallocatedTempBuffer() const {
    186     // Parameters do not need temporary buffers.
    187     return !is_entry_computation_parameter() &&
    188            // LogicalBuffers that maybe pointed to by the output should live out
    189            // of the computation.
    190            !maybe_live_out() &&
    191            // Thread-local buffers are allocated using `alloca`s.
    192            !is_thread_local();
    193   }
    194 
    195   bool operator==(const BufferAllocation& other) const {
    196     return index_ == other.index_;
    197   }
    198   bool operator!=(const BufferAllocation& other) const {
    199     return !(*this == other);
    200   }
    201   bool operator<(const BufferAllocation& other) const {
    202     return index() < other.index();
    203   }
    204 
    205  private:
    206   // Only BufferAssigner and BufferAssignment can modify BufferAllocation.
    207   friend class BufferAssigner;
    208   friend class BufferAssignment;
    209 
    210   // Adds a LogicalBuffer to the set assigned to this buffer.
    211   void AddAssignment(const LogicalBuffer& buffer, int64 offset, int64 size);
    212 
    213   void set_entry_computation_parameter(int64 parameter_number,
    214                                        ShapeIndex param_shape_index) {
    215     is_entry_computation_parameter_ = true;
    216     parameter_number_ = parameter_number;
    217     param_shape_index_ = std::move(param_shape_index);
    218   }
    219   void set_maybe_live_out(bool value) { maybe_live_out_ = value; }
    220   void set_index(Index index) { index_ = index; }
    221   void set_size(int64 size) { size_ = size; }
    222 
    223   // The index of the allocation in the BufferAssignment.
    224   Index index_;
    225 
    226   // Size of the allocation in bytes.
    227   int64 size_;
    228 
    229   // Whether this buffer needs to be thread-local.
    230   bool is_thread_local_;
    231 
    232   // Whether this buffer is usable by more than one logical buffer.
    233   bool is_reusable_;
    234 
    235   // Color of the allocation.
    236   LogicalBuffer::Color color_;
    237 
    238   // Whether this allocation holds an entry computation parameter. Entry
    239   // computation parameters are special be cause they have lifetimes which may
    240   // outlast the computation.
    241   bool is_entry_computation_parameter_ = false;
    242 
    243   // If this allocation holds an entry computation parameter, this field
    244   // indicates the index (starting from 0) of the parameter.
    245   int64 parameter_number_ = 0;
    246 
    247   // If this buffer is for an entry computation parameter, which subshape of the
    248   // parameter is it for?
    249   ShapeIndex param_shape_index_;
    250 
    251   // Whether the allocation contains a LogicalBuffer which may be live-out of
    252   // the entry computation. Note that this flag is conservatively computed by
    253   // TuplePointsToAnalysis.  That is, an allocation marked `maybe_live_out_`
    254   // might not actually escape.
    255   bool maybe_live_out_ = false;
    256 
    257   // Mapping from the set of buffers assigned to this allocation to their
    258   // logical offsets and sizes.
    259   tensorflow::gtl::FlatMap<const LogicalBuffer*, OffsetSize> assigned_buffers_;
    260 };
    261 
    262 // Add stream operators for nicer output of CHECK/RET_CHECK failures.
    263 std::ostream& operator<<(std::ostream& out, const BufferAllocation& s);
    264 std::ostream& operator<<(std::ostream& out, const BufferAllocation::Slice& s);
    265 
    266 // This class encapsulates an assignment of the LogicalBuffers in an XLA
    267 // module to a set of BufferAllocations.
    268 class BufferAssignment {
    269  public:
    270   // Returns the vector containing all buffer allocations in this assignment.
    271   const std::vector<BufferAllocation>& Allocations() const {
    272     return allocations_;
    273   }
    274 
    275   // Returns the total size allocation holding all temporary buffers.
    276   int64 temp_allocation_total_size() const {
    277     return temp_allocation_total_size_;
    278   }
    279 
    280   // Returns whether the given buffer has been assigned an allocation.
    281   bool HasAllocation(const LogicalBuffer& buffer) const;
    282 
    283   // Returns the allocation that a particular LogicalBuffer has been assigned
    284   // to. CHECKs if buffer has not been assigned an allocation.
    285   const BufferAllocation& GetAssignedAllocation(
    286       const LogicalBuffer& buffer) const;
    287 
    288   // Returns the allocation with the given index. CHECKs if no allocation exists
    289   // with the given index.
    290   const BufferAllocation& GetAllocation(BufferAllocation::Index index) const;
    291 
    292   // Builds and returns a vector containing the slices which might contain the
    293   // subvalue at the given index of given instruction.
    294   std::set<BufferAllocation::Slice> GetAllSlices(
    295       const HloInstruction* instruction, const ShapeIndex& index) const;
    296 
    297   // Convenience function which returns whether the buffer of the
    298   // instruction at the given index is assigned an allocation.
    299   bool HasAllocationAt(const HloInstruction* instruction,
    300                        const ShapeIndex& index) const;
    301 
    302   // Convenience function which returns whether the top-level buffer of the
    303   // instruction (index == {}) is assigned an allocation.
    304   bool HasTopLevelAllocation(const HloInstruction* instruction) const;
    305 
    306   // Convenience function which returns the unique slice containing the buffer
    307   // at the given index of the given instruction. If a slice is not assigned or
    308   // the slice cannot be determined at compile time then an error is returned.
    309   StatusOr<BufferAllocation::Slice> GetUniqueSlice(
    310       const HloInstruction* instruction, const ShapeIndex& index) const;
    311   // Like GetUniqueSlice but fixes the index to the top-level of the shape
    312   // (index = {}).
    313   StatusOr<BufferAllocation::Slice> GetUniqueTopLevelSlice(
    314       const HloInstruction* instruction) const;
    315   // Like GetUniqueTopLevelSlice but returns the slice for the output of the
    316   // entry computation of the HLO module (ie, the result of the XLA
    317   // computation).
    318   StatusOr<BufferAllocation::Slice> GetUniqueTopLevelOutputSlice() const;
    319 
    320   // Returns the set LogicalBuffers which may be the source of the value at the
    321   // given index and instruction.
    322   const PointsToSet::BufferList& GetSourceBuffers(
    323       const HloInstruction* instruction, const ShapeIndex& index) const {
    324     return GetPointsToSet(instruction).element(index);
    325   }
    326 
    327   // Returns true if 'hlo_a{shape_index_a}' and 'hlo_b{shape_index_b}'
    328   // share the same BufferAllocation::Slice.
    329   // Returns false otherwise.
    330   // REQUIRES: BufferAssignment assigned allocations to both instructions.
    331   bool SharesSliceAtIndex(const HloInstruction* hlo_a,
    332                           const ShapeIndex& shape_index_a,
    333                           const HloInstruction* hlo_b,
    334                           const ShapeIndex& shape_index_b) const;
    335 
    336   // Returns true if the top-level buffers of hlo_a and hlo_b are the same.
    337   // REQUIRES: HasTopLevelAllocation(hlo_a) && HasTopLevelAllocation(hlo_b).
    338   bool SharesTopLevelSlice(const HloInstruction* hlo_a,
    339                            const HloInstruction* hlo_b) const {
    340     return SharesSliceAtIndex(hlo_a, {}, hlo_b, {});
    341   }
    342 
    343   // Returns true if hlo_a and hlo_b both have at least one buffer assigned for
    344   // their top-level and each of their nested shape indices, and if hlo_a's
    345   // buffers are all different from hlo_b's buffers.
    346   bool HaveDisjointSlices(const HloInstruction* hlo_a,
    347                           const HloInstruction* hlo_b) const;
    348 
    349   // Returns the underlying points-to analysis used for this assignment.
    350   const TuplePointsToAnalysis& points_to_analysis() const {
    351     return liveness_->points_to_analysis();
    352   }
    353 
    354   // Returns the BufferLiveness object used to construct this assignment.
    355   const BufferLiveness& liveness() const { return *liveness_; }
    356 
    357   string ToString() const;
    358   BufferAssignmentProto ToProto() const;
    359 
    360   // Statistics for the assignment.  Values initialized to -1 are not always
    361   // collected; fragmentation is only collected for instructions that have a
    362   // sequential total ordering.
    363   struct Stats {
    364     int64 parameter_allocation_count = 0;
    365     int64 parameter_allocation_bytes = 0;
    366     int64 maybe_live_out_allocation_count = 0;
    367     int64 maybe_live_out_allocation_bytes = 0;
    368     int64 preallocated_temp_allocation_count = 0;
    369     int64 preallocated_temp_allocation_bytes = 0;
    370     int64 preallocated_temp_fragmentation_bytes = -1;
    371     int64 total_allocation_count = 0;
    372     int64 total_allocation_bytes = 0;
    373     int64 total_fragmentation_bytes = -1;
    374 
    375     string ToString() const;
    376   };
    377   const Stats& GetStats() const { return stats_; }
    378 
    379  private:
    380   // Only BufferAssigner can build or modify BufferAssignments.
    381   friend class BufferAssigner;
    382 
    383   explicit BufferAssignment(const HloModule* module,
    384                             std::unique_ptr<BufferLiveness> liveness,
    385                             LogicalBuffer::SizeFunction buffer_size,
    386                             LogicalBuffer::AlignmentFunction color_alignment)
    387       : module_(module),
    388         liveness_(std::move(liveness)),
    389         buffer_size_(std::move(buffer_size)),
    390         color_alignment_(std::move(color_alignment)) {}
    391 
    392   // Creates and returns a new BufferAllocation, with no assigned
    393   // LogicalBuffers. Ownership is maintained internally.
    394   BufferAllocation* NewEmptyAllocation(int64 size, bool is_thread_local,
    395                                        bool is_reusable,
    396                                        LogicalBuffer::Color color);
    397 
    398   // Helper that calls NewEmptyAllocation and AddAssignment in one call,
    399   // creating an allocation containing a single LogicalBuffer.
    400   BufferAllocation* NewAllocation(const LogicalBuffer& buffer, int64 size,
    401                                   bool is_thread_local, bool is_reusable);
    402 
    403   // Adds a LogicalBuffer to the set assigned to the given allocation.
    404   void AddAssignment(BufferAllocation* allocation, const LogicalBuffer& buffer,
    405                      int64 offset, int64 size);
    406 
    407   // Returns the HloModule used to construct this assignment.
    408   const HloModule& module() const { return *module_; }
    409 
    410   // Convenience function which returns the PointsToSet for the given
    411   // instruction. Extracted from the liveness object.
    412   const PointsToSet& GetPointsToSet(const HloInstruction* instruction) const;
    413 
    414   // Mutable accessors for allocations.
    415   BufferAllocation* GetMutableAssignedAllocation(const LogicalBuffer& buffer);
    416   BufferAllocation* GetMutableAllocation(BufferAllocation::Index index);
    417 
    418   // Combines allocations of temporary buffers into one big BufferAllocation.
    419   void CombineTempAllocations();
    420 
    421   // Computes stats for the assignment, to be retrieved by GetStats.
    422   Status ComputeSummaryStats();
    423 
    424   // The vector of buffer allocations. Indexed by BufferAllocation::Index.
    425   std::vector<BufferAllocation> allocations_;
    426 
    427   // The total size of all temporary buffers.
    428   int64 temp_allocation_total_size_ = 0;
    429 
    430   // Maps Buffers to the index of the BufferAllocation which holds the buffer.
    431   tensorflow::gtl::FlatMap<const LogicalBuffer*, BufferAllocation::Index>
    432       allocation_index_for_buffer_;
    433 
    434   const HloModule* module_;
    435   const std::unique_ptr<BufferLiveness> liveness_;
    436 
    437   // Function which returns the buffer size for a given logical buffer (shape).
    438   LogicalBuffer::SizeFunction buffer_size_;
    439 
    440   // Function which returns the alignment for a given logical buffer color.
    441   LogicalBuffer::AlignmentFunction color_alignment_;
    442 
    443   Stats stats_;
    444   std::vector<HeapSimulatorTrace> heap_simulator_traces_;
    445 
    446   TF_DISALLOW_COPY_AND_ASSIGN(BufferAssignment);
    447 };
    448 
    449 // A class which constructs a buffer assignment.
    450 class BufferAssigner {
    451  public:
    452   // Build and return a BufferAssignment for the given module. The given
    453   // HloOrdering is used to determine buffer liveness. buffer_size and
    454   // color_alignment are functions which returns the size and alignment of a
    455   // LogicalBuffer.  allow_input_output_aliasing specifies whether input buffer
    456   // are allowed to be reused as outbut buffers by the client code.
    457   static StatusOr<std::unique_ptr<BufferAssignment>> Run(
    458       const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
    459       LogicalBuffer::SizeFunction buffer_size,
    460       LogicalBuffer::AlignmentFunction color_alignment,
    461       bool allow_input_output_aliasing = false,
    462       BufferLiveness::Colorer colorer = BufferLiveness::DefaultColorer());
    463 
    464  private:
    465   BufferAssigner(bool allow_input_output_aliasing,
    466                  BufferLiveness::Colorer colorer)
    467       : allow_input_output_aliasing_(allow_input_output_aliasing),
    468         colorer_(colorer) {}
    469   virtual ~BufferAssigner() = default;
    470 
    471   // Create a buffer assignment.
    472   StatusOr<std::unique_ptr<BufferAssignment>> CreateAssignment(
    473       const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
    474       LogicalBuffer::SizeFunction buffer_size,
    475       LogicalBuffer::AlignmentFunction color_alignment);
    476 
    477   // Assigns buffers to the instructions in the given computation. "assignment"
    478   // is modified to reflect the new buffer assignments. If is_thread_local is
    479   // true, then all assigned buffers have the is_thread_local flag set to
    480   // true.
    481   Status AssignBuffersForComputation(
    482       const HloComputation* computation, const DebugOptions& debug_options,
    483       bool is_thread_local,
    484       const tensorflow::gtl::FlatSet<const LogicalBuffer*>& colocated_buffers,
    485       const tensorflow::gtl::FlatSet<BufferAllocation::Index>&
    486           colocated_allocations,
    487       tensorflow::gtl::FlatMap<const HloComputation*,
    488                                tensorflow::gtl::FlatSet<const LogicalBuffer*>>*
    489           buffers_to_assign_sequentially,
    490       BufferAssignment* assignment);
    491 
    492   // Assigns 'buffers_to_assign_sequentially' using heap simulation, assuming
    493   // the HLO instructions will be executed in the sequential order given by
    494   // assignment->liveness().hlo_ordering().SequentialOrder. If
    495   // 'run_whole_module_heap_simulation' is true, the heap simulation will be run
    496   // assuming all global computations are sequentially ordered.
    497   Status AssignBuffersWithSequentialOrdering(
    498       const tensorflow::gtl::FlatMap<
    499           const HloComputation*,
    500           tensorflow::gtl::FlatSet<const LogicalBuffer*>>&
    501           buffers_to_assign_sequentially,
    502       bool run_whole_module_heap_simulation, BufferAssignment* assignment);
    503 
    504   // Uses the results of the heap simulator to create a single allocation, with
    505   // LogicalBuffers packed to specific offsets.
    506   void AssignBuffersFromHeapSimulator(const HeapSimulator::Result& result,
    507                                       BufferAssignment* assignment,
    508                                       LogicalBuffer::Color color);
    509 
    510   // Tries to assign the given instruction to the given buffer. Returns if the
    511   // assignment was successful.
    512   bool MaybeAssignBuffer(BufferAllocation* allocation,
    513                          const LogicalBuffer& buffer,
    514                          BufferAssignment* assignment);
    515 
    516   // Colocated buffers are logical buffers from different computations which
    517   // alias. Explicitly handling these colocated buffers is necessary because
    518   // points-to analysis is computation level scope and does not recognize
    519   // aliasing across computations (b/32491382).
    520   using ColocatedBufferSet = tensorflow::gtl::FlatSet<const LogicalBuffer*>;
    521 
    522   // Returns a vector of ColocatedBufferSet objects, where each
    523   // ColocatedBufferSet aggregates a set of related LogicalBuffers from 'module'
    524   // which should be colocated in the same buffer allocation.
    525   void BuildColocatedBufferSets(
    526       const HloModule* module, const BufferLiveness& buffer_liveness,
    527       const LogicalBuffer::SizeFunction& buffer_size,
    528       std::vector<ColocatedBufferSet>* colocated_buffer_sets);
    529 
    530   // For each buffer set in 'colocated_buffer_sets', assigns all buffers in the
    531   // same set to the same buffer allocation in 'assignment'.
    532   void AssignColocatedBufferSets(
    533       const std::vector<ColocatedBufferSet>& colocated_buffer_sets,
    534       BufferAssignment* assignment,
    535       tensorflow::gtl::FlatSet<const LogicalBuffer*>* colocated_buffers,
    536       tensorflow::gtl::FlatSet<BufferAllocation::Index>* colocated_allocations);
    537 
    538   // Adds the 'colocated_set' of buffers to 'colocated_buffer_sets', maintaining
    539   // the invariant that all sets in 'colocated_buffer_sets' are disjoint.
    540   void AddSetToColocatedBufferSets(
    541       const std::vector<const LogicalBuffer*>& colocated_set,
    542       std::vector<ColocatedBufferSet>* colocated_buffer_sets);
    543 
    544   // Given a list of colocated buffer sets (each colocated buffer set represents
    545   // the logical buffers that would be assigned to the same physical buffer),
    546   // try to merge the sets if the buffers can be shared. Returns the merged set.
    547   std::vector<ColocatedBufferSet> MergeColocatedBufferSets(
    548       const std::vector<ColocatedBufferSet>& colocated_buffer_sets,
    549       const BufferLiveness& buffer_liveness,
    550       const LogicalBuffer::SizeFunction& buffer_size);
    551 
    552   // Split a set of buffers into several sets, each of which contains buffers
    553   // colored with the same color.
    554   tensorflow::gtl::FlatMap<LogicalBuffer::Color,
    555                            tensorflow::gtl::FlatSet<const LogicalBuffer*>,
    556                            LogicalBuffer::Color::Hasher>
    557   SplitBuffersByColor(
    558       const tensorflow::gtl::FlatSet<const LogicalBuffer*>& buffers);
    559 
    560   // If true, buffer assignments assumes that input parameter buffers and output
    561   // buffers can be shared if their sizes match.
    562   bool allow_input_output_aliasing_;
    563 
    564   // Functor used to assign colors to newly allocated logical buffers.
    565   BufferLiveness::Colorer colorer_;
    566 
    567   TF_DISALLOW_COPY_AND_ASSIGN(BufferAssigner);
    568 };
    569 
    570 }  // namespace xla
    571 
    572 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_ASSIGNMENT_H_
    573