Home | History | Annotate | Download | only in llvm_ir
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_TILING_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_TILING_H_
     18 
     19 #include "llvm/IR/Value.h"
     20 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     21 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
     22 
     23 namespace xla {
     24 namespace llvm_ir {
     25 
     26 // About 0-2-1 transpose:
     27 //
     28 // If a shape can be viewed as three logical components 0-1-2 in the order of
     29 // major to minor, a 0-2-1-transpose changes the order of such logical
     30 // components to 0-2-1. We call the shape being transposed the input shape and
     31 // the transposed shape the output shape. The logical view of the input/output
     32 // shapes for the transpose are called the 0-1-2/0-2-1 shapes or the normalized
     33 // shapes. The original input/output shapes are called unnormalized shapes.
     34 //
     35 // If `b` is a 0-2-1 transpose of `a` in 0-1-2, return the dimensions for the
     36 // normalized shape of `b` or the 0-2-1 shape.
     37 absl::optional<std::vector<int64> > FindTranspose021(const Shape& a,
     38                                                      const Shape& b);
     39 
     40 // A tile is a spatial subdivision of a tensor. We group tensor elements into
     41 // tiles so that we can launch kernels to process the tensor elements in blocks
     42 // of tiles.
     43 //
     44 // A kernel mapping scheme describes a method to partition the tensors accessed
     45 // by an unnested HLO instruction into tiles and blocks of tiles, and the
     46 // associated information to use hardware threads to process the tensor elements
     47 // in blocks of tiles.
     48 //
     49 // Currently, there are two main use cases for a tiling scheme. First, we
     50 // implement kernels with 0-2-1 memory transpose using shared memory to improve
     51 // memory access pattern. Second, we implement reduction to contiguous
     52 // dimensions in layout, with or without memory tranpsose, to achieve better
     53 // memory access pattern as well as to reduce the need numbers of executed
     54 // expensive instructions, such as thread synchronization related instructions
     55 // and atomic operations. For both use cases, we can apply a normalization to
     56 // the original tensors, to collapse contiguous dimensions for the same purpose
     57 // and produce normlized three dimensional tensors. For this reason, the tiling
     58 // scheme class only needs to handle normalized three dimensional tensors and
     59 // two dimensional tiles.
     60 //
     61 // The current implementation of the class is somewhat NVIDIA GPU oriented. This
     62 // situation can be improved when there is a need though. The idea of 0-2-1
     63 // transpose using shared memory can be found in the following CUDA algorithm in
     64 // TensorFlow: https://goo.gl/MStRV6.
     65 //
     66 // We use a thread block to process a tile because we want to use the HW thread
     67 // block synchronization primitives to synchronize the processing of all the
     68 // elements in the same tile. A thread block can be viewed as a two dimensional
     69 // array of threads, described by the number of threads for the Y and X
     70 // dimensions. A thread block (num_threads_y, num_threads_x) processes a tile of
     71 // (tile_size_y, tile_size_x) as follows: each thread in the thread block
     72 // processes one element in the tile so that all the threads in the thread block
     73 // together process a subdivision of the tile that has the same dimension as the
     74 // thread block array. Then the thread block moves on to process the next
     75 // subdivision of the tile until the whole tile is processed. Therefore, each
     76 // thread in the thread block processes
     77 // tile_size_x/num_threads_x * tile_size_y/num_threads_y elements in a tile.
     78 //
     79 // There are situations where we want a thread block to process multiple
     80 // tiles. We can't group those tiles into a bigger tiles because we limit a tile
     81 // to a two dimensional spatial subdivision of a tensor. For example, when we
     82 // use tiling to implement reduction with tranpose, we want the partial sum
     83 // produced by each thread to accumulate values for more elements before using
     84 // shlf_down and atomic_add instructions for further reduction, to amortize the
     85 // cost of such expensive instructions. The concept of tile block is introduced
     86 // for this purpose. A tile block is a three dimensional array of tiles, of
     87 // which some dimensions may be degenerated to only one tile.
     88 class KernelMappingScheme {
     89  public:
     90   enum { DimZ = 0, DimY, DimX, DimTot };
     91 
     92  public:
     93   KernelMappingScheme() {}
     94   // dims_in_elems: the normalized tensor dimensions.
     95   // req_block_sizes: the requested block size in number of tiles for each
     96   //   dimension. The actual block size is set to min(req_block_size,
     97   //   dims_in_number_of_blocks).
     98   KernelMappingScheme(absl::Span<const int64> dims_in_elems, int64 tile_size_y,
     99                       int64 tile_size_x,
    100                       absl::Span<const int64> req_block_sizes,
    101                       int64 num_threads_y, int64 num_threads_x,
    102                       llvm::IRBuilder<>* b);
    103 
    104   absl::Span<const int64> GetDimensionsInElements() const {
    105     return dims_in_elems_;
    106   }
    107   absl::Span<const int64> GetDimensionsInTiles() const {
    108     return dims_in_tiles_;
    109   }
    110   absl::Span<const int64> GetDimensionsInBlocks() const {
    111     return dims_in_blocks_;
    112   }
    113 
    114   int64 GetNumberOfTilesInTotal() const {
    115     return absl::c_accumulate(dims_in_tiles_, 1LL, std::multiplies<int64>());
    116   }
    117   int64 GetNumberOfTilesInOneBlock() const {
    118     return absl::c_accumulate(block_sizes_, 1, std::multiplies<int64>());
    119   }
    120   int64 GetNumberOfTilesInOneBlockForDimension(int d) const {
    121     DCHECK(d >= DimZ && d <= DimX);
    122     return block_sizes_[d];
    123   }
    124   int64 GetNumberOfBlocks() const {
    125     return absl::c_accumulate(dims_in_blocks_, 1, std::multiplies<int64>());
    126   }
    127 
    128   int64 GetTileSizeForDimension(int d) const {
    129     DCHECK(d >= DimZ && d <= DimX);
    130     return tile_sizes_[d];
    131   }
    132   int64 GetTileSizeForDimensionX() const {
    133     return GetTileSizeForDimension(DimX);
    134   }
    135   int64 GetTileSizeForDimensionY() const {
    136     return GetTileSizeForDimension(DimY);
    137   }
    138 
    139   absl::Span<const int64> GetBlockSizes() const { return block_sizes_; }
    140   int64 GetTileBlockSizeForDimension(int d) const {
    141     DCHECK(d >= DimZ && d <= DimX);
    142     return dims_in_blocks_[d];
    143   }
    144 
    145   int64 GetNumberOfThreadsForDimensionX() const { return num_threads_x_; }
    146   int64 GetNumberOfThreadsForDimensionY() const { return num_threads_y_; }
    147 
    148   int64 GetThreadsPerBlock() const {
    149     return GetNumberOfThreadsForDimensionX() *
    150            GetNumberOfThreadsForDimensionY();
    151   }
    152 
    153   bool DilatedX() const { return dilated_x_; }
    154   void SetDilatedX(bool v) {
    155     dilated_x_ = v;
    156     if (!dilated_x_) {
    157       // dilated_x_=false is for the purpose of vectorization, which requires
    158       // GetTileSizeForDimension(DimX) to be a multiplier of num_threads_x_.
    159       CHECK_EQ(GetTileSizeForDimension(DimX) % num_threads_x_, 0);
    160     }
    161   }
    162 
    163   IrArray::Index EmitBlockIndex(llvm::Type* index_ty);
    164   // Returns the index for the first tile in the block with the given block
    165   // index.
    166   IrArray::Index GetTileIndexForBlockOrigin(const IrArray::Index& block_index);
    167   // Returns the index for the first element in the tile with the given tile
    168   // index.
    169   IrArray::Index GetElementIndexForTileOrigin(const IrArray::Index& tile_index);
    170 
    171   std::tuple<llvm::Value*, llvm::Value*> EmitThreadYXCoordinate(
    172       llvm::Type* index_ty);
    173 
    174   IrArray::Index GetUnnormalizedIndex(
    175       const IrArray::Index& normalized_shape_index,
    176       const Shape& unnormalized_shape);
    177 
    178   llvm::GlobalVariable* GetSharedMemoryBufferForElementType(
    179       llvm::Type* elem_ty, absl::string_view buffer_name);
    180 
    181  private:
    182   llvm::IRBuilder<>* b_;
    183   // The number of elements in each dimension.
    184   std::vector<int64> dims_in_elems_;
    185 
    186   // The number of elements for each dimension of a tile.
    187   std::vector<int64> tile_sizes_;
    188   // The number of tiles in each dimension. It is computed from dims_in_elem_
    189   // and tile_sizes_.
    190   std::vector<int64> dims_in_tiles_;
    191 
    192   // The number of tiles for each dimension of a tile block.
    193   std::vector<int64> block_sizes_;
    194   // The number of blocks in each dimension of a tile block. It is computed from
    195   // dims_in_tile_ and block_sizes_.
    196   std::vector<int64> dims_in_blocks_;
    197 
    198   // Number of threads used to process elements in the X direction of a tile.
    199   int64 num_threads_x_;
    200   // Number of threads used to process elements in the Y direction of a tile.
    201   int64 num_threads_y_;
    202 
    203   // When num_threads_x threads process a total of tile_size_x elements in the
    204   // X dimension of a tile, each threads process n=tile_size_x/num_threads_x
    205   // elements. When dilated_x=false, the n elements processed by a thread are
    206   // contiguous. On the other hand, when dilated_x=true the n elements are
    207   // dilated by a factor of num_threads_x.
    208   bool dilated_x_;
    209 };
    210 
    211 // A class to represent information for tiled parameters to support IR emission
    212 // for 021 transpose.
    213 class TiledParameterInfo {
    214  public:
    215   TiledParameterInfo(absl::Span<llvm::Value* const> param_buffers,
    216                      llvm::Value* y, llvm::Value* x)
    217       : param_buffers_(param_buffers), y_(y), x_(x) {}
    218 
    219   llvm::Value* x() const { return x_; }
    220   llvm::Value* y() const { return y_; }
    221 
    222   void set_x(llvm::Value* x) { x_ = x; }
    223   void set_y(llvm::Value* y) { y_ = y; }
    224 
    225   llvm::Value* GetBufferForParameter(int64 index) const {
    226     return param_buffers_[index];
    227   }
    228 
    229  private:
    230   // Param_buffers_[i] stores the tile buffer for the ith parameter or nullptr
    231   // if the parameter is not tiled.
    232   absl::Span<llvm::Value* const> param_buffers_;
    233   // The y coordinate within a tile.
    234   llvm::Value* y_;
    235   // The x coordinate within a tile.
    236   llvm::Value* x_;
    237 };
    238 
    239 }  // namespace llvm_ir
    240 }  // namespace xla
    241 
    242 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_TILING_H_
    243