1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h" 17 #include "tensorflow/compiler/xla/layout_util.h" 18 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" 19 #include "tensorflow/compiler/xla/shape_util.h" 20 #include "tensorflow/compiler/xla/statusor.h" 21 #include "tensorflow/compiler/xla/util.h" 22 #include "tensorflow/core/platform/logging.h" 23 24 namespace xla { 25 namespace llvm_ir { 26 27 namespace { 28 // Returns the indices of the first elements of all consecutive subarrays of the 29 // given array. For example: 30 // ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4} 31 std::vector<size_t> ConsecutiveSegments(absl::Span<const int64> xs) { 32 std::vector<size_t> is = {0}; 33 for (size_t i = 1; i < xs.size(); ++i) { 34 if (1 != xs[i] - xs[i - 1]) { 35 is.push_back(i); 36 } 37 } 38 return is; 39 } 40 41 // Merges the sequences of dimensions of the given shape which start at the 42 // given indices `segs`. 43 Shape MergeDimensions(absl::Span<const size_t> segs, const Shape& shape) { 44 std::vector<int64> dimensions; 45 for (size_t i = 1; i <= segs.size(); ++i) { 46 dimensions.push_back(std::accumulate( 47 shape.dimensions().begin() + segs[i - 1], 48 shape.dimensions().begin() + 49 (segs.size() == i ? shape.dimensions().size() : segs[i]), 50 1, std::multiplies<int64>())); 51 } 52 return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(), 53 dimensions); 54 } 55 56 // Given an index for a shape, return the equivalent new index if the shape is 57 // reshaped to another shape. 58 IrArray::Index GetReshapedIndex(const IrArray::Index& index, const Shape& shape, 59 const Shape& reshaped_shape, 60 llvm::IRBuilder<>* b) { 61 auto bounds = shape.dimensions(); 62 auto minor_to_major = shape.layout().minor_to_major(); 63 llvm::Value* linear_index = index.GetConstantWithIndexType(0); 64 int64 multiplier = 1; 65 for (int i = 0; i < index.size(); ++i) { 66 int64 dim = minor_to_major[i]; 67 llvm::Value* addend = b->CreateMul( 68 index[dim], index.GetConstantWithIndexType(multiplier), "linearizing", 69 /*HasNUW=*/true, /*HasNSW=*/true); 70 linear_index = b->CreateAdd(linear_index, addend, "", 71 /*HasNUW=*/true, /*HasNSW=*/true); 72 multiplier *= bounds[dim]; 73 } 74 75 return IrArray::Index(linear_index, reshaped_shape, b); 76 } 77 78 } // namespace 79 80 absl::optional<std::vector<int64> > FindTranspose021(const Shape& a, 81 const Shape& b) { 82 if (!ShapeUtil::CompatibleIgnoringElementType(a, b)) { 83 return absl::nullopt; 84 } 85 86 std::vector<int64> permutation(a.dimensions().size()); 87 absl::Span<const int64> minor_to_major_a = LayoutUtil::MinorToMajor(a); 88 std::vector<int64> major_to_minor_a(minor_to_major_a.rbegin(), 89 minor_to_major_a.rend()); 90 absl::Span<const int64> minor_to_major_b = LayoutUtil::MinorToMajor(b); 91 std::vector<int64> major_to_minor_b(minor_to_major_b.rbegin(), 92 minor_to_major_b.rend()); 93 for (size_t i = 0; i < permutation.size(); ++i) { 94 permutation[i] = PositionInContainer(major_to_minor_b, major_to_minor_a[i]); 95 } 96 97 std::vector<size_t> segments = ConsecutiveSegments(permutation); 98 if ((3 == segments.size() && 0 == permutation[0]) || 2 == segments.size()) { 99 Shape descending_layout_shape = 100 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a); 101 Shape normalized_shape = MergeDimensions(segments, descending_layout_shape); 102 absl::Span<const int64> normalized_dims = 103 AsInt64Slice(normalized_shape.dimensions()); 104 std::vector<int64> dims_021; 105 if (2 == segments.size()) { 106 // The logical component-0 is of size one. 107 dims_021 = {1, normalized_dims[1], normalized_dims[0]}; 108 } else { 109 dims_021 = {normalized_dims[0], normalized_dims[2], normalized_dims[1]}; 110 } 111 112 return dims_021; 113 } 114 115 return absl::nullopt; 116 } 117 118 KernelMappingScheme::KernelMappingScheme( 119 absl::Span<const int64> dims_in_elems, int64 tile_size_y, int64 tile_size_x, 120 absl::Span<const int64> req_block_sizes, int64 num_threads_y, 121 int64 num_threads_x, llvm::IRBuilder<>* b) 122 : b_(b), 123 dims_in_elems_(dims_in_elems.begin(), dims_in_elems.end()), 124 tile_sizes_{1, tile_size_y, tile_size_x}, 125 num_threads_x_(num_threads_x), 126 num_threads_y_(num_threads_y), 127 dilated_x_(true) { 128 DCHECK_EQ(dims_in_elems_.size(), 3); 129 DCHECK_EQ(req_block_sizes.size(), 3); 130 131 DCHECK_EQ(tile_size_y % num_threads_y_, 0); 132 DCHECK_EQ(tile_size_x % num_threads_x_, 0); 133 134 dims_in_tiles_ = ElementWiseCeilOfRatio<int64>(dims_in_elems_, tile_sizes_); 135 block_sizes_.reserve(req_block_sizes.size()); 136 absl::c_transform(req_block_sizes, dims_in_tiles_, 137 std::back_inserter(block_sizes_), 138 [](const int64 requested_size, const int64 max_size) { 139 return std::min(requested_size, max_size); 140 }); 141 dims_in_blocks_ = ElementWiseCeilOfRatio<int64>(dims_in_tiles_, block_sizes_); 142 143 VLOG(10) << "dims_in_elems_ = [" << absl::StrJoin(dims_in_elems_, ",") << "]"; 144 VLOG(10) << "dims_in_tiles_ = [" << absl::StrJoin(dims_in_tiles_, ",") << "]"; 145 VLOG(10) << "dims_in_blocks_ = [" << absl::StrJoin(dims_in_blocks_, ",") 146 << "]"; 147 } 148 149 IrArray::Index KernelMappingScheme::GetUnnormalizedIndex( 150 const IrArray::Index& normalized_shape_index, 151 const Shape& unnormalized_shape) { 152 DCHECK_EQ(normalized_shape_index.size(), dims_in_elems_.size()); 153 Shape output_shape = ShapeUtil::MakeShapeWithDescendingLayout( 154 unnormalized_shape.element_type(), GetDimensionsInElements()); 155 return GetReshapedIndex(normalized_shape_index, output_shape, 156 unnormalized_shape, b_); 157 } 158 159 IrArray::Index KernelMappingScheme::EmitBlockIndex(llvm::Type* index_ty) { 160 llvm::Value* block_id = llvm_ir::EmitCallToIntrinsic( 161 llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, b_); 162 llvm_ir::AddRangeMetadata(0, GetNumberOfBlocks(), 163 llvm::cast<llvm::Instruction>(block_id)); 164 llvm::Value* linear_block_id = 165 b_->CreateIntCast(block_id, index_ty, /*isSigned=*/true, "block.id.x"); 166 return IrArray::Index(linear_block_id, 167 ShapeUtil::MakeShapeWithDescendingLayout( 168 PRED /*arbitrary*/, dims_in_blocks_), 169 b_); 170 } 171 172 IrArray::Index KernelMappingScheme::GetTileIndexForBlockOrigin( 173 const IrArray::Index& block_index) { 174 DCHECK_EQ(block_index.size(), block_sizes_.size()); 175 std::vector<llvm::Value*> multidim; 176 multidim.reserve(block_sizes_.size()); 177 for (int i = 0; i < block_sizes_.size(); ++i) { 178 multidim.push_back(b_->CreateMul( 179 block_index[i], 180 llvm::ConstantInt::get(block_index[i]->getType(), block_sizes_[i]), 181 "block_origin." + std::to_string(i))); 182 } 183 return IrArray::Index(multidim, block_index[0]->getType()); 184 } 185 186 IrArray::Index KernelMappingScheme::GetElementIndexForTileOrigin( 187 const IrArray::Index& tile_index) { 188 std::vector<llvm::Value*> elem_multi_index = tile_index.multidim(); 189 for (int i = DimY; i < DimTot; ++i) { 190 elem_multi_index[i] = 191 b_->CreateMul(tile_index[i], 192 llvm::ConstantInt::get(tile_index[i]->getType(), 193 GetTileSizeForDimension(i)), 194 "tile_origin." + std::to_string(i)); 195 } 196 return IrArray::Index(elem_multi_index, tile_index.GetType()); 197 } 198 199 llvm::GlobalVariable* KernelMappingScheme::GetSharedMemoryBufferForElementType( 200 llvm::Type* elem_ty, absl::string_view buffer_name) { 201 // If shared memory tranpose is needed, we use square tiles. 202 CHECK_EQ(GetTileSizeForDimensionX(), GetTileSizeForDimensionY()); 203 204 // For Nvidia GPUs, the warp size is 32 threads and the shared memory bank is 205 // organized into 32-way. We usually use the warp size or a multiplier or a 206 // the warp size as the size for tiling. This may cause all elements in the 207 // same column of a tile use the same memory bank and therefore shared memory 208 // bank conflicts. Adding 1 to the minor dimension of the shared memory buffer 209 // can reduce such shared memory bank conflicts. 210 llvm::Type* buffer_type = llvm::ArrayType::get( 211 llvm::ArrayType::get(elem_ty, GetTileSizeForDimension(DimX) + 1), 212 GetTileSizeForDimension(DimY)); 213 return llvm_ir::AllocateSharedMemoryTile(b_->GetInsertBlock()->getModule(), 214 buffer_type, buffer_name); 215 } 216 217 std::tuple<llvm::Value*, llvm::Value*> 218 KernelMappingScheme::EmitThreadYXCoordinate(llvm::Type* index_ty) { 219 // Calculate (y, x) coordinate of the thread in the 2D view of thread block 220 // defined by (num_thread_y, num_thread_x) from thread_id. 221 llvm::CallInst* thread_id_raw = llvm_ir::EmitCallToIntrinsic( 222 llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b_); 223 llvm_ir::AddRangeMetadata(0, GetThreadsPerBlock(), thread_id_raw); 224 llvm::Value* thread_id_int = 225 b_->CreateIntCast(thread_id_raw, index_ty, 226 /*isSigned=*/true, "thread.id.x"); 227 llvm::Value* num_thread_x = 228 llvm::ConstantInt::get(index_ty, GetNumberOfThreadsForDimensionX()); 229 llvm::Value* x = b_->CreateURem(thread_id_int, num_thread_x, "thread.x"); 230 llvm::Value* y = b_->CreateUDiv(thread_id_int, num_thread_x, "thread.y"); 231 return std::make_tuple(y, x); 232 } 233 234 } // namespace llvm_ir 235 } // namespace xla 236