Home | History | Annotate | Download | only in service
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_
     18 
     19 #include <type_traits>
     20 
     21 #include "absl/container/flat_hash_map.h"
     22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     23 #include "tensorflow/compiler/xla/service/hlo_module.h"
     24 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
     25 #include "tensorflow/core/util/ptr_util.h"
     26 
     27 namespace xla {
     28 
     29 // IndexedArrayAnalysis decides if an HLO instruction can be rewritten as a
     30 // gather from another array.  It does this by mapping HLO instructions to
     31 // instances of IndexedArrayAnalysis::Array, which can be inspected to discover
     32 // whether said HLO is equivalent to a gather.
     33 class IndexedArrayAnalysis {
     34  public:
     35   // IndexedArrayAnalysis maps each HLO instruction to an instance of a Array.
     36   // Array really just a sum type of the classes that inherit from it.  The
     37   // meaning of each of the subtypes is documented on the subtype declaration.
     38   //
     39   // Array instances are immutable once created.
     40   class Array {
     41    public:
     42     enum Kind {
     43       kUnknown,
     44       kConstant,
     45       kReshaped,
     46       kScalarIndexedConstant,
     47       kScalarIndexed
     48     };
     49 
     50     virtual Kind kind() const = 0;
     51     virtual const Shape& shape() const = 0;
     52 
     53     // Does a checked downcast from `Array` to `T` which must be one of its
     54     // subtypes.
     55     template <typename T>
     56     T* as() {
     57       static_assert((std::is_base_of<Array, T>::value),
     58                     "target type not derived from source type");
     59       // We skip the CHECK and hence the dynamic_cast if RTTI is disabled.
     60 #if !defined(__GNUC__) || defined(__GXX_RTTI)
     61       CHECK_NE(dynamic_cast<T*>(this), nullptr);
     62 #endif  // !defined(__GNUC__) || defined(__GXX_RTTI)
     63 
     64       return static_cast<T*>(this);
     65     }
     66 
     67     virtual ~Array() = default;
     68 
     69     Array& operator=(const Array& other) = delete;
     70   };
     71 
     72   // Represents an HLO instruction that was not analyzable by this
     73   // IndexedArrayAnalysis.  Instances of UnknownArray just wrap an existing
     74   // HloInstruction.
     75   class UnknownArray : public Array {
     76    public:
     77     Kind kind() const override { return kUnknown; }
     78     const Shape& shape() const override { return instruction().shape(); }
     79     const HloInstruction& instruction() const { return instruction_; }
     80 
     81    private:
     82     explicit UnknownArray(const HloInstruction* instr) : instruction_(*instr) {}
     83 
     84     const HloInstruction& instruction_;
     85 
     86     friend class IndexedArrayAnalysis;
     87   };
     88 
     89   // Represents a constant value.  This constant value may be present in the HLO
     90   // module being analyzed, or it could have been created on the fly by the
     91   // analysis.
     92   class ConstantArray : public Array {
     93    public:
     94     Kind kind() const override { return kConstant; }
     95     const Shape& shape() const override { return literal()->shape(); }
     96     const Literal* literal() const { return literal_; }
     97 
     98    private:
     99     explicit ConstantArray(const Literal* literal) : literal_(literal) {}
    100     const Literal* literal_;
    101 
    102     friend class IndexedArrayAnalysis;
    103   };
    104 
    105   // Represents an Array that is a reshape of another Array.
    106   class ReshapedArray : public Array {
    107    public:
    108     Kind kind() const override { return kReshaped; }
    109 
    110     // The array to reshape.
    111     Array* operand() const { return operand_; }
    112 
    113     // The output shape.
    114     const Shape& shape() const override { return shape_; }
    115 
    116    private:
    117     explicit ReshapedArray(Array* operand, Shape shape)
    118         : operand_(operand), shape_(shape) {}
    119 
    120     Array* operand_;
    121     const Shape shape_;
    122 
    123     friend class IndexedArrayAnalysis;
    124   };
    125 
    126   // ---------------------------------------------------------------------------
    127   // Indexed Array Overview
    128   // ---------------------------------------------------------------------------
    129   //
    130   // ScalarIndexedArray and ScalarIndexedConstantArray form the core of this
    131   // analysis.  ScalarIndexedConstantArray is just a specialization of
    132   // ScalarIndexedArray so we will only discuss ScalarIndexedArray in this
    133   // overview.
    134   //
    135   // A ScalarIndexedArray represents an array that can be computed by indexing
    136   // into a "source" array using an "indices" tensor.  A simple example is a
    137   // gather operation gathering 12 rows out of a [100,100] matrix -- such an
    138   // operation will be represented by an instance of a ScalarIndexedArray with
    139   // the [100,100] matrix as the "source" array and the [12]-shaped indices
    140   // array as the "indices" tensor.  The ScalarIndexedArray operation itself
    141   // will be of shape [12,100] (assuming we were gathering with axis=0).
    142   //
    143   // Gather operations are not the only operation that maps to
    144   // ScalarIndexedArray instances (if that were true there would be little point
    145   // in having a separate analysis).  We can often infer ScalarIndexedArrays for
    146   // other operations too.  For instance, consider:
    147   //
    148   //   %source = f32[100,100] constant
    149   //   %indices = s32[12] ...
    150   //   %gather = f32[12,100] ... gather from %source using %indices at axis 0
    151   //   %dot = dot(%gather, other_constant) [canonical contracting dims]
    152   //
    153   // The dot operation itself is also a ScalarIndexedArray with source =
    154   // dot(constant, other_constant) and indices = %indices.  A reshape of %gather
    155   // to [12,5,20] too is a ScalarIndexedArray with source = an appropriately
    156   // reshaped constant and indices = %indices.
    157 
    158   // Represents the result of a gather operation.  This gather operation may
    159   // explicitly be present in the HLO module being analyzed, or it could have
    160   // been created on the fly by the analysis.
    161   //
    162   // An instance of ScalarIndexedArray represents a array whose I'th element can
    163   // be mapped to the J'th element of the `source` array (where I and J are
    164   // multidimensional indices) in this way:
    165   //
    166   //   I' = remove components at positions `output_dims` from I
    167   //   G' = remove components not at positions `output_dims` from I
    168   //   T  = indices[G']
    169   //   J  = I' with T inserted at position `source_dim`
    170   //
    171   // For example, if source is of shape [11,13,17,19], indices is of shape
    172   // [23,29], output_dims is [0,2] and source_dim is 2 then the output is of
    173   // shape [23,11,29,13,19] and the output index [A,B,C,D,E] is mapped to the
    174   // input index [B,D,indices[A,C],E].
    175   class ScalarIndexedArray : public Array {
    176    public:
    177     Kind kind() const override { return kScalarIndexed; }
    178     const Shape& shape() const override { return shape_; }
    179 
    180     Array* source() const { return source_; }
    181     Array* indices() const { return indices_; }
    182 
    183     // `source_dim` is the dimension in the source array that is being indexed
    184     // over using indices from the `indices` array.  See the class documentation
    185     // and the overview for more details.
    186     int64 source_dim() const { return source_dim_; }
    187 
    188     // `output_dims` are the dimensions in the output array that are being used
    189     // to compute an index into the `indices` array.  See the class
    190     // documentation and the overview for more details.
    191     absl::Span<const int64> output_dims() const { return output_dims_; }
    192 
    193    private:
    194     explicit ScalarIndexedArray(Array* source, Array* indices, int64 source_dim,
    195                                 std::vector<int64> output_dims, Shape shape)
    196         : source_(source),
    197           indices_(indices),
    198           source_dim_(source_dim),
    199           output_dims_(std::move(output_dims)),
    200           shape_(std::move(shape)) {}
    201 
    202     Array* source_;
    203     Array* indices_;
    204     int64 source_dim_;
    205     std::vector<int64> output_dims_;
    206     Shape shape_;
    207 
    208     friend class IndexedArrayAnalysis;
    209   };
    210 
    211   // A ScalarIndexedConstantArray is just a ScalarIndexedArray constrained to
    212   // have a ConstantArray instance as the source.  This is an ergonomic
    213   // concession -- in theory it is possible to just keep ScalarIndexedArray and
    214   // check source()->kind().
    215   class ScalarIndexedConstantArray : public ScalarIndexedArray {
    216    public:
    217     Kind kind() const override { return kScalarIndexedConstant; }
    218 
    219     const Literal& literal() const {
    220       return *source()->as<ConstantArray>()->literal();
    221     }
    222 
    223    private:
    224     explicit ScalarIndexedConstantArray(Array* source, Array* indices,
    225                                         int64 source_dim,
    226                                         std::vector<int64> output_dims,
    227                                         Shape shape)
    228         : ScalarIndexedArray(source, indices, source_dim,
    229                              std::move(output_dims), std::move(shape)) {
    230       CHECK(dynamic_cast<ConstantArray*>(source));
    231     }
    232 
    233     friend class IndexedArrayAnalysis;
    234   };
    235 
    236   // Returns an Array instance for `instr`.  The IndexedArrayAnalysis instance
    237   // keeps ownership of the returned Array instance.
    238   //
    239   // Caching Behavior: IndexedArrayAnalysis has a cache mapping HLO
    240   // instructions to IndexedArrayAnalysis::Array instances.  This entire cache
    241   // becomes stale and may cause the analysis to return incorrect results if any
    242   // transitive operand (stopping at the containing computation) is modified for
    243   // any HLO instruction on which GetArrayFor has been invoked.
    244   //
    245   // NB!  By inspecting the implementation, you may be able to infer a stronger
    246   // caching guarantee than what is mentioned above.  Nevertheless, what is
    247   // stated above is the contract.
    248   StatusOr<Array*> GetArrayFor(const HloInstruction* instr);
    249 
    250   // Pretty-prints the expression rooted at `root`.
    251   string ToString(Array* root, bool print_constants = false);
    252 
    253  private:
    254   // Helper function that ensures that every HLO instruction that is
    255   // transitively used by `root` has an entry in `cache_`.
    256   Status TraverseAndPopulateCache(const HloInstruction* root);
    257 
    258   // Creates an Array instance for `instr` under the assumption that all
    259   // operations of `instr` are present in `cache_`.
    260   StatusOr<Array*> ComputeArrayFor(const HloInstruction* instr);
    261 
    262   StatusOr<Array*> ComputeArrayForConstant(const Literal& literal);
    263 
    264   StatusOr<Array*> ComputeArrayForGather(
    265       const Shape& shape, const GatherDimensionNumbers& dim_numbers,
    266       absl::Span<const int64> slice_sizes, Array* source, Array* indices);
    267 
    268   StatusOr<Array*> ComputeArrayForDotWithIndexedLhs(
    269       const Shape& shape, const DotDimensionNumbers& dim_numbers,
    270       const PrecisionConfig& precision_config, ScalarIndexedConstantArray* lhs,
    271       ConstantArray* rhs);
    272 
    273   StatusOr<Array*> ComputeArrayForDotWithIndexedRhs(
    274       const Shape& shape, const DotDimensionNumbers& dim_numbers,
    275       const PrecisionConfig& precision_config, ConstantArray* lhs,
    276       ScalarIndexedConstantArray* rhs);
    277 
    278   StatusOr<Array*> ComputeArrayForDot(const Shape& shape,
    279                                       const DotDimensionNumbers& dim_numbers,
    280                                       const PrecisionConfig& precision_config,
    281                                       Array* lhs, Array* rhs);
    282 
    283   // This tries to fold a ScalarIndexedArray which has another
    284   // ScalarIndexedArray as a source into a ScalarIndexedArray that instead has a
    285   // ScalarIndexedArray as indices.  If `source` happened to be a
    286   // ScalarIndexedConstantArray this can result in an expression that is more
    287   // canonical.
    288   //
    289   // As an example, consider a gather operation, G0, gathering 7 elements from
    290   // an array "Arr" of shape [100] resulting in an array of shape [7], and a
    291   // second gather operation, G1, which gathers 3 elements out of the result of
    292   // G0 resulting in an array of shape [3].  Let the indices uses by G0 be I0
    293   // (of shape [7]) and the indices used by G1 be I1 (of shape [3]).  We can
    294   // instead rewrite G1 to gather directly from "Arr" with the three indices
    295   // from I0 as per I1.  In other words, we can rewrite:
    296   //
    297   //    G0 = [Arr[i] for i in I0]
    298   //    G1 = [G0[i]  for i in I1]
    299   //
    300   // into
    301   //
    302   //    I2 = [I0[i]  for i in I1]
    303   //    G1 = [Arr[i] for i in I2]
    304   StatusOr<ScalarIndexedArray*> FoldGatherOfGather(
    305       ScalarIndexedArray* source, Array* indices, int64 source_dim,
    306       absl::Span<const int64> output_dims, Shape shape);
    307 
    308   // Reshapes a scalar-indexed node to remove the degenerate dimensions in its
    309   // output.  The result is always a scalar-indexed node.
    310   StatusOr<ScalarIndexedArray*> ReshapeToRemoveDegenerateDims(
    311       ScalarIndexedArray* operand);
    312 
    313   // Reshapes a scalar-indexed node such that the result has the degenerate
    314   // dimensions `degenerate_dims`.  The result is always a scalar-indexed node.
    315   StatusOr<ScalarIndexedArray*> ReshapeToAddDegenerateDims(
    316       ScalarIndexedArray* operand, absl::Span<const int64> degenerate_dims);
    317 
    318   StatusOr<ScalarIndexedArray*> FoldReshapeOfGather(
    319       const Shape& shape, ScalarIndexedConstantArray* operand);
    320   StatusOr<ScalarIndexedArray*> FoldReshapeOfGatherNoDegenerateDims(
    321       const Shape& shape, ScalarIndexedConstantArray* scalar_indexed);
    322   StatusOr<Array*> ComputeArrayForReshape(const Shape& shape, Array* operand);
    323 
    324   StatusOr<Array*> ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
    325                                                       Array* lhs, Array* rhs);
    326   StatusOr<Array*> ComputeArrayForElementwiseUnaryOp(HloOpcode opcode,
    327                                                      Array* operand);
    328 
    329   template <typename T, typename... Args>
    330   T* Construct(Args&&... args) {
    331     T* new_tensor = new T(std::forward<Args>(args)...);
    332     owned_tensors_.push_back(std::unique_ptr<T>(new_tensor));
    333     return new_tensor;
    334   }
    335 
    336   ScalarIndexedArray* ConstructScalarIndexedArray(
    337       Array* source, Array* indices, int64 source_dim,
    338       std::vector<int64> output_dims, Shape shape) {
    339     if (source->kind() == Array::kConstant) {
    340       return Construct<ScalarIndexedConstantArray>(source, indices, source_dim,
    341                                                    std::move(output_dims),
    342                                                    std::move(shape));
    343     } else {
    344       return Construct<ScalarIndexedArray>(source, indices, source_dim,
    345                                            std::move(output_dims),
    346                                            std::move(shape));
    347     }
    348   }
    349 
    350   Literal* TakeOwnership(Literal literal) {
    351     owned_literals_.push_back(std::move(literal));
    352     return &owned_literals_.back();
    353   }
    354 
    355   StatusOr<Literal*> TakeOwnership(StatusOr<Literal> literal_or_error) {
    356     TF_ASSIGN_OR_RETURN(Literal literal, std::move(literal_or_error));
    357     owned_literals_.push_back(std::move(literal));
    358     return &owned_literals_.back();
    359   }
    360 
    361   std::vector<std::unique_ptr<Array>> owned_tensors_;
    362   std::vector<Literal> owned_literals_;
    363   absl::flat_hash_map<const HloInstruction*, Array*> cache_;
    364 };
    365 
    366 // A pass that prints all non-trivial results returned by IndexedArrayAnalysis.
    367 // This pass is a no-op if !VLOG_IS_ON(2) so it should be fine to
    368 // unconditionally add to the regular HLO pass pipeline.
    369 class IndexedArrayAnalysisPrinterPass : public HloModulePass {
    370  public:
    371   absl::string_view name() const override;
    372   StatusOr<bool> Run(HloModule* module) override;
    373 };
    374 
    375 }  // namespace xla
    376 
    377 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_
    378