Home | History | Annotate | Download | only in llvm_ir
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h"
     17 #include "tensorflow/compiler/xla/layout_util.h"
     18 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
     19 #include "tensorflow/compiler/xla/shape_util.h"
     20 #include "tensorflow/compiler/xla/statusor.h"
     21 #include "tensorflow/compiler/xla/util.h"
     22 #include "tensorflow/core/platform/logging.h"
     23 
     24 namespace xla {
     25 namespace llvm_ir {
     26 
     27 namespace {
     28 // Returns the indices of the first elements of all consecutive subarrays of the
     29 // given array. For example:
     30 // ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4}
     31 std::vector<size_t> ConsecutiveSegments(absl::Span<const int64> xs) {
     32   std::vector<size_t> is = {0};
     33   for (size_t i = 1; i < xs.size(); ++i) {
     34     if (1 != xs[i] - xs[i - 1]) {
     35       is.push_back(i);
     36     }
     37   }
     38   return is;
     39 }
     40 
     41 // Merges the sequences of dimensions of the given shape which start at the
     42 // given indices `segs`.
     43 Shape MergeDimensions(absl::Span<const size_t> segs, const Shape& shape) {
     44   std::vector<int64> dimensions;
     45   for (size_t i = 1; i <= segs.size(); ++i) {
     46     dimensions.push_back(std::accumulate(
     47         shape.dimensions().begin() + segs[i - 1],
     48         shape.dimensions().begin() +
     49             (segs.size() == i ? shape.dimensions().size() : segs[i]),
     50         1, std::multiplies<int64>()));
     51   }
     52   return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
     53                                                   dimensions);
     54 }
     55 
     56 // Given an index for a shape, return the equivalent new index if the shape is
     57 // reshaped to another shape.
     58 IrArray::Index GetReshapedIndex(const IrArray::Index& index, const Shape& shape,
     59                                 const Shape& reshaped_shape,
     60                                 llvm::IRBuilder<>* b) {
     61   auto bounds = shape.dimensions();
     62   auto minor_to_major = shape.layout().minor_to_major();
     63   llvm::Value* linear_index = index.GetConstantWithIndexType(0);
     64   int64 multiplier = 1;
     65   for (int i = 0; i < index.size(); ++i) {
     66     int64 dim = minor_to_major[i];
     67     llvm::Value* addend = b->CreateMul(
     68         index[dim], index.GetConstantWithIndexType(multiplier), "linearizing",
     69         /*HasNUW=*/true, /*HasNSW=*/true);
     70     linear_index = b->CreateAdd(linear_index, addend, "",
     71                                 /*HasNUW=*/true, /*HasNSW=*/true);
     72     multiplier *= bounds[dim];
     73   }
     74 
     75   return IrArray::Index(linear_index, reshaped_shape, b);
     76 }
     77 
     78 }  // namespace
     79 
     80 absl::optional<std::vector<int64> > FindTranspose021(const Shape& a,
     81                                                      const Shape& b) {
     82   if (!ShapeUtil::CompatibleIgnoringElementType(a, b)) {
     83     return absl::nullopt;
     84   }
     85 
     86   std::vector<int64> permutation(a.dimensions().size());
     87   absl::Span<const int64> minor_to_major_a = LayoutUtil::MinorToMajor(a);
     88   std::vector<int64> major_to_minor_a(minor_to_major_a.rbegin(),
     89                                       minor_to_major_a.rend());
     90   absl::Span<const int64> minor_to_major_b = LayoutUtil::MinorToMajor(b);
     91   std::vector<int64> major_to_minor_b(minor_to_major_b.rbegin(),
     92                                       minor_to_major_b.rend());
     93   for (size_t i = 0; i < permutation.size(); ++i) {
     94     permutation[i] = PositionInContainer(major_to_minor_b, major_to_minor_a[i]);
     95   }
     96 
     97   std::vector<size_t> segments = ConsecutiveSegments(permutation);
     98   if ((3 == segments.size() && 0 == permutation[0]) || 2 == segments.size()) {
     99     Shape descending_layout_shape =
    100         ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a);
    101     Shape normalized_shape = MergeDimensions(segments, descending_layout_shape);
    102     absl::Span<const int64> normalized_dims =
    103         AsInt64Slice(normalized_shape.dimensions());
    104     std::vector<int64> dims_021;
    105     if (2 == segments.size()) {
    106       // The logical component-0 is of size one.
    107       dims_021 = {1, normalized_dims[1], normalized_dims[0]};
    108     } else {
    109       dims_021 = {normalized_dims[0], normalized_dims[2], normalized_dims[1]};
    110     }
    111 
    112     return dims_021;
    113   }
    114 
    115   return absl::nullopt;
    116 }
    117 
    118 KernelMappingScheme::KernelMappingScheme(
    119     absl::Span<const int64> dims_in_elems, int64 tile_size_y, int64 tile_size_x,
    120     absl::Span<const int64> req_block_sizes, int64 num_threads_y,
    121     int64 num_threads_x, llvm::IRBuilder<>* b)
    122     : b_(b),
    123       dims_in_elems_(dims_in_elems.begin(), dims_in_elems.end()),
    124       tile_sizes_{1, tile_size_y, tile_size_x},
    125       num_threads_x_(num_threads_x),
    126       num_threads_y_(num_threads_y),
    127       dilated_x_(true) {
    128   DCHECK_EQ(dims_in_elems_.size(), 3);
    129   DCHECK_EQ(req_block_sizes.size(), 3);
    130 
    131   DCHECK_EQ(tile_size_y % num_threads_y_, 0);
    132   DCHECK_EQ(tile_size_x % num_threads_x_, 0);
    133 
    134   dims_in_tiles_ = ElementWiseCeilOfRatio<int64>(dims_in_elems_, tile_sizes_);
    135   block_sizes_.reserve(req_block_sizes.size());
    136   absl::c_transform(req_block_sizes, dims_in_tiles_,
    137                     std::back_inserter(block_sizes_),
    138                     [](const int64 requested_size, const int64 max_size) {
    139                       return std::min(requested_size, max_size);
    140                     });
    141   dims_in_blocks_ = ElementWiseCeilOfRatio<int64>(dims_in_tiles_, block_sizes_);
    142 
    143   VLOG(10) << "dims_in_elems_ = [" << absl::StrJoin(dims_in_elems_, ",") << "]";
    144   VLOG(10) << "dims_in_tiles_ = [" << absl::StrJoin(dims_in_tiles_, ",") << "]";
    145   VLOG(10) << "dims_in_blocks_ = [" << absl::StrJoin(dims_in_blocks_, ",")
    146            << "]";
    147 }
    148 
    149 IrArray::Index KernelMappingScheme::GetUnnormalizedIndex(
    150     const IrArray::Index& normalized_shape_index,
    151     const Shape& unnormalized_shape) {
    152   DCHECK_EQ(normalized_shape_index.size(), dims_in_elems_.size());
    153   Shape output_shape = ShapeUtil::MakeShapeWithDescendingLayout(
    154       unnormalized_shape.element_type(), GetDimensionsInElements());
    155   return GetReshapedIndex(normalized_shape_index, output_shape,
    156                           unnormalized_shape, b_);
    157 }
    158 
    159 IrArray::Index KernelMappingScheme::EmitBlockIndex(llvm::Type* index_ty) {
    160   llvm::Value* block_id = llvm_ir::EmitCallToIntrinsic(
    161       llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, b_);
    162   llvm_ir::AddRangeMetadata(0, GetNumberOfBlocks(),
    163                             llvm::cast<llvm::Instruction>(block_id));
    164   llvm::Value* linear_block_id =
    165       b_->CreateIntCast(block_id, index_ty, /*isSigned=*/true, "block.id.x");
    166   return IrArray::Index(linear_block_id,
    167                         ShapeUtil::MakeShapeWithDescendingLayout(
    168                             PRED /*arbitrary*/, dims_in_blocks_),
    169                         b_);
    170 }
    171 
    172 IrArray::Index KernelMappingScheme::GetTileIndexForBlockOrigin(
    173     const IrArray::Index& block_index) {
    174   DCHECK_EQ(block_index.size(), block_sizes_.size());
    175   std::vector<llvm::Value*> multidim;
    176   multidim.reserve(block_sizes_.size());
    177   for (int i = 0; i < block_sizes_.size(); ++i) {
    178     multidim.push_back(b_->CreateMul(
    179         block_index[i],
    180         llvm::ConstantInt::get(block_index[i]->getType(), block_sizes_[i]),
    181         "block_origin." + std::to_string(i)));
    182   }
    183   return IrArray::Index(multidim, block_index[0]->getType());
    184 }
    185 
    186 IrArray::Index KernelMappingScheme::GetElementIndexForTileOrigin(
    187     const IrArray::Index& tile_index) {
    188   std::vector<llvm::Value*> elem_multi_index = tile_index.multidim();
    189   for (int i = DimY; i < DimTot; ++i) {
    190     elem_multi_index[i] =
    191         b_->CreateMul(tile_index[i],
    192                       llvm::ConstantInt::get(tile_index[i]->getType(),
    193                                              GetTileSizeForDimension(i)),
    194                       "tile_origin." + std::to_string(i));
    195   }
    196   return IrArray::Index(elem_multi_index, tile_index.GetType());
    197 }
    198 
    199 llvm::GlobalVariable* KernelMappingScheme::GetSharedMemoryBufferForElementType(
    200     llvm::Type* elem_ty, absl::string_view buffer_name) {
    201   // If shared memory tranpose is needed, we use square tiles.
    202   CHECK_EQ(GetTileSizeForDimensionX(), GetTileSizeForDimensionY());
    203 
    204   // For Nvidia GPUs, the warp size is 32 threads and the shared memory bank is
    205   // organized into 32-way. We usually use the warp size or a multiplier or a
    206   // the warp size as the size for tiling. This may cause all elements in the
    207   // same column of a tile use the same memory bank and therefore shared memory
    208   // bank conflicts. Adding 1 to the minor dimension of the shared memory buffer
    209   // can reduce such shared memory bank conflicts.
    210   llvm::Type* buffer_type = llvm::ArrayType::get(
    211       llvm::ArrayType::get(elem_ty, GetTileSizeForDimension(DimX) + 1),
    212       GetTileSizeForDimension(DimY));
    213   return llvm_ir::AllocateSharedMemoryTile(b_->GetInsertBlock()->getModule(),
    214                                            buffer_type, buffer_name);
    215 }
    216 
    217 std::tuple<llvm::Value*, llvm::Value*>
    218 KernelMappingScheme::EmitThreadYXCoordinate(llvm::Type* index_ty) {
    219   // Calculate (y, x) coordinate of the thread in the 2D view of thread block
    220   // defined by (num_thread_y, num_thread_x) from thread_id.
    221   llvm::CallInst* thread_id_raw = llvm_ir::EmitCallToIntrinsic(
    222       llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b_);
    223   llvm_ir::AddRangeMetadata(0, GetThreadsPerBlock(), thread_id_raw);
    224   llvm::Value* thread_id_int =
    225       b_->CreateIntCast(thread_id_raw, index_ty,
    226                         /*isSigned=*/true, "thread.id.x");
    227   llvm::Value* num_thread_x =
    228       llvm::ConstantInt::get(index_ty, GetNumberOfThreadsForDimensionX());
    229   llvm::Value* x = b_->CreateURem(thread_id_int, num_thread_x, "thread.x");
    230   llvm::Value* y = b_->CreateUDiv(thread_id_int, num_thread_x, "thread.y");
    231   return std::make_tuple(y, x);
    232 }
    233 
    234 }  // namespace llvm_ir
    235 }  // namespace xla
    236