1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h" 17 18 #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" 19 #include "tensorflow/compiler/xla/service/hlo_module.h" 20 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" 21 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" 22 23 namespace xla { 24 namespace cpu { 25 namespace { 26 27 using tensorflow::int64; 28 29 // Provides tiled access to an in-memory rank 2 array. 30 class MemoryTile { 31 public: 32 // Constructs a MemoryTile that can operate on tiles consisting of 33 // `tile_size_along_major_dim` vectors from the matrix `matrix`, starting at 34 // `major_dim_offset` in the major dimension. The tile size along the minor 35 // dimension is the vector size, and that is implicitly determined by `vsl`. 36 MemoryTile(VectorSupportLibrary* vsl, llvm::IRBuilder<>* b, 37 llvm::Value* matrix, int64 matrix_size_along_minor_dim, 38 llvm::Value* major_dim_offset, int64 tile_size_along_major_dim) 39 : vsl_(vsl), b_(b) { 40 pointers_.reserve(tile_size_along_major_dim); 41 for (int64 i = 0; i < tile_size_along_major_dim; i++) { 42 llvm::Value* total_offset = 43 b->CreateMul(b->getInt64(matrix_size_along_minor_dim), 44 b->CreateAdd(b->getInt64(i), major_dim_offset)); 45 pointers_.push_back(vsl_->ComputeOffsetPointer(matrix, total_offset)); 46 } 47 } 48 49 // Load a tile consisting of `tile_size_along_major_dim` vectors from position 50 // {major: `major_dim_offset`, minor: `minor_dim_offset`}. 51 // 52 // Note: `major_dim_offset` is a parameter to the constructor. 53 std::vector<llvm::Value*> LoadTile(llvm::Value* minor_dim_offset) const { 54 std::vector<llvm::Value*> result; 55 result.reserve(pointers_.size()); 56 for (const auto& pointer : pointers_) { 57 result.push_back(vsl_->LoadVector(pointer, minor_dim_offset)); 58 } 59 return result; 60 } 61 62 // Stores `tile` to position {major: `major_dim_offset`, minor: 63 // `minor_dim_offset`}. 64 // 65 // Note: `major_dim_offset` is a parameter to the constructor. 66 void StoreTile(absl::Span<llvm::Value* const> tile, 67 llvm::Value* minor_dim_offset) const { 68 CHECK_EQ(tile.size(), pointers_.size()); 69 for (int64 i = 0; i < pointers_.size(); i++) { 70 vsl_->StoreVector(tile[i], pointers_[i], minor_dim_offset); 71 } 72 } 73 74 // Loads a tile of size [`tile_size_along_major_dim`, 75 // `tile_size_along_middle_dim`] from position {major: `major_dim_offset`, 76 // minor: `minor_dim_offset`} and then broadcasts each element into a vector 77 // of size vsl_.vector_size(). The (i,j)'th element of the return value is 78 // the (i,j)'th element in the tile broadcasted into an LLVM vector. 79 // 80 // Note: `major_dim_offset` is a parameter to the constructor. 81 std::vector<std::vector<llvm::Value*>> LoadBroadcastTile( 82 llvm::Value* minor_dim_offset, int64 tile_size_along_middle_dim) const { 83 std::vector<std::vector<llvm::Value*>> result; 84 result.resize(pointers_.size()); 85 for (int64 i = 0; i < pointers_.size(); i++) { 86 for (int64 j = 0; j < tile_size_along_middle_dim; j++) { 87 result[i].push_back(vsl_->LoadBroadcast( 88 pointers_[i], b_->CreateAdd(minor_dim_offset, b_->getInt64(j)))); 89 } 90 } 91 return result; 92 } 93 94 private: 95 VectorSupportLibrary* vsl_; 96 llvm::IRBuilder<>* b_; 97 std::vector<llvm::Value*> pointers_; 98 }; 99 100 // The base class for the classes representing the GEMV emitter configurations. 101 // 102 // The IR emitted (modulo the LLVM values representing the input and output 103 // buffers) by the row major and column major GEMV emitters should be a function 104 // of their configuration. This is important because their configuration is 105 // used as a key to cache the generated IR. 106 class GemvConfig { 107 public: 108 // Mixin for convenience. 109 template <typename T> 110 struct User { 111 public: 112 PrimitiveType scalar_type() const { 113 return derived().config().scalar_type(); 114 } 115 int64 tile_rows() const { return derived().config().tile_rows(); } 116 int64 tile_cols() const { return derived().config().tile_cols(); } 117 int64 m() const { return derived().config().m(); } 118 int64 k() const { return derived().config().k(); } 119 int64 has_addend() const { return derived().config().has_addend(); } 120 121 private: 122 const T& derived() const { return *static_cast<const T*>(this); } 123 }; 124 125 PrimitiveType scalar_type() const { return scalar_type_; } 126 int64 tile_rows() const { return tile_rows_; } 127 int64 tile_cols() const { return tile_cols_; } 128 int64 m() const { return m_; } 129 int64 k() const { return k_; } 130 bool has_addend() const { return has_addend_; } 131 132 string GetCacheKey() const { 133 return absl::StrCat(name_, "_", PrimitiveType_Name(scalar_type()), "_", 134 tile_rows(), "_", tile_cols(), "_", m(), "_", k(), 135 has_addend() ? "_with_addend" : ""); 136 } 137 138 protected: 139 explicit GemvConfig(string name, PrimitiveType scalar_type, int64 tile_rows, 140 int64 tile_cols, int64 m, int64 k, bool has_addend) 141 : name_(std::move(name)), 142 scalar_type_(scalar_type), 143 tile_rows_(tile_rows), 144 tile_cols_(tile_cols), 145 m_(m), 146 k_(k), 147 has_addend_(has_addend) {} 148 149 private: 150 string name_; 151 PrimitiveType scalar_type_; 152 int64 tile_rows_; 153 int64 tile_cols_; 154 int64 m_; 155 int64 k_; 156 bool has_addend_; 157 }; 158 159 // Computes a dot product between "[M,K]{0,1} lhs" with a [K,1] vector (the 160 // layout of the vector does not matter). This implementation uses a tiling 161 // scheme to improve performance. 162 // 163 // We logically separate the LHS matrix into four segments: 164 // 165 // +----------------------+---+ 166 // | | | 167 // | | | 168 // | A | B | 169 // | | | 170 // | | | 171 // | | | 172 // +----------------------+---+ 173 // | C | D | 174 // +----------------------+---+ 175 // 176 // where A is the largest submatrix of the LHS that can be evenly dividied into 177 // tiles. For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have: 178 // 179 // +---+---+---+---+ +--+--+--+--+ 180 // |M00|M10|M20|M30| |V0|V1|V2|V3| 181 // +---+---+---+---+ +--+--+--+--+ 182 // |M01|M11|M21|M31| and |V0|V1|V2|V3| 183 // +---+---+---+---+ +--+--+--+--+ 184 // |M02|M12|M22|M32| |V0|V1|V2|V3| 185 // +---+---+---+---+ +--+--+--+--+ 186 // |M03|M13|M23|M33| |V0|V1|V2|V3| 187 // +---+---+---+---+ +--+--+--+--+ 188 // 189 // (Legend: rows are horizontal and columns are vertical; and each column is one 190 // llvm::Value of a vector type) 191 // 192 // where: 193 // 194 // a. The left tile is from the column major left matrix. 195 // b. The right tile is an elementwise broadcast of a [V0, V1, V2, V3] 196 // vector loaded from the RHS vector. 197 // 198 // As we iterate through the column dimension, we compute the change to the 199 // result vector by an elementwise multiplication between the two tiles above 200 // followed by a reduction along the major dimension: 201 // 202 // +-----------------------------------+ 203 // | M00*V0 + M10*V1 + M20*V2 + M30*V3 | 204 // +-----------------------------------+ 205 // | M01*V0 + M11*V1 + M21*V2 + M31*V3 | 206 // Result[R:R+4] += +-----------------------------------+ 207 // | M02*V0 + M12*V1 + M22*V2 + M32*V3 | 208 // +-----------------------------------+ 209 // | M03*V0 + M13*V1 + M23*V2 + M33*V3 | 210 // +-----------------------------------+ 211 // 212 // Where R is the starting row for the tile. 213 // 214 // We have an inner epilogue loop to deal with the "C" submatrix and an outer 215 // epilogue loop to deal with the B,D submarix. 216 // 217 // TODO(sanjoy): We should investigate if using gather loads and scatter stores 218 // can be used here have the same inner loop for both column-major and row-major 219 // matrix-vector products. 220 class ColumnMajorMatrixVectorProductEmitter 221 : public GemvConfig::User<ColumnMajorMatrixVectorProductEmitter> { 222 public: 223 class Config : public GemvConfig { 224 public: 225 explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, 226 int64 m, int64 k, bool has_addend) 227 : GemvConfig(/*name=*/"col_major_gemv", scalar_type, 228 /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m, 229 /*k=*/k, /*has_addend=*/has_addend) {} 230 }; 231 232 ColumnMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs, 233 llvm::Value* rhs, llvm::Value* addend, 234 llvm::Value* result, 235 llvm::IRBuilder<>* b) 236 : config_(config), 237 lhs_(lhs), 238 rhs_(rhs), 239 addend_(addend), 240 result_(result), 241 b_(b), 242 ksl_(b_), 243 vsl_(config.scalar_type(), /*vector_size=*/config.tile_rows(), b_, "") { 244 CHECK(tile_rows() > 0 && IsPowerOfTwo(static_cast<uint64>(tile_rows()))); 245 CHECK(!has_addend() || addend != nullptr); 246 } 247 248 void Emit(); 249 250 const Config& config() const { return config_; } 251 252 private: 253 void EmitOuterLoopBody(llvm::Value* column, int64 column_count, 254 bool is_first_column); 255 256 MemoryTile GetLhsMemoryTile(llvm::Value* column_start, int64 column_count) { 257 return MemoryTile(&vsl_, b_, /*matrix=*/lhs_, 258 /*matrix_size_along_minor_dim=*/m(), 259 /*major_dim_offset=*/column_start, 260 /*tile_size_along_major_dim=*/column_count); 261 } 262 263 // Load a tile of values from the RHS. For the RHS a "tile" is a contiguous 264 // sequence of `count` values, each one broadcasted to the vector width. 265 std::vector<llvm::Value*> LoadRhsTile(llvm::Value* offset, int64 count) { 266 llvm::Value* base_pointer = vsl_.ComputeOffsetPointer(rhs_, offset); 267 std::vector<llvm::Value*> result; 268 result.reserve(count); 269 for (int64 i = 0; i < count; i++) { 270 result.push_back(vsl_.LoadBroadcast(base_pointer, i)); 271 } 272 return result; 273 } 274 275 void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile, 276 const std::vector<llvm::Value*>& rhs_tile, 277 int64 columns, bool is_first_column); 278 279 void EmitInnerLoopEpilogue(llvm::Value* current_tile_col, int64 columns, 280 bool is_first_tiled_column); 281 282 Config config_; 283 llvm::Value* lhs_; 284 llvm::Value* rhs_; 285 llvm::Value* addend_; 286 llvm::Value* result_; 287 llvm::IRBuilder<>* b_; 288 KernelSupportLibrary ksl_; 289 VectorSupportLibrary vsl_; 290 }; 291 292 void ColumnMajorMatrixVectorProductEmitter::EmitOuterLoopBody( 293 llvm::Value* column, int64 column_count, bool is_first_column) { 294 MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*column_start=*/column, 295 /*column_count=*/column_count); 296 297 std::vector<llvm::Value*> rhs_tile = 298 LoadRhsTile(column, /*count=*/column_count); 299 EmitInnerLoopTiled(&lhs_memory_tile, rhs_tile, 300 /*columns=*/column_count, is_first_column); 301 EmitInnerLoopEpilogue(column, /*columns=*/column_count, is_first_column); 302 } 303 304 void ColumnMajorMatrixVectorProductEmitter::Emit() { 305 // See the comment on the class declaration for the algorithm used here. 306 int64 column_remainder = k() % tile_cols(); 307 int64 column_limit = k() - column_remainder; 308 309 ksl_.For("dot.outer.tiled", 310 /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols(), 311 [&](llvm::Value* column, bool is_first_column) { 312 EmitOuterLoopBody(column, tile_cols(), is_first_column); 313 }); 314 315 if (column_remainder != 0) { 316 EmitOuterLoopBody(b_->getInt64(column_limit), column_remainder, 317 column_limit == 0); 318 } 319 } 320 321 void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( 322 MemoryTile* lhs_memory_tile, const std::vector<llvm::Value*>& rhs_tile, 323 int64 columns, bool is_first_column) { 324 int64 row_limit = m() - (m() % tile_rows()); 325 326 ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/row_limit, 327 /*step=*/tile_rows(), [&](llvm::Value* row) { 328 std::vector<llvm::Value*> lhs_tile = 329 lhs_memory_tile->LoadTile(/*minor_dim_offset=*/row); 330 llvm::Value* accumulator = 331 is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row) 332 : vsl_.GetZeroVector()) 333 : vsl_.LoadVector(result_, row); 334 for (int i = 0; i < columns; i++) { 335 accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator); 336 } 337 vsl_.StoreVector(accumulator, result_, row); 338 }); 339 } 340 341 void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( 342 llvm::Value* current_tile_col, int64 columns, bool is_first_tiled_column) { 343 int64 row_start = m() - (m() % tile_rows()); 344 if (row_start == m()) { 345 return; 346 } 347 348 llvm::Value* columns_llvm = b_->getInt64(columns); 349 350 // for (col = current_tile_col; col < (columns + current_tile_col); col++) 351 // for (row = row_start, row < m_; row++) { 352 // result[row] += lhs[row, col] * rhs[col] 353 // // Also take into account that if col is 0 then result[row] is not 354 // // initialized. 355 // } 356 357 ksl_.For( 358 "dot.inner.epilg.outer", /*start=*/current_tile_col, 359 /*end=*/b_->CreateAdd(columns_llvm, current_tile_col), 360 /*step=*/1, /*peel_first_iteration=*/false, 361 [&](llvm::Value* col, llvm::Value* is_first_scalar_col) { 362 llvm::Value* rhs_element = vsl_.LoadScalar(rhs_, col); 363 llvm::Value* total_offset = b_->CreateMul(col, b_->getInt64(m())); 364 llvm::Value* lhs_base_pointer = 365 vsl_.ComputeOffsetPointer(lhs_, total_offset); 366 ksl_.For( 367 "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m(), 368 /*step=*/1, [&](llvm::Value* scalar_row) { 369 llvm::Value* product = vsl_.Mul( 370 vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element); 371 llvm::Value* setting_result_first_time = b_->CreateAnd( 372 is_first_scalar_col, b_->getInt1(is_first_tiled_column)); 373 ksl_.If( 374 setting_result_first_time, 375 /*true_block_generator=*/ 376 [&]() { 377 if (addend_) { 378 vsl_.StoreScalar( 379 vsl_.Add(vsl_.LoadScalar(addend_, scalar_row), 380 product), 381 result_, scalar_row); 382 } else { 383 vsl_.StoreScalar(product, result_, scalar_row); 384 } 385 }, 386 /*false_block_generator=*/ 387 [&]() { 388 vsl_.StoreScalar( 389 vsl_.Add(vsl_.LoadScalar(result_, scalar_row), product), 390 result_, scalar_row); 391 }); 392 }); 393 }); 394 } 395 396 // Computes a dot product between "[M,K]{1,0} lhs" with a [K,1] vector (the 397 // layout of the vector does not matter). This implementation uses a tiling 398 // scheme to improve performance. 399 // 400 // We logically separate the LHS matrix into four segments: 401 // 402 // +----------------------+---+ 403 // | | | 404 // | | | 405 // | A | B | 406 // | | | 407 // | | | 408 // | | | 409 // +----------------------+---+ 410 // | C | D | 411 // +----------------------+---+ 412 // 413 // where A is the largest submatrix of the LHS that can be evenly dividied into 414 // tiles. For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have: 415 // 416 // +---+---+---+---+ 417 // |M00|M10|M20|M30| 418 // +---+---+---+---+ +--+--+--+--+ 419 // |M01|M11|M21|M31| and |V0|V1|V2|V3| 420 // +---+---+---+---+ +--+--+--+--+ 421 // |M02|M12|M22|M32| 422 // +---+---+---+---+ 423 // |M03|M13|M23|M33| 424 // +---+---+---+---+ 425 // 426 // (Legend: rows are horizontal and columns are vertical; and each row is one 427 // llvm::Value of a vector type) 428 // 429 // where: 430 // 431 // a. The left tile is loaded from the row major left matrix. 432 // b. The right vector is loaded from the RHS vector. 433 // 434 // We keep 4 vector accumulators accumulating the following four vector 435 // expressions as we iterate over the row dimension: 436 // 437 // +------+------+------+------+ 438 // |M0I*V0|M1I*V1|M2I*V2|M3I*V3| for I in [0,4) 439 // +------+------+------+------+ 440 // 441 // In the end we do a horizontal reduction over these 4 vector accumulators to 442 // get 4 values in the result vector. 443 // 444 // We have an inner epilogue loop to deal with the "B" sub-matrix and an outer 445 // epilogue loop to deal with the C,D submatrix. 446 class RowMajorMatrixVectorProductEmitter 447 : public GemvConfig::User<RowMajorMatrixVectorProductEmitter> { 448 public: 449 class Config : public GemvConfig { 450 public: 451 explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, 452 int64 m, int64 k, bool has_addend) 453 : GemvConfig(/*name=*/"row_major_gemv", scalar_type, 454 /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m, 455 /*k=*/k, /*has_addend=*/has_addend) {} 456 }; 457 458 RowMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs, 459 llvm::Value* rhs, llvm::Value* addend, 460 llvm::Value* result, llvm::IRBuilder<>* b) 461 : config_(config), 462 lhs_(lhs), 463 rhs_(rhs), 464 addend_(addend), 465 result_(result), 466 b_(b), 467 ksl_(b_), 468 vsl_(scalar_type(), /*vector_size=*/tile_cols(), b_, "") { 469 CHECK(tile_cols() > 0 && IsPowerOfTwo(static_cast<uint64>(tile_cols()))); 470 CHECK(!has_addend() || addend != nullptr); 471 } 472 473 void Emit(); 474 475 const Config& config() const { return config_; } 476 477 private: 478 MemoryTile GetLhsMemoryTile(llvm::Value* row_start, int64 row_count) { 479 return MemoryTile(&vsl_, b_, /*matrix=*/lhs_, 480 /*matrix_size_along_minor_dim=*/k(), 481 /*major_dim_offset=*/row_start, 482 /*tile_size_along_major_dim=*/row_count); 483 } 484 485 void EmitOuterLoopBody(llvm::Value* row, int64 row_count); 486 487 void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile, int64 rows, 488 std::vector<VectorVariable>* vector_accumulators); 489 490 void EmitInnerLoopEpilogue(llvm::Value* current_tile_row, int64 rows, 491 std::vector<ScalarVariable>* scalar_accumulators); 492 493 Config config_; 494 llvm::Value* lhs_; 495 llvm::Value* rhs_; 496 llvm::Value* addend_; 497 llvm::Value* result_; 498 llvm::IRBuilder<>* b_; 499 KernelSupportLibrary ksl_; 500 VectorSupportLibrary vsl_; 501 }; 502 503 void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row, 504 int64 row_count) { 505 MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*row_start=*/row, 506 /*row_count=*/row_count); 507 std::vector<VectorVariable> vector_accumulators; 508 std::vector<ScalarVariable> scalar_accumulators; 509 for (int i = 0; i < row_count; i++) { 510 vector_accumulators.emplace_back(&vsl_, vsl_.GetZeroVector()); 511 scalar_accumulators.emplace_back(&vsl_, vsl_.GetZeroScalar()); 512 } 513 EmitInnerLoopTiled(&lhs_memory_tile, /*rows=*/row_count, 514 &vector_accumulators); 515 EmitInnerLoopEpilogue(/*current_tile_row=*/row, /*rows=*/row_count, 516 &scalar_accumulators); 517 518 std::vector<llvm::Value*> accumulator_values; 519 std::transform( 520 vector_accumulators.begin(), vector_accumulators.end(), 521 std::back_inserter(accumulator_values), 522 [](const VectorVariable& vector_var) { return vector_var.Get(); }); 523 524 std::vector<llvm::Value*> horizontal_sums; 525 if (row_count == vsl_.vector_size()) { 526 if (addend_) { 527 horizontal_sums = vsl_.ComputeHorizontalSums( 528 std::move(accumulator_values), vsl_.LoadVector(addend_, row)); 529 } else { 530 horizontal_sums = 531 vsl_.ComputeHorizontalSums(std::move(accumulator_values)); 532 } 533 } else { 534 horizontal_sums = vsl_.ComputeHorizontalSums(std::move(accumulator_values)); 535 } 536 537 for (int i = 0; i < row_count; i++) { 538 llvm::Value* result_value = 539 vsl_.Add(horizontal_sums[i], scalar_accumulators[i].Get()); 540 llvm::Value* offset = b_->CreateAdd(b_->getInt64(i), row); 541 if (addend_ && row_count != vsl_.vector_size()) { 542 result_value = vsl_.Add(vsl_.LoadScalar(addend_, offset), result_value); 543 } 544 vsl_.StoreScalar(result_value, result_, offset); 545 } 546 } 547 548 void RowMajorMatrixVectorProductEmitter::Emit() { 549 // See the comment on the class declaration for the algorithm used here. 550 int64 row_remainder = m() % tile_rows(); 551 int64 row_limit = m() - row_remainder; 552 553 ksl_.For("dot.outer.tiled", 554 /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(), 555 [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); }); 556 557 if (row_remainder != 0) { 558 EmitOuterLoopBody(b_->getInt64(row_limit), row_remainder); 559 } 560 } 561 562 void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( 563 MemoryTile* lhs_memory_tile, int64 rows, 564 std::vector<VectorVariable>* vector_accumulators) { 565 int64 column_limit = k() - (k() % tile_cols()); 566 567 ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/column_limit, 568 /*step=*/tile_cols(), [&](llvm::Value* col) { 569 std::vector<llvm::Value*> lhs_tile = 570 lhs_memory_tile->LoadTile(/*minor_dim_offset=*/col); 571 llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col); 572 for (int i = 0; i < rows; i++) { 573 llvm::Value* old_sum = (*vector_accumulators)[i].Get(); 574 (*vector_accumulators)[i].Set( 575 vsl_.Add(old_sum, vsl_.Mul(rhs_value, lhs_tile[i]))); 576 } 577 }); 578 } 579 580 void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( 581 llvm::Value* current_tile_row, int64 rows, 582 std::vector<ScalarVariable>* scalar_accumulators) { 583 int64 column_start = k() - (k() % tile_cols()); 584 if (column_start == k()) { 585 return; 586 } 587 588 for (int r = 0; r < rows; r++) { 589 llvm::Value* total_offset = b_->CreateMul( 590 b_->CreateAdd(b_->getInt64(r), current_tile_row), b_->getInt64(k())); 591 llvm::Value* lhs_base_pointer = 592 vsl_.ComputeOffsetPointer(lhs_, total_offset); 593 ksl_.For("dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k(), 594 /*step=*/1, [&](llvm::Value* scalar_col) { 595 llvm::Value* product = 596 vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col), 597 vsl_.LoadScalar(rhs_, scalar_col)); 598 llvm::Value* old_value = (*scalar_accumulators)[r].Get(); 599 (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product)); 600 }); 601 } 602 } 603 604 // This class implements a tiled matrix multiplication algorithm, intended for 605 // multiplying small matrices that don't need cache tiling. 606 // 607 // In the future this can be used as the innermost GEBP loop in a GEMM kernel as 608 // described in "Goto, Kazushige, and Robert A. Geijn. "Anatomy of 609 // high-performance matrix multiplication." ACM Transactions on Mathematical 610 // Software (TOMS) 34.3 (2008): 12.". 611 // 612 // This only supports canonical dot operations (i.e. where the lhs contraction 613 // dimension is 1 and the rhs contraction dimension is 0) over row major 614 // matrices. 615 class TiledSmallGemmEmitter { 616 public: 617 // Describe the dimensions of the kernel. 618 class Dimensions { 619 public: 620 explicit Dimensions(int64 m, int64 k, int64 n) : m_(m), k_(k), n_(n) {} 621 622 int64 m() const { return m_; } 623 int64 k() const { return k_; } 624 int64 n() const { return n_; } 625 626 string ToString() const { return absl::StrCat(m(), "x", k(), "x", n()); } 627 628 private: 629 const int64 m_; 630 const int64 k_; 631 const int64 n_; 632 }; 633 634 // Represents the configuration of the emitter. The LLVM IR emitted by the 635 // emitter, modulo the LLVM values holding the input and output buffers, must 636 // be a function of the instance of `Config` passed to it. 637 // 638 // `dims` holds the matrix multiplication dimensions. 639 // 640 // `max_vectorization_width` is the maximum vector width (i.e. the width of 641 // the largest vector register we will use). This can be larger than the 642 // largest vector register supported by the machine -- LLVM will legalize 643 // these large vector widths into legally sized vectors. 644 // 645 // `max_vector_count` is the maximum number of vectors of size 646 // `max_vectorization_width` that we will attempt to process at once. 647 // 648 // `min_vectorization_width` is the smallest vector width the emitter will use 649 // -- below that it will devolve to using a scalar loop. 650 // 651 // The innermost reduction loop executes the matrix multiply in tiles of size 652 // [`tile_size_m`, `tile_size_k`] from the LHS and [`tile_size_k`, 653 // <vectorization width>] in the RHS. 654 class Config { 655 public: 656 explicit Config(PrimitiveType scalar_type, Dimensions dims, 657 int64 max_vectorization_width, int64 max_vector_count, 658 int64 min_vectorization_width, int64 tile_size_m, 659 int64 tile_size_k) 660 : scalar_type_(scalar_type), 661 dims_(dims), 662 max_vectorization_width_(max_vectorization_width), 663 max_vector_count_(max_vector_count), 664 min_vectorization_width_(min_vectorization_width), 665 tile_size_m_(tile_size_m), 666 tile_size_k_(tile_size_k) {} 667 668 string GetCacheKey() const { 669 return absl::StrCat("gemm_", PrimitiveType_Name(scalar_type()), "_", 670 dims().ToString(), "_", max_vectorization_width(), 671 "_", min_vectorization_width(), "_", tile_size_m(), 672 "_", tile_size_k()); 673 } 674 675 PrimitiveType scalar_type() const { return scalar_type_; } 676 Dimensions dims() const { return dims_; } 677 int64 max_vectorization_width() const { return max_vectorization_width_; } 678 int64 max_vector_count() const { return max_vector_count_; } 679 int64 min_vectorization_width() const { return min_vectorization_width_; } 680 681 int64 tile_size_m() const { return tile_size_m_; } 682 int64 tile_size_k() const { return tile_size_k_; } 683 684 private: 685 PrimitiveType scalar_type_; 686 Dimensions dims_; 687 int64 max_vectorization_width_; 688 int64 max_vector_count_; 689 int64 min_vectorization_width_; 690 int64 tile_size_m_; 691 int64 tile_size_k_; 692 }; 693 694 // Creates an instance of TiledSmallGemmEmitter that matrix-multiplies 695 // `lhs` with `rhs` and stores the result in `result`. 696 explicit TiledSmallGemmEmitter(Config config, llvm::Value* lhs, 697 llvm::Value* rhs, llvm::Value* result, 698 llvm::IRBuilder<>* b) 699 : lhs_(lhs), 700 rhs_(rhs), 701 result_(result), 702 config_(config), 703 b_(b), 704 ksl_(b_) { 705 CHECK(max_vectorization_width() > 0 && 706 IsPowerOfTwo(static_cast<uint64>(max_vectorization_width()))); 707 CHECK_GT(max_vector_count(), 0); 708 CHECK(min_vectorization_width() > 0 && 709 IsPowerOfTwo(static_cast<uint64>(min_vectorization_width()))); 710 CHECK_GE(max_vectorization_width(), min_vectorization_width()); 711 CHECK_GT(tile_size_k(), 0); 712 } 713 714 void Emit(); 715 716 private: 717 // The HandleResiduesOnX helpers split the iteration space for dimension X 718 // into a multiple of the tile size on dimension X and an epilogue. These 719 // helpers ultimately call into `EmitTiledGemm` for emitting the 720 // tiled GEMM kernel. 721 722 void HandleResiduesOnN(); 723 void HandleResiduesOnK(VectorSupportLibrary* vsl, llvm::Value* n_start, 724 llvm::Value* n_end); 725 void HandleResiduesOnM(VectorSupportLibrary* vsl, int64 tile_size_k, 726 llvm::Value* k_start, llvm::Value* k_end, 727 llvm::Value* n_start, llvm::Value* n_end); 728 729 // This emits a tiled GEMM kernel. For a detailed description see the comment 730 // on the implementation. 731 void EmitTiledGemm(VectorSupportLibrary* vsl, int64 tile_size_k, 732 llvm::Value* k_start, llvm::Value* k_end, 733 llvm::Value* n_start, llvm::Value* n_end, 734 int64 tile_size_m, llvm::Value* m_start, 735 llvm::Value* m_end); 736 737 llvm::Value* GetInt64(int64 value) { return b_->getInt64(value); } 738 739 Config config() const { return config_; } 740 Dimensions dims() const { return config().dims(); } 741 742 int64 max_vectorization_width() const { 743 return config().max_vectorization_width(); 744 } 745 int64 max_vector_count() const { return config().max_vector_count(); } 746 int64 min_vectorization_width() const { 747 return config().min_vectorization_width(); 748 } 749 int64 tile_size_m() const { return config().tile_size_m(); } 750 int64 tile_size_k() const { return config().tile_size_k(); } 751 PrimitiveType scalar_type() const { return config().scalar_type(); } 752 753 llvm::Value* lhs_; 754 llvm::Value* rhs_; 755 llvm::Value* result_; 756 Config config_; 757 758 llvm::IRBuilder<>* b_; 759 KernelSupportLibrary ksl_; 760 }; 761 762 void TiledSmallGemmEmitter::Emit() { HandleResiduesOnN(); } 763 764 void TiledSmallGemmEmitter::HandleResiduesOnN() { 765 // We can only iterate the `n` dimension for an extent that is divisible by 766 // the vectorization width. So we emit an outer loop that first processes the 767 // largest extent in `n` that is divisible by max_vectorization_width, then 768 // the largest remaining extent that is divisible by max_vectorization_width / 769 // 2 etc. 770 771 int64 current_vectorization_width = 772 max_vector_count() * max_vectorization_width(); 773 int64 current_vector_count = max_vector_count(); 774 775 int64 n_start = 0; 776 while (n_start != dims().n() && 777 current_vectorization_width >= min_vectorization_width()) { 778 int64 n_end = dims().n() - (dims().n() % current_vectorization_width); 779 if (n_start != n_end) { 780 VectorSupportLibrary vsl(scalar_type(), current_vectorization_width, b_, 781 "gemm"); 782 HandleResiduesOnK(&vsl, GetInt64(n_start), GetInt64(n_end)); 783 n_start = n_end; 784 } 785 if (current_vector_count == 1) { 786 current_vectorization_width /= 2; 787 } else { 788 current_vector_count--; 789 current_vectorization_width = 790 current_vector_count * max_vectorization_width(); 791 } 792 } 793 794 if (n_start != dims().n()) { 795 VectorSupportLibrary vsl(scalar_type(), 1, b_, "gemm"); 796 ksl_.For("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) { 797 llvm::Value* n_i_next = b_->CreateAdd(n_i, b_->getInt64(1)); 798 HandleResiduesOnK(&vsl, n_i, n_i_next); 799 }); 800 } 801 } 802 803 void TiledSmallGemmEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl, 804 llvm::Value* n_start, 805 llvm::Value* n_end) { 806 int64 k_start = 0; 807 int64 k_end = dims().k() - (dims().k() % tile_size_k()); 808 if (k_end != k_start) { 809 HandleResiduesOnM(vsl, tile_size_k(), GetInt64(k_start), GetInt64(k_end), 810 n_start, n_end); 811 k_start = k_end; 812 } 813 814 if (k_start != dims().k()) { 815 HandleResiduesOnM(vsl, dims().k() - k_start, GetInt64(k_start), 816 GetInt64(dims().k()), n_start, n_end); 817 } 818 } 819 820 void TiledSmallGemmEmitter::HandleResiduesOnM( 821 VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, 822 llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end) { 823 const int64 m_end = dims().m() - dims().m() % tile_size_m(); 824 EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end, tile_size_m(), 825 GetInt64(0), GetInt64(m_end)); 826 827 if (m_end != dims().m()) { 828 EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end, 829 dims().m() - m_end, GetInt64(m_end), GetInt64(dims().m())); 830 } 831 } 832 833 // The loop structure is: 834 // 835 // Iterate over dimension M as m: 836 // Iterate over dimension N as n: 837 // Iterate over dimension K as k: 838 // OutputTile[m,n] += Dot(LhsTile[m,k], RhsTile[k,n]) 839 // 840 // I.e. a just a tiled version of a "naive" GEMM. 841 // 842 // The tiling scheme is as follows: 843 // 844 // Let the LHS be: 845 // 846 // +----+----+----+ 847 // | a0 | b0 | c0 | . 848 // +----+----+----+ . 849 // | a1 | b1 | c1 | . 850 // +----+----+----+ 851 // .. .. 852 // 853 // and the RHS be: 854 // 855 // +----+----+----+----+ 856 // | p0 | p1 | p2 | p3 | . 857 // +----+----+----+----+ . 858 // | q0 | q1 | q2 | q3 | . 859 // +----+----+----+----+ 860 // | r0 | r1 | r2 | r3 | . 861 // +----+----+----+----+ . 862 // ...... ...... 863 // 864 // and let tile_size_m=2, tile_size_k=3 and the vector width (implicitly denoted 865 // by `vsl`) be 4. Then we want to matrix multiply this tile to get a [2,4] 866 // matrix that we can increment the result matrix by. 867 // 868 // First broadcast the rows row in LHS to 3 vectors of width 4, giving us a rank 869 // 3 array, L, of dimension [2,3,4]: 870 // 871 // L[0,_,_] * L[1,_,_] 872 // * 873 // +----+----+----+----+ * +----+----+----+----+ 874 // | a0 | a0 | a0 | a0 | * | a1 | a1 | a1 | a1 | 875 // +----+----+----+----+ * +----+----+----+----+ 876 // | b0 | b0 | b0 | b0 | * | b1 | b1 | b1 | b1 | 877 // +----+----+----+----+ * +----+----+----+----+ 878 // | c0 | c0 | c0 | c0 | * | c1 | c1 | c1 | c1 | 879 // +----+----+----+----+ * +----+----+----+----+ 880 // 881 // 882 // Then we FMA L[0,_,_] with the RHS to get the first row of the result and 883 // L[1,_,_] with the RHS to get the second row of the result. For example, 884 // L[0,_,_] is computed as: 885 // 886 // +----+----+----+----+ +----+----+----+----+ 887 // | a0 | a0 | a0 | a0 | * | p0 | p1 | p2 | p3 | + 888 // +----+----+----+----+ +----+----+----+----+ 889 // 890 // +----+----+----+----+ +----+----+----+----+ 891 // | b0 | b0 | b0 | b0 | * | q0 | q1 | q2 | q3 | + 892 // +----+----+----+----+ +----+----+----+----+ 893 // 894 // +----+----+----+----+ +----+----+----+----+ 895 // | c0 | c0 | c0 | c0 | * | r0 | r1 | r2 | r3 | 896 // +----+----+----+----+ +----+----+----+----+ 897 // 898 // to get: 899 // 900 // +-------------------+-------------------+-------------------+--------- 901 // | a0*p0+b0*q0+c0*r0 | a0*p1+b0*q1+c0*r1 | a0*p2+b0*q2+c0*r2 | ... 902 // +-------------------+-------------------+-------------------+--------- 903 void TiledSmallGemmEmitter::EmitTiledGemm( 904 VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, 905 llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end, 906 int64 tile_size_m, llvm::Value* m_start, llvm::Value* m_end) { 907 ksl_.For("dot.m", m_start, m_end, tile_size_m, [&](llvm::Value* m_i) { 908 MemoryTile result_memory_tile(vsl, b_, /*matrix=*/result_, 909 /*matrix_size_along_minor_dim=*/dims().n(), 910 /*major_dim_offset=*/m_i, 911 /*tile_size_along_major_dim=*/tile_size_m); 912 MemoryTile lhs_memory_tile(vsl, b_, /*matrix=*/lhs_, 913 /*matrix_size_along_minor_dim=*/dims().k(), 914 /*major_dim_offset=*/m_i, 915 /*tile_size_along_major_dim=*/tile_size_m); 916 ksl_.For( 917 "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) { 918 TileVariable result_tile_var(vsl, result_memory_tile.LoadTile(n_i)); 919 ksl_.For("dot.k", k_start, k_end, tile_size_k, [&](llvm::Value* k_i) { 920 MemoryTile rhs_memory_tile(vsl, b_, rhs_, dims().n(), k_i, 921 tile_size_k); 922 std::vector<std::vector<llvm::Value*>> lhs_tile = 923 lhs_memory_tile.LoadBroadcastTile(k_i, tile_size_k); 924 std::vector<llvm::Value*> rhs_tile = rhs_memory_tile.LoadTile(n_i); 925 std::vector<llvm::Value*> result_tile = result_tile_var.Get(); 926 for (int64 r_m_i = 0; r_m_i < tile_size_m; r_m_i++) { 927 for (int64 r_k_i = 0; r_k_i < tile_size_k; r_k_i++) { 928 result_tile[r_m_i] = 929 vsl->MulAdd(lhs_tile[r_m_i][r_k_i], rhs_tile[r_k_i], 930 result_tile[r_m_i]); 931 } 932 } 933 result_tile_var.Set(result_tile); 934 }); 935 936 result_memory_tile.StoreTile(result_tile_var.Get(), n_i); 937 }); 938 }); 939 } 940 941 llvm::Type* GetPointerToElementType(llvm::Type* pointer_type) { 942 llvm::Type* type = 943 llvm::cast<llvm::PointerType>(pointer_type)->getElementType(); 944 while (auto* array_type = llvm::dyn_cast<llvm::ArrayType>(type)) { 945 type = array_type->getElementType(); 946 } 947 948 return type->getPointerTo(); 949 } 950 951 struct GemvBuffersWithCanonicalType { 952 llvm::Value* lhs_canonicalized; 953 llvm::Value* rhs_canonicalized; 954 llvm::Value* addend_canonicalized; 955 llvm::Value* result_canonicalized; 956 }; 957 958 GemvBuffersWithCanonicalType GetGemvBuffersWithCanonicalType( 959 llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, 960 llvm::Value* result, llvm::IRBuilder<>* b) { 961 // We characterize a GEMV operation via M and K, since N is implicitly 1. 962 // This means the GEMV that multiplies (say) [5,6] with [6,1] is implemented 963 // by the same GEMV that multiplies [5,6] with [1,6]. However, the 964 // `llvm::Types` for the inputs to the two GEMVs don't match (in a trivial 965 // sense -- the in memory representations are the same) since they're computed 966 // from the `xla::Shape`s. Since we want to be able to call the same 967 // `llvm::Function` for the two GEMVs we canonicalize the types of the GEMV 968 // inputs here into the same type. 969 GemvBuffersWithCanonicalType buffers_with_canonical_type; 970 llvm::Type* lhs_type = lhs->getType(); 971 llvm::Type* rhs_type = rhs->getType(); 972 llvm::Type* addend_type = addend ? addend->getType() : nullptr; 973 llvm::Type* result_type = result->getType(); 974 975 buffers_with_canonical_type.lhs_canonicalized = 976 b->CreateBitCast(lhs, GetPointerToElementType(lhs_type)); 977 buffers_with_canonical_type.rhs_canonicalized = 978 b->CreateBitCast(rhs, GetPointerToElementType(rhs_type)); 979 buffers_with_canonical_type.addend_canonicalized = 980 addend ? b->CreateBitCast(addend, GetPointerToElementType(addend_type)) 981 : nullptr; 982 buffers_with_canonical_type.result_canonicalized = 983 b->CreateBitCast(result, GetPointerToElementType(result_type)); 984 985 return buffers_with_canonical_type; 986 } 987 988 } // namespace 989 990 void EmitRowMajorGemv(PrimitiveType scalar_type, int64 tile_rows, 991 int64 tile_cols, int64 m, int64 k, llvm::Value* lhs, 992 llvm::Value* rhs, llvm::Value* addend, 993 llvm::Value* result, llvm::IRBuilder<>* b, 994 const HloModuleConfig& module_config) { 995 RowMajorMatrixVectorProductEmitter::Config config( 996 /*scalar_type=*/scalar_type, 997 /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, 998 /*m=*/m, /*k=*/k, /*has_addend=*/addend != nullptr); 999 1000 GemvBuffersWithCanonicalType canonical_inputs = 1001 GetGemvBuffersWithCanonicalType(lhs, rhs, addend, result, b); 1002 1003 KernelSupportLibrary::EmitAndCallOutlinedKernel( 1004 module_config, b, config.GetCacheKey(), 1005 canonical_inputs.lhs_canonicalized, canonical_inputs.rhs_canonicalized, 1006 canonical_inputs.addend_canonicalized, 1007 canonical_inputs.result_canonicalized, 1008 [&config, b, &canonical_inputs](llvm::Value* lhs, llvm::Value* rhs, 1009 llvm::Value* addend, 1010 llvm::Value* result) { 1011 RowMajorMatrixVectorProductEmitter emitter(config, lhs, rhs, addend, 1012 result, b); 1013 emitter.Emit(); 1014 }); 1015 } 1016 1017 void EmitColumnMajorGemv(PrimitiveType scalar_type, int64 tile_rows, 1018 int64 tile_cols, int64 m, int64 k, llvm::Value* lhs, 1019 llvm::Value* rhs, llvm::Value* addend, 1020 llvm::Value* result, llvm::IRBuilder<>* b, 1021 const HloModuleConfig& module_config) { 1022 ColumnMajorMatrixVectorProductEmitter::Config config( 1023 /*scalar_type=*/scalar_type, 1024 /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, 1025 /*m=*/m, /*k=*/k, /*has_addend=*/addend != nullptr); 1026 1027 GemvBuffersWithCanonicalType canonical_inputs = 1028 GetGemvBuffersWithCanonicalType(lhs, rhs, addend, result, b); 1029 1030 KernelSupportLibrary::EmitAndCallOutlinedKernel( 1031 module_config, b, config.GetCacheKey(), 1032 canonical_inputs.lhs_canonicalized, canonical_inputs.rhs_canonicalized, 1033 canonical_inputs.addend_canonicalized, 1034 canonical_inputs.result_canonicalized, 1035 [&config, b, &canonical_inputs](llvm::Value* lhs, llvm::Value* rhs, 1036 llvm::Value* addend, 1037 llvm::Value* result) { 1038 ColumnMajorMatrixVectorProductEmitter emitter(config, lhs, rhs, addend, 1039 result, b); 1040 emitter.Emit(); 1041 }); 1042 } 1043 1044 void EmitSmallGemm(PrimitiveType scalar_type, int64 m, int64 k, int64 n, 1045 int64 max_vectorization_width, int64 max_vector_count, 1046 int64 min_vectorization_width, int64 tile_size_m, 1047 int64 tile_size_k, llvm::Value* lhs, llvm::Value* rhs, 1048 llvm::Value* result, llvm::IRBuilder<>* b, 1049 const HloModuleConfig& module_config) { 1050 TiledSmallGemmEmitter::Config config( 1051 /*scalar_type=*/scalar_type, 1052 TiledSmallGemmEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n}, 1053 /*max_vectorization_width=*/max_vectorization_width, 1054 /*max_vector_count=*/max_vector_count, 1055 /*min_vectorization_width=*/min_vectorization_width, 1056 /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k); 1057 1058 KernelSupportLibrary::EmitAndCallOutlinedKernel( 1059 module_config, b, config.GetCacheKey(), lhs, rhs, result, 1060 [&](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result) { 1061 TiledSmallGemmEmitter small_gemm_emitter(config, /*lhs=*/lhs, 1062 /*rhs=*/rhs, 1063 /*result=*/result, b); 1064 small_gemm_emitter.Emit(); 1065 }); 1066 } 1067 1068 } // namespace cpu 1069 } // namespace xla 1070