Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_POINTS_TO_ANALYSIS_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_POINTS_TO_ANALYSIS_H_
     18 
     19 #include <stddef.h>
     20 #include <iosfwd>
     21 #include <memory>
     22 #include <set>
     23 #include <string>
     24 #include <vector>
     25 
     26 #include "absl/container/flat_hash_map.h"
     27 #include "absl/container/inlined_vector.h"
     28 #include "absl/types/span.h"
     29 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
     30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     31 #include "tensorflow/compiler/xla/service/hlo_module.h"
     32 #include "tensorflow/compiler/xla/service/logical_buffer.h"
     33 #include "tensorflow/compiler/xla/service/logical_buffer_analysis.h"
     34 #include "tensorflow/compiler/xla/shape_tree.h"
     35 #include "tensorflow/compiler/xla/statusor.h"
     36 #include "tensorflow/compiler/xla/types.h"
     37 #include "tensorflow/compiler/xla/xla_data.pb.h"
     38 #include "tensorflow/core/lib/core/status.h"
     39 #include "tensorflow/core/lib/gtl/compactptrset.h"
     40 #include "tensorflow/core/platform/macros.h"
     41 #include "tensorflow/core/platform/types.h"
     42 
     43 namespace xla {
     44 
     45 // A class describing the source(s) of the Buffer(s) contained in the output of
     46 // a particular HLO instruction. The structure of PointsToSet mirrors the
     47 // structure of the instruction's shape, which may be an arbitrary tree (eg, a
     48 // nested tuple). Each node in this tree corresponds to a single buffer in the
     49 // instruction's output and contains the set of Buffers which might define
     50 // the corresponding buffer.
     51 class PointsToSet {
     52  public:
     53   // Construct our ShapeTree with a pointer rather than a reference to a Shape
     54   // because this is very hot code, and copying (and then destroying) all these
     55   // Shapes is slow.
     56   explicit PointsToSet(const Shape* shape) : tree_(shape) {}
     57 
     58   // Returns true if any points-to sets for any subshape element is not a
     59   // singleton.
     60   bool IsAmbiguous() const;
     61 
     62   // Returns true if no LogicalBuffer appears in more than one points-to set of
     63   // the shape nodes.
     64   bool IsDistinct() const;
     65 
     66   // Returns the total number of different LogicalBuffers contained in this
     67   // object. This is equal to CreateFlattenedSet().size().
     68   size_t size() const;
     69 
     70   // Creates a set containing the union of all LogicalBuffers contained in the
     71   // PointsToSet.
     72   using BufferSet = tensorflow::gtl::CompactPointerSet<const LogicalBuffer*>;
     73   BufferSet CreateFlattenedSet() const;
     74 
     75   // Returns true if the given buffer is in the points-to set at the given
     76   // index.
     77   bool ContainsBufferAtIndex(const LogicalBuffer& buffer,
     78                              const ShapeIndex& index) const;
     79 
     80   // Returns true if the given buffer is in the points-to set at any index.
     81   bool ContainsBuffer(const LogicalBuffer& buffer) const;
     82 
     83   // Adds the given buffer to the points-to set at the given index. This is a
     84   // nop if the buffer already is in the set at that index.
     85   void AddPointedToBuffer(const LogicalBuffer& buffer, const ShapeIndex& index);
     86 
     87   // For the subshape at the given index (where index is defined as in
     88   // ShapeUtil::GetSubshape) this method returns the set of HLO instructions
     89   // which may produce the tuple subshape at that index. For example, given:
     90   //
     91   // %tuple1 = tuple(...)
     92   // %tuple2 = tuple(...)
     93   // %select = select(%tuple1, %tuple2)
     94   // %nested_tuple = tuple(%select, %tuple1)
     95   //
     96   // These are the values for tuple_sources() for the PointsToSet of
     97   // %nested_tuple:
     98   //
     99   // tuple_sources({}) = {%nested_tuple}
    100   // tuple_sources({0}) = {%tuple1, %tuple2}
    101   // tuple_sources({1}) = {%tuple1}
    102   //
    103   // tuple_sources() at the index of an array shape (not a tuple) returns the
    104   // empty set. The instructions in the set returned by tuple_sources
    105   // necessarily are either Tuple instructions, constants, or parameters.
    106   using SourceSet = tensorflow::gtl::CompactPointerSet<HloInstruction*>;
    107   const SourceSet& tuple_sources(const ShapeIndex& index) const;
    108 
    109   // Add a tuple source instruction for the given index.
    110   void add_tuple_source(const ShapeIndex& index, HloInstruction* tuple);
    111 
    112   using BufferList = absl::InlinedVector<const LogicalBuffer*, 1>;
    113 
    114   // Return the list of logical buffers for the subshape at index.
    115   const BufferList& element(const ShapeIndex& index) const {
    116     return tree_.element(index).buffers;
    117   }
    118   BufferList* mutable_element(const ShapeIndex& index) {
    119     return &tree_.mutable_element(index)->buffers;
    120   }
    121 
    122   // Call fn(index, buflist) for every subshape index.
    123   template <typename Fn>
    124   void ForEachElement(const Fn& fn) const {
    125     tree_.ForEachElement([&fn](const ShapeIndex& index, const Elem& elem) {
    126       fn(index, elem.buffers);
    127     });
    128   }
    129   template <typename Fn>
    130   void ForEachMutableElement(const Fn& fn) {
    131     tree_.ForEachMutableElement([&fn](const ShapeIndex& index, Elem* elem) {
    132       fn(index, &elem->buffers);
    133     });
    134   }
    135   template <typename Fn>
    136   Status ForEachElementWithStatus(const Fn& fn) const {
    137     return tree_.ForEachElementWithStatus(
    138         [&fn](const ShapeIndex& index, const Elem& elem) {
    139           return fn(index, elem.buffers);
    140         });
    141   }
    142 
    143  private:
    144   struct Elem {
    145     BufferList buffers;
    146     SourceSet tuple_sources;
    147   };
    148   ShapeTree<Elem> tree_;
    149 
    150   // PointsToSet contains references (const LogicalBuffer*) to elements within
    151   // TuplePointsToAnalysis, so disable copying.
    152   TF_DISALLOW_COPY_AND_ASSIGN(PointsToSet);
    153 };
    154 
    155 // This class describes a particular subshape in a computation (instruction and
    156 // shape index) and the logical buffer which may be a source of the subshape
    157 // value.
    158 class BufferAlias {
    159  public:
    160   BufferAlias(HloInstruction* instruction, const ShapeIndex& index)
    161       : instruction_(instruction), index_(index) {}
    162 
    163   // Return the instruction/index of the subshape.
    164   HloInstruction* instruction() const { return instruction_; }
    165   const ShapeIndex& index() const { return index_; }
    166 
    167   bool operator==(const BufferAlias& other) const {
    168     return instruction_ == other.instruction_ && index_ == other.index_;
    169   }
    170   bool operator!=(const BufferAlias& other) const { return !(*this == other); }
    171 
    172   string ToString() const;
    173 
    174  private:
    175   HloInstruction* instruction_;
    176   ShapeIndex index_;
    177 };
    178 
    179 std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias);
    180 
    181 // DFS visitor that performs tuple points-to analysis. This analysis determines
    182 // the potential sources of each buffer in each instruction's output.
    183 class TuplePointsToAnalysis : public DfsHloVisitorWithDefault {
    184  public:
    185   // Runs points-to analysis on 'module'.
    186   static StatusOr<std::unique_ptr<TuplePointsToAnalysis>> Run(
    187       const HloModule* module);
    188 
    189   // Return the points-to set of an instruction. This describes the potential
    190   // sources of each buffer in the instruction's output.
    191   const PointsToSet& GetPointsToSet(
    192       const HloInstruction* hlo_instruction) const;
    193 
    194   // Returns the logical buffer with the given ID.
    195   const LogicalBuffer& GetBuffer(LogicalBuffer::Id id) const;
    196 
    197   // Returns the buffer defined at the given instruction and index. An error is
    198   // returned if no buffer is defined at that point.
    199   StatusOr<const LogicalBuffer*> GetBufferDefinedAt(
    200       const HloInstruction* instruction, const ShapeIndex& index) const;
    201 
    202   // Return a (possibly empty) vector containing all BufferAliases of the given
    203   // logical buffer The buffer alias set is the inverse of the points-to set.
    204   // That is, LogicalBuffer B is in the points-to set of instruction I at index
    205   // N iff instruction I, index N is a BufferAlias of B.
    206   using BufferAliasVector = absl::InlinedVector<BufferAlias, 1>;
    207   const BufferAliasVector& GetBufferAliases(const LogicalBuffer& buffer) const;
    208 
    209   // Returns the number of logical buffers in the module
    210   LogicalBuffer::Id num_logical_buffers() const {
    211     return logical_buffer_analysis_->num_logical_buffers();
    212   }
    213 
    214   // Return a the logical buffer with id "id" in the module. Iteration
    215   // over all logical buffers is usually done with something like:
    216   //
    217   // for (LogicalBuffer:Id id = 0; id < points_to.num_logical_buffers(); id++){
    218   //   const auto& buffer = points_to.logical_buffer(id);
    219   //   ... do something with buffer ...
    220   // }
    221   LogicalBuffer& logical_buffer(LogicalBuffer::Id id) const {
    222     return logical_buffer_analysis_->GetBuffer(id);
    223   }
    224 
    225   // Returns a vector of buffers that the instruction produces. Most
    226   // instructions produce a single buffer (the top-level buffer), some produce
    227   // no buffers (eg bitcast), and some produce more than one buffer (eg,
    228   // tuple-shaped parameters).
    229   using BufferDefinitionVector = absl::InlinedVector<const LogicalBuffer*, 1>;
    230   const BufferDefinitionVector& GetBuffersDefinedByInstruction(
    231       const HloInstruction* instruction) const;
    232 
    233   // Returns true if the given instruction defines a buffer at the given index.
    234   bool InstructionDefinesBufferAtIndex(const HloInstruction* instruction,
    235                                        const ShapeIndex& index) const;
    236 
    237   // Returns an OK status if the given buffer is defined by instruction
    238   // 'buffer.instruction()' at index 'buffer.index()' and if the given buffer
    239   // matches the TuplePointsToAnalysis' LogicalBuffer with 'buffer.id'. Returns
    240   // an FailedPrecondition error status otherwise. An example of a LogicalBuffer
    241   // which is not defined is a tuple element in a Tuple instruction. In this
    242   // case, the Tuple instruction does not define the LogicalBuffer, rather that
    243   // index aliases one of its operands.
    244   Status VerifyBuffer(const LogicalBuffer& buffer) const;
    245 
    246   Status DefaultAction(HloInstruction* hlo_instruction) override;
    247   Status HandleTuple(HloInstruction* tuple) override;
    248   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
    249   Status HandleBitcast(HloInstruction* bitcast) override;
    250   Status HandleDomain(HloInstruction* domain) override;
    251   Status HandleCopy(HloInstruction* copy) override;
    252   Status HandleRecvDone(HloInstruction* recv_done) override;
    253   Status HandleSend(HloInstruction* send) override;
    254   Status HandleTupleSelect(HloInstruction* tuple_select) override;
    255   Status HandleAddDependency(HloInstruction* add_dependency) override;
    256 
    257   string ToString() const;
    258 
    259   // Returns true if 'user' cannot possibly use the buffer at 'index' in
    260   // 'operand'. Returns false otherwise.
    261   //
    262   // REQUIRES: 'operand' is an operand of 'user'.
    263   bool DoesNotUseOperandBuffer(const HloInstruction* operand,
    264                                const ShapeIndex& index,
    265                                const HloInstruction* user) const;
    266 
    267   // Returns true if 'user' (at 'user_index') can share a buffer with its
    268   // operand 'operand' (at 'operand_index'). Returns false otherwise.
    269   //
    270   // REQUIRES: 'operand' is an operand of 'user'.
    271   bool CanShareOperandBufferWithUser(HloInstruction* operand,
    272                                      const ShapeIndex& operand_index,
    273                                      HloInstruction* user,
    274                                      const ShapeIndex& user_index) const;
    275 
    276  private:
    277   explicit TuplePointsToAnalysis(
    278       const HloModule* module,
    279       std::unique_ptr<LogicalBufferAnalysis> logical_buffer_analysis)
    280       : module_(module),
    281         logical_buffer_analysis_(std::move(logical_buffer_analysis)) {}
    282 
    283   // Perform the analysis. Should be called immediately after constructing the
    284   // object and before calling GetPointsToSet.
    285   Status Analyze();
    286 
    287   // Populates instruction-defined buffers and aliases for each instruction
    288   // in 'instructions'.
    289   Status PopulateDefinedBuffersAndAliases(const decltype(
    290       std::declval<HloComputation>().instructions())& instructions);
    291 
    292   // Creates an empty PointsToSet in the points_to_ map for the given
    293   // instruction.
    294   PointsToSet& CreateEmptyPointsToSet(const HloInstruction* instruction);
    295 
    296   // Creates a PointsToSet in the points_to_ map for 'instruction' which is a
    297   // copy of the existing PointsToSet for 'src'.
    298   PointsToSet& CreateCopiedPointsToSet(const HloInstruction* instruction,
    299                                        const HloInstruction* src);
    300 
    301   // Adds the buffers defined by the given instruction to the given vector.
    302   Status GatherBuffersDefinedByInstruction(const HloInstruction* instruction,
    303                                            BufferDefinitionVector* buffers);
    304 
    305   // Print points-to set for 'instruction' to 'output'.
    306   void InstructionToString(const HloInstruction* instruction,
    307                            string* output) const;
    308 
    309   // Information kept per instruction
    310   struct PerInstruction {
    311     std::unique_ptr<PointsToSet> points_to_set;
    312     // Empircally, ~92% of instructions have 1
    313     // instruction_defined_buffer, and 99% have 0 or 1
    314     BufferDefinitionVector instruction_defined_buffers;
    315   };
    316 
    317   const PerInstruction* PerInst(const HloInstruction* inst) const {
    318     int id = inst->unique_id();
    319     DCHECK_GE(id, 0);
    320     auto iter = per_instruction_.find(id);
    321     if (iter == per_instruction_.end()) {
    322       LOG(FATAL) << "Expected per-instruction information to already exist";
    323     } else {
    324       return iter->second.get();
    325     }
    326   }
    327   PerInstruction* PerInst(const HloInstruction* inst) {
    328     int id = inst->unique_id();
    329     DCHECK_GE(id, 0);
    330     auto iter = per_instruction_.find(id);
    331     if (iter == per_instruction_.end()) {
    332       return per_instruction_.emplace(id, absl::make_unique<PerInstruction>())
    333           .first->second.get();
    334     } else {
    335       return iter->second.get();
    336     }
    337   }
    338 
    339   std::vector<std::pair<HloInstruction*, int64>> GetAllUsesOfInstructionAtIndex(
    340       HloInstruction* instruction, const ShapeIndex& index) const;
    341   bool HasUniqueFusedUseOfOperandAt(HloInstruction* operand,
    342                                     const ShapeIndex& operand_index,
    343                                     HloInstruction* fusion,
    344                                     const int64 use_operand_index) const;
    345 
    346   // The module this analysis is performed on.
    347   const HloModule* module_;
    348 
    349   // The logical buffers for this module.
    350   const std::unique_ptr<LogicalBufferAnalysis> logical_buffer_analysis_;
    351 
    352   // A map from instruction->unique_id() to
    353   absl::flat_hash_map<int, std::unique_ptr<PerInstruction>> per_instruction_;
    354 
    355   // A map from LogicalBuffer->id() to alias information about that logical
    356   // buffer
    357   std::vector<BufferAliasVector> logical_buffer_aliases_;
    358 
    359   TF_DISALLOW_COPY_AND_ASSIGN(TuplePointsToAnalysis);
    360 };
    361 
    362 }  // namespace xla
    363 
    364 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_POINTS_TO_ANALYSIS_H_
    365