Home | History | Annotate | Download | only in llvm_ir
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_ARRAY_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_ARRAY_H_
     18 
     19 #include <map>
     20 #include <vector>
     21 
     22 #include "absl/algorithm/container.h"
     23 #include "absl/strings/string_view.h"
     24 #include "absl/types/span.h"
     25 #include "llvm/IR/IRBuilder.h"
     26 #include "llvm/IR/Value.h"
     27 #include "tensorflow/compiler/xla/map_util.h"
     28 #include "tensorflow/compiler/xla/shape.h"
     29 #include "tensorflow/compiler/xla/types.h"
     30 #include "tensorflow/compiler/xla/xla_data.pb.h"
     31 #include "tensorflow/core/platform/logging.h"
     32 #include "tensorflow/core/platform/types.h"
     33 
     34 namespace xla {
     35 namespace llvm_ir {
     36 
     37 // IrArray represents an XLA array at the LLVM IR level. This class
     38 // encapsulates a base pointer to the buffer holding the array (as an LLVM
     39 // Value) and the shape of the array. The class includes methods for emitting
     40 // LLVM IR sequences which access elements of the array at a multidimensional
     41 // index (eg, [x, y, z] in a 3-dimensional array). Arbitrary shape and layouts
     42 // are supported.
     43 class IrArray {
     44  public:
     45   // A multidimensional index into an IrArray. The index for dimension zero is
     46   // first in the vector. This is the reverse order of the notation used for
     47   // describing the dimensions of an array. That is, for a [4 x 3 x 2] array
     48   // dimension zero has size 2, dimension one has size 3, and dimension two has
     49   // size 4. Thus the index {1, 2, 3} indexes the last element of this [4 x 3 x
     50   // 2] array.
     51   //
     52   // This may also keep a linear index and the layout and dimensions it was
     53   // emitted for; if the shape where this `Index` is used matches, the linear
     54   // index may be used, potentially sparing the cost of computing the
     55   // multidimensional index, which LLVM DCE can delete.
     56   class Index {
     57    public:
     58     // Constructs an index for a scalar shape.
     59     explicit Index(llvm::Type* index_ty) : index_type_(index_ty) {
     60       CHECK(index_ty->isIntegerTy());
     61     }
     62 
     63     // Constructs an index from multi-dimensional index "multidim". The linear
     64     // index is set to nullptr.
     65     explicit Index(absl::Span<llvm::Value* const> multidim,
     66                    llvm::Type* index_ty = nullptr)
     67         : multidim_(multidim.begin(), multidim.end()) {
     68       if (size() == 0) {
     69         index_type_ = index_ty;
     70       } else {
     71         for (const auto* dim : multidim) {
     72           CHECK_NE(dim, nullptr);
     73         }
     74         index_type_ = multidim[0]->getType();
     75         if (index_ty != nullptr) {
     76           CHECK_EQ(index_type_, index_ty);
     77         }
     78       }
     79       CHECK_NE(index_type_, nullptr);
     80       CHECK(absl::c_all_of(multidim, [&](llvm::Value* v) {
     81         return index_type_ == v->getType();
     82       }));
     83     }
     84 
     85     // Constructs an index from linear index "linear" and computes the
     86     // multi-dimensional index from "linear" and "shape". "b" is the IR
     87     // builder to emit the index of each dimension in the multi-dimensional
     88     // index.
     89     //
     90     // Precondition: "shape" has a layout.
     91     Index(llvm::Value* linear, const Shape& shape, llvm::IRBuilder<>* b);
     92 
     93     // Constructs an index from a multi-dimensional index. 'shape' is the shape
     94     // for which the multi-dimensional index is used. 'index_type' is the type
     95     // of the index.
     96     //
     97     // Precondition: "shape" has a layout.
     98     Index(absl::Span<llvm::Value* const> multidim, const Shape& shape,
     99           llvm::Type* index_type);
    100 
    101     // Returns an index that adds `addend` to the given `dim` of the object.
    102     Index AddOffsetToDim(llvm::Value* addend, int64 dim,
    103                          llvm::IRBuilder<>* b) const {
    104       std::vector<llvm::Value*> multi_index = multidim();
    105       multi_index[dim] = b->CreateAdd(multi_index[dim], addend);
    106       return Index(multi_index, index_type_);
    107     }
    108 
    109     const std::vector<llvm::Value*>& multidim() const { return multidim_; }
    110     llvm::Value* linear() const { return linear_; }
    111 
    112     size_t size() const { return multidim().size(); }
    113 
    114     llvm::Value* operator[](size_t i) const { return multidim()[i]; }
    115 
    116     using const_iterator = std::vector<llvm::Value*>::const_iterator;
    117 
    118     const_iterator begin() const { return multidim().begin(); }
    119     const_iterator end() const { return multidim().end(); }
    120 
    121     bool LinearValidOnShape(const Shape& a) const;
    122 
    123     // Given that "this" is the target index of a reshape from `operand_shape`
    124     // to `shape`, returns the source index.
    125     Index SourceIndexOfReshape(const Shape& output_shape,
    126                                const Shape& input_shape,
    127                                llvm::IRBuilder<>* builder) const;
    128 
    129     // Returns the index into the source operand from which a slice operation
    130     // selects a value to be placed into index "this". The slice is described
    131     // by starting indices `starts` and stride values `strides`.
    132     //
    133     // Precondition: "this" is an index into a slice whose operand shape is
    134     // `operand_shape`.
    135     Index SourceIndexOfSlice(const Shape& operand_shape,
    136                              absl::Span<const int64> starts,
    137                              absl::Span<const int64> strides,
    138                              llvm::IRBuilder<>* builder) const;
    139 
    140     // Given that "this" is the target index of a transpose from `operand_shape`
    141     // to `shape` with the given dimension mapping, returns the source index.
    142     Index SourceIndexOfTranspose(const Shape& shape, const Shape& operand_shape,
    143                                  absl::Span<const int64> dimension_mapping,
    144                                  llvm::IRBuilder<>* builder) const;
    145 
    146     // Given that "this" is the target index of a bitcast from `operand_shape`
    147     // to `shape`, returns the source index.
    148     Index SourceIndexOfBitcast(const Shape& shape, const Shape& operand_shape,
    149                                llvm::IRBuilder<>* builder) const;
    150 
    151     // Given that "this" is the target index of a broadcast from `operand_shape`
    152     // to `shape` with the given dimension mapping, returns the source index.
    153     Index SourceIndexOfBroadcast(const Shape& shape, const Shape& operand_shape,
    154                                  absl::Span<const int64> dimension_mapping,
    155                                  llvm::IRBuilder<>* builder) const;
    156 
    157     // Linearizes the index into the given shape, i.e. reshapes it to rank-1 and
    158     // returns the index into the sole dimension 0 of the new shape.
    159     llvm::Value* Linearize(absl::Span<const int64> dimensions,
    160                            llvm::IRBuilder<>* builder) const;
    161 
    162     llvm::Type* GetType() const { return index_type_; }
    163 
    164     llvm::Constant* GetConstantWithIndexType(int64 c) const {
    165       // The LLVM function makes sure that the value can be represented by the
    166       // specified type, see ConstantInt::ConstantInt(IntegerType *Ty, const
    167       // APInt &V).
    168       return llvm::ConstantInt::get(index_type_, c);
    169     }
    170 
    171    private:
    172     // Constructs an index from both a multi-dimensional index and a linear
    173     // index. 'shape' is the shape on which the index is used. 'index_type' is
    174     // the type of the index.
    175     //
    176     // Precondition: "shape" has a layout.
    177     Index(absl::Span<llvm::Value* const> multidim, llvm::Value* linear,
    178           const Shape& shape, llvm::Type* index_type);
    179 
    180     void Delinearize(std::vector<llvm::Value*>* multidim, llvm::Value* linear,
    181                      const Shape& shape, llvm::IRBuilder<>* b) const;
    182 
    183     std::vector<llvm::Value*> multidim_;
    184 
    185     // These values are purely for efficiency; `multidim_` is enough to find the
    186     // element at a given `Index`, but if a loop is emitted with a linear index
    187     // space, that linear index can be saved in `linear_`, and the layout and
    188     // dimensions of the shape the loop was emitted for in `layout_` and
    189     // `dims_`, and if the `Index` is used in another array, and its layout and
    190     // dimensions match, the linear index can be used, sparing the cost of
    191     // computing `multidim_`, which LLVM DCE could potentially so delete.
    192     // Modifying `multidim_` after construction nullifies `linear_`, lest it
    193     // be used wrongly, as it would be valid no more.
    194     // If a loop is emitted with a multidimensional index space, `linear_` would
    195     // be null and `layout_` and `dims_` would be ignored.
    196     llvm::Value* linear_ = nullptr;
    197     Layout layout_;
    198     std::vector<int64> dims_;
    199 
    200     llvm::Type* index_type_;
    201   };
    202 
    203   // Default constructor. Constructs an IrArray in a null status.
    204   IrArray() : base_ptr_(nullptr) {}
    205 
    206   // Construct an IrArray with the given base pointer and shape. base_ptr is a
    207   // pointer type pointing to the first element(lowest address) of the array.
    208   IrArray(llvm::Value* base_ptr, Shape shape);
    209 
    210   // Default implementations of copying and moving.
    211   IrArray(IrArray&& other) = default;
    212   IrArray(const IrArray& other) = default;
    213   IrArray& operator=(IrArray&& other) = default;
    214   IrArray& operator=(const IrArray& other) = default;
    215 
    216   llvm::Value* GetBasePointer() const { return base_ptr_; }
    217   llvm::Type* GetElementLlvmType() const { return element_type_; }
    218 
    219   const Shape& GetShape() const { return shape_; }
    220 
    221   // Emit a sequence of instructions to compute the address of the element in
    222   // the given array at the given index. Returns the address of the element as
    223   // an LLVM Value.
    224   //
    225   // The optional name is useful for debugging when looking at
    226   // the emitted LLVM IR.
    227   llvm::Value* EmitArrayElementAddress(const Index& index, llvm::IRBuilder<>* b,
    228                                        absl::string_view name = "",
    229                                        bool use_linear_index = true) const;
    230 
    231   // Attach metadata this IrArray instance knows about to "instruction".
    232   void AnnotateLoadStoreInstructionWithMetadata(
    233       llvm::Instruction* instruction) const;
    234 
    235   // Emit IR to read an array element at the given index. Returns the read
    236   // result (effectively, a Value loaded from memory). This method seamlessly
    237   // handles scalar shapes by broadcasting their value to all indices (index is
    238   // ignored).
    239   //
    240   // The optional name is useful for debugging when looking at
    241   // the emitted LLVM IR.
    242   // 'use_linear_index' can be used to specify whether the linear index (if
    243   // available) or the multi-dimensional index should be used.
    244   llvm::Value* EmitReadArrayElement(const Index& index, llvm::IRBuilder<>* b,
    245                                     absl::string_view name = "",
    246                                     bool use_linear_index = true) const;
    247 
    248   // Emit IR to write the given value to the array element at the given index.
    249   // 'use_linear_index' can be used to specify whether the linear index (if
    250   // available) or the multi-dimensional index should be used.
    251   void EmitWriteArrayElement(const Index& index, llvm::Value* value,
    252                              llvm::IRBuilder<>* b,
    253                              bool use_linear_index = true) const;
    254 
    255   // Returns a new IrArray whose shape is "new_shape" and base pointer is a
    256   // bitcast of the base pointer of "this" IrArray.
    257   // 'use_linear_index' can be used to specify whether the linear index (if
    258   // available) or the multi-dimensional index should be used.
    259   IrArray CastToShape(const Shape& new_shape, llvm::IRBuilder<>* b) const;
    260 
    261   void AddAliasScopeMetadata(llvm::MDNode* alias_scope) {
    262     CHECK_NE(alias_scope, nullptr);
    263     AddMetadata(llvm::LLVMContext::MD_alias_scope, alias_scope);
    264   }
    265 
    266   void AddNoaliasMetadata(llvm::MDNode* noalias) {
    267     CHECK_NE(noalias, nullptr);
    268     AddMetadata(llvm::LLVMContext::MD_noalias, noalias);
    269   }
    270 
    271   // Promises LLVM that the data pointed to by this IrArray never changes after
    272   // it's first loaded.
    273   //
    274   // The temporal scope of this promise is the "whole program" from LLVM's point
    275   // of view, but how this translates to HLOs differs between backends.
    276   //
    277   // In the single-threaded CPU backend, we emit one function that
    278   // runs all the HLOs in sequence, so the whole program is the whole HLO
    279   // module.
    280   //
    281   // In the GPU backend, we emit one GPU kernel per top-level HLO (i.e. per HLO
    282   // in the entry computation).  From LLVM's perspective, launching a new kernel
    283   // is like launching a new program, and so the whole program is one top-level
    284   // HLO.  Since the scope of the promise is smaller than in the CPU backend, we
    285   // can mark more things as invariant in the GPU backend.
    286   //
    287   // Marking loads as invariant is particularly helpful on GPUs because
    288   // invariant loads can be lowered to PTX ld.global.nc (equivalent to CUDA's
    289   // __ldg intrinsic).  These loads use a special cache, and can be
    290   // significantly faster than regular loads.
    291   void MarkInvariantOverWholeProgram(llvm::LLVMContext* context) {
    292     if (is_invariant_) {
    293       return;
    294     }
    295     is_invariant_ = true;
    296     AddMetadata(llvm::LLVMContext::MD_invariant_load,
    297                 llvm::MDNode::get(*context, {}));
    298   }
    299 
    300   const std::map<int, llvm::MDNode*>& metadata() const { return metadata_; }
    301 
    302  private:
    303   // Add the specified LLVM IR metadata to loads/stores associated with this
    304   // IrArray.
    305   void AddMetadata(int kind, llvm::MDNode* md) {
    306     InsertOrDie(&metadata_, kind, md);
    307   }
    308 
    309   // Address of the base of the array as an LLVM Value.
    310   llvm::Value* base_ptr_;
    311 
    312   // The LLVM type of the elements in the array.
    313   llvm::Type* element_type_;
    314 
    315   // Shape of the XLA array.
    316   Shape shape_;
    317 
    318   // The list of key/value pairs used when attaching metadata to emitted
    319   // loads/stores for this array.  They keys are the metadata kinds and the
    320   // values are the metadata nodes.
    321   std::map<int, llvm::MDNode*> metadata_;
    322 
    323   bool is_invariant_ = false;
    324 };
    325 
    326 }  // namespace llvm_ir
    327 }  // namespace xla
    328 
    329 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_ARRAY_H_
    330