Home | History | Annotate | Download | only in cpu
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h"
     17 
     18 #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
     19 #include "tensorflow/compiler/xla/service/hlo_module.h"
     20 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
     21 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
     22 
     23 namespace xla {
     24 namespace cpu {
     25 namespace {
     26 
     27 using tensorflow::int64;
     28 
     29 // Provides tiled access to an in-memory rank 2 array.
     30 class MemoryTile {
     31  public:
     32   // Constructs a MemoryTile that can operate on tiles consisting of
     33   // `tile_size_along_major_dim` vectors from the matrix `matrix`, starting at
     34   // `major_dim_offset` in the major dimension.  The tile size along the minor
     35   // dimension is the vector size, and that is implicitly determined by `vsl`.
     36   MemoryTile(VectorSupportLibrary* vsl, llvm::IRBuilder<>* b,
     37              llvm::Value* matrix, int64 matrix_size_along_minor_dim,
     38              llvm::Value* major_dim_offset, int64 tile_size_along_major_dim)
     39       : vsl_(vsl), b_(b) {
     40     pointers_.reserve(tile_size_along_major_dim);
     41     for (int64 i = 0; i < tile_size_along_major_dim; i++) {
     42       llvm::Value* total_offset =
     43           b->CreateMul(b->getInt64(matrix_size_along_minor_dim),
     44                        b->CreateAdd(b->getInt64(i), major_dim_offset));
     45       pointers_.push_back(vsl_->ComputeOffsetPointer(matrix, total_offset));
     46     }
     47   }
     48 
     49   // Load a tile consisting of `tile_size_along_major_dim` vectors from position
     50   // {major: `major_dim_offset`, minor: `minor_dim_offset`}.
     51   //
     52   // Note: `major_dim_offset` is a parameter to the constructor.
     53   std::vector<llvm::Value*> LoadTile(llvm::Value* minor_dim_offset) const {
     54     std::vector<llvm::Value*> result;
     55     result.reserve(pointers_.size());
     56     for (const auto& pointer : pointers_) {
     57       result.push_back(vsl_->LoadVector(pointer, minor_dim_offset));
     58     }
     59     return result;
     60   }
     61 
     62   // Stores `tile` to position {major: `major_dim_offset`, minor:
     63   // `minor_dim_offset`}.
     64   //
     65   // Note: `major_dim_offset` is a parameter to the constructor.
     66   void StoreTile(absl::Span<llvm::Value* const> tile,
     67                  llvm::Value* minor_dim_offset) const {
     68     CHECK_EQ(tile.size(), pointers_.size());
     69     for (int64 i = 0; i < pointers_.size(); i++) {
     70       vsl_->StoreVector(tile[i], pointers_[i], minor_dim_offset);
     71     }
     72   }
     73 
     74   // Loads a tile of size [`tile_size_along_major_dim`,
     75   // `tile_size_along_middle_dim`] from position {major: `major_dim_offset`,
     76   // minor: `minor_dim_offset`} and then broadcasts each element into a vector
     77   // of size vsl_.vector_size().  The (i,j)'th element of the return value is
     78   // the (i,j)'th element in the tile broadcasted into an LLVM vector.
     79   //
     80   // Note: `major_dim_offset` is a parameter to the constructor.
     81   std::vector<std::vector<llvm::Value*>> LoadBroadcastTile(
     82       llvm::Value* minor_dim_offset, int64 tile_size_along_middle_dim) const {
     83     std::vector<std::vector<llvm::Value*>> result;
     84     result.resize(pointers_.size());
     85     for (int64 i = 0; i < pointers_.size(); i++) {
     86       for (int64 j = 0; j < tile_size_along_middle_dim; j++) {
     87         result[i].push_back(vsl_->LoadBroadcast(
     88             pointers_[i], b_->CreateAdd(minor_dim_offset, b_->getInt64(j))));
     89       }
     90     }
     91     return result;
     92   }
     93 
     94  private:
     95   VectorSupportLibrary* vsl_;
     96   llvm::IRBuilder<>* b_;
     97   std::vector<llvm::Value*> pointers_;
     98 };
     99 
    100 // The base class for the classes representing the GEMV emitter configurations.
    101 //
    102 // The IR emitted (modulo the LLVM values representing the input and output
    103 // buffers) by the row major and column major GEMV emitters should be a function
    104 // of their configuration.  This is important because their configuration is
    105 // used as a key to cache the generated IR.
    106 class GemvConfig {
    107  public:
    108   // Mixin for convenience.
    109   template <typename T>
    110   struct User {
    111    public:
    112     PrimitiveType scalar_type() const {
    113       return derived().config().scalar_type();
    114     }
    115     int64 tile_rows() const { return derived().config().tile_rows(); }
    116     int64 tile_cols() const { return derived().config().tile_cols(); }
    117     int64 m() const { return derived().config().m(); }
    118     int64 k() const { return derived().config().k(); }
    119     int64 has_addend() const { return derived().config().has_addend(); }
    120 
    121    private:
    122     const T& derived() const { return *static_cast<const T*>(this); }
    123   };
    124 
    125   PrimitiveType scalar_type() const { return scalar_type_; }
    126   int64 tile_rows() const { return tile_rows_; }
    127   int64 tile_cols() const { return tile_cols_; }
    128   int64 m() const { return m_; }
    129   int64 k() const { return k_; }
    130   bool has_addend() const { return has_addend_; }
    131 
    132   string GetCacheKey() const {
    133     return absl::StrCat(name_, "_", PrimitiveType_Name(scalar_type()), "_",
    134                         tile_rows(), "_", tile_cols(), "_", m(), "_", k(),
    135                         has_addend() ? "_with_addend" : "");
    136   }
    137 
    138  protected:
    139   explicit GemvConfig(string name, PrimitiveType scalar_type, int64 tile_rows,
    140                       int64 tile_cols, int64 m, int64 k, bool has_addend)
    141       : name_(std::move(name)),
    142         scalar_type_(scalar_type),
    143         tile_rows_(tile_rows),
    144         tile_cols_(tile_cols),
    145         m_(m),
    146         k_(k),
    147         has_addend_(has_addend) {}
    148 
    149  private:
    150   string name_;
    151   PrimitiveType scalar_type_;
    152   int64 tile_rows_;
    153   int64 tile_cols_;
    154   int64 m_;
    155   int64 k_;
    156   bool has_addend_;
    157 };
    158 
    159 // Computes a dot product between "[M,K]{0,1} lhs" with a [K,1] vector (the
    160 // layout of the vector does not matter).  This implementation uses a tiling
    161 // scheme to improve performance.
    162 //
    163 // We logically separate the LHS matrix into four segments:
    164 //
    165 //   +----------------------+---+
    166 //   |                      |   |
    167 //   |                      |   |
    168 //   |         A            | B |
    169 //   |                      |   |
    170 //   |                      |   |
    171 //   |                      |   |
    172 //   +----------------------+---+
    173 //   |         C            | D |
    174 //   +----------------------+---+
    175 //
    176 // where A is the largest submatrix of the LHS that can be evenly dividied into
    177 // tiles.  For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have:
    178 //
    179 //   +---+---+---+---+       +--+--+--+--+
    180 //   |M00|M10|M20|M30|       |V0|V1|V2|V3|
    181 //   +---+---+---+---+       +--+--+--+--+
    182 //   |M01|M11|M21|M31| and   |V0|V1|V2|V3|
    183 //   +---+---+---+---+       +--+--+--+--+
    184 //   |M02|M12|M22|M32|       |V0|V1|V2|V3|
    185 //   +---+---+---+---+       +--+--+--+--+
    186 //   |M03|M13|M23|M33|       |V0|V1|V2|V3|
    187 //   +---+---+---+---+       +--+--+--+--+
    188 //
    189 // (Legend: rows are horizontal and columns are vertical; and each column is one
    190 // llvm::Value of a vector type)
    191 //
    192 // where:
    193 //
    194 //   a. The left tile is from the column major left matrix.
    195 //   b. The right tile is an elementwise broadcast of a [V0, V1, V2, V3]
    196 //      vector loaded from the RHS vector.
    197 //
    198 // As we iterate through the column dimension, we compute the change to the
    199 // result vector by an elementwise multiplication between the two tiles above
    200 // followed by a reduction along the major dimension:
    201 //
    202 //                     +-----------------------------------+
    203 //                     | M00*V0 + M10*V1 + M20*V2 + M30*V3 |
    204 //                     +-----------------------------------+
    205 //                     | M01*V0 + M11*V1 + M21*V2 + M31*V3 |
    206 // Result[R:R+4] +=    +-----------------------------------+
    207 //                     | M02*V0 + M12*V1 + M22*V2 + M32*V3 |
    208 //                     +-----------------------------------+
    209 //                     | M03*V0 + M13*V1 + M23*V2 + M33*V3 |
    210 //                     +-----------------------------------+
    211 //
    212 // Where R is the starting row for the tile.
    213 //
    214 // We have an inner epilogue loop to deal with the "C" submatrix and an outer
    215 // epilogue loop to deal with the B,D submarix.
    216 //
    217 // TODO(sanjoy): We should investigate if using gather loads and scatter stores
    218 // can be used here have the same inner loop for both column-major and row-major
    219 // matrix-vector products.
    220 class ColumnMajorMatrixVectorProductEmitter
    221     : public GemvConfig::User<ColumnMajorMatrixVectorProductEmitter> {
    222  public:
    223   class Config : public GemvConfig {
    224    public:
    225     explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols,
    226                     int64 m, int64 k, bool has_addend)
    227         : GemvConfig(/*name=*/"col_major_gemv", scalar_type,
    228                      /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m,
    229                      /*k=*/k, /*has_addend=*/has_addend) {}
    230   };
    231 
    232   ColumnMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs,
    233                                         llvm::Value* rhs, llvm::Value* addend,
    234                                         llvm::Value* result,
    235                                         llvm::IRBuilder<>* b)
    236       : config_(config),
    237         lhs_(lhs),
    238         rhs_(rhs),
    239         addend_(addend),
    240         result_(result),
    241         b_(b),
    242         ksl_(b_),
    243         vsl_(config.scalar_type(), /*vector_size=*/config.tile_rows(), b_, "") {
    244     CHECK(tile_rows() > 0 && IsPowerOfTwo(static_cast<uint64>(tile_rows())));
    245     CHECK(!has_addend() || addend != nullptr);
    246   }
    247 
    248   void Emit();
    249 
    250   const Config& config() const { return config_; }
    251 
    252  private:
    253   void EmitOuterLoopBody(llvm::Value* column, int64 column_count,
    254                          bool is_first_column);
    255 
    256   MemoryTile GetLhsMemoryTile(llvm::Value* column_start, int64 column_count) {
    257     return MemoryTile(&vsl_, b_, /*matrix=*/lhs_,
    258                       /*matrix_size_along_minor_dim=*/m(),
    259                       /*major_dim_offset=*/column_start,
    260                       /*tile_size_along_major_dim=*/column_count);
    261   }
    262 
    263   // Load a tile of values from the RHS.  For the RHS a "tile" is a contiguous
    264   // sequence of `count` values, each one broadcasted to the vector width.
    265   std::vector<llvm::Value*> LoadRhsTile(llvm::Value* offset, int64 count) {
    266     llvm::Value* base_pointer = vsl_.ComputeOffsetPointer(rhs_, offset);
    267     std::vector<llvm::Value*> result;
    268     result.reserve(count);
    269     for (int64 i = 0; i < count; i++) {
    270       result.push_back(vsl_.LoadBroadcast(base_pointer, i));
    271     }
    272     return result;
    273   }
    274 
    275   void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile,
    276                           const std::vector<llvm::Value*>& rhs_tile,
    277                           int64 columns, bool is_first_column);
    278 
    279   void EmitInnerLoopEpilogue(llvm::Value* current_tile_col, int64 columns,
    280                              bool is_first_tiled_column);
    281 
    282   Config config_;
    283   llvm::Value* lhs_;
    284   llvm::Value* rhs_;
    285   llvm::Value* addend_;
    286   llvm::Value* result_;
    287   llvm::IRBuilder<>* b_;
    288   KernelSupportLibrary ksl_;
    289   VectorSupportLibrary vsl_;
    290 };
    291 
    292 void ColumnMajorMatrixVectorProductEmitter::EmitOuterLoopBody(
    293     llvm::Value* column, int64 column_count, bool is_first_column) {
    294   MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*column_start=*/column,
    295                                                 /*column_count=*/column_count);
    296 
    297   std::vector<llvm::Value*> rhs_tile =
    298       LoadRhsTile(column, /*count=*/column_count);
    299   EmitInnerLoopTiled(&lhs_memory_tile, rhs_tile,
    300                      /*columns=*/column_count, is_first_column);
    301   EmitInnerLoopEpilogue(column, /*columns=*/column_count, is_first_column);
    302 }
    303 
    304 void ColumnMajorMatrixVectorProductEmitter::Emit() {
    305   // See the comment on the class declaration for the algorithm used here.
    306   int64 column_remainder = k() % tile_cols();
    307   int64 column_limit = k() - column_remainder;
    308 
    309   ksl_.For("dot.outer.tiled",
    310            /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols(),
    311            [&](llvm::Value* column, bool is_first_column) {
    312              EmitOuterLoopBody(column, tile_cols(), is_first_column);
    313            });
    314 
    315   if (column_remainder != 0) {
    316     EmitOuterLoopBody(b_->getInt64(column_limit), column_remainder,
    317                       column_limit == 0);
    318   }
    319 }
    320 
    321 void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled(
    322     MemoryTile* lhs_memory_tile, const std::vector<llvm::Value*>& rhs_tile,
    323     int64 columns, bool is_first_column) {
    324   int64 row_limit = m() - (m() % tile_rows());
    325 
    326   ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/row_limit,
    327            /*step=*/tile_rows(), [&](llvm::Value* row) {
    328              std::vector<llvm::Value*> lhs_tile =
    329                  lhs_memory_tile->LoadTile(/*minor_dim_offset=*/row);
    330              llvm::Value* accumulator =
    331                  is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row)
    332                                             : vsl_.GetZeroVector())
    333                                  : vsl_.LoadVector(result_, row);
    334              for (int i = 0; i < columns; i++) {
    335                accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator);
    336              }
    337              vsl_.StoreVector(accumulator, result_, row);
    338            });
    339 }
    340 
    341 void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
    342     llvm::Value* current_tile_col, int64 columns, bool is_first_tiled_column) {
    343   int64 row_start = m() - (m() % tile_rows());
    344   if (row_start == m()) {
    345     return;
    346   }
    347 
    348   llvm::Value* columns_llvm = b_->getInt64(columns);
    349 
    350   // for (col = current_tile_col; col < (columns + current_tile_col); col++)
    351   //   for (row = row_start, row < m_; row++) {
    352   //     result[row] += lhs[row, col] * rhs[col]
    353   //     // Also take into account that if col is 0 then result[row] is not
    354   //     // initialized.
    355   //   }
    356 
    357   ksl_.For(
    358       "dot.inner.epilg.outer", /*start=*/current_tile_col,
    359       /*end=*/b_->CreateAdd(columns_llvm, current_tile_col),
    360       /*step=*/1, /*peel_first_iteration=*/false,
    361       [&](llvm::Value* col, llvm::Value* is_first_scalar_col) {
    362         llvm::Value* rhs_element = vsl_.LoadScalar(rhs_, col);
    363         llvm::Value* total_offset = b_->CreateMul(col, b_->getInt64(m()));
    364         llvm::Value* lhs_base_pointer =
    365             vsl_.ComputeOffsetPointer(lhs_, total_offset);
    366         ksl_.For(
    367             "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m(),
    368             /*step=*/1, [&](llvm::Value* scalar_row) {
    369               llvm::Value* product = vsl_.Mul(
    370                   vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element);
    371               llvm::Value* setting_result_first_time = b_->CreateAnd(
    372                   is_first_scalar_col, b_->getInt1(is_first_tiled_column));
    373               ksl_.If(
    374                   setting_result_first_time,
    375                   /*true_block_generator=*/
    376                   [&]() {
    377                     if (addend_) {
    378                       vsl_.StoreScalar(
    379                           vsl_.Add(vsl_.LoadScalar(addend_, scalar_row),
    380                                    product),
    381                           result_, scalar_row);
    382                     } else {
    383                       vsl_.StoreScalar(product, result_, scalar_row);
    384                     }
    385                   },
    386                   /*false_block_generator=*/
    387                   [&]() {
    388                     vsl_.StoreScalar(
    389                         vsl_.Add(vsl_.LoadScalar(result_, scalar_row), product),
    390                         result_, scalar_row);
    391                   });
    392             });
    393       });
    394 }
    395 
    396 // Computes a dot product between "[M,K]{1,0} lhs" with a [K,1] vector (the
    397 // layout of the vector does not matter).  This implementation uses a tiling
    398 // scheme to improve performance.
    399 //
    400 // We logically separate the LHS matrix into four segments:
    401 //
    402 //   +----------------------+---+
    403 //   |                      |   |
    404 //   |                      |   |
    405 //   |         A            | B |
    406 //   |                      |   |
    407 //   |                      |   |
    408 //   |                      |   |
    409 //   +----------------------+---+
    410 //   |         C            | D |
    411 //   +----------------------+---+
    412 //
    413 // where A is the largest submatrix of the LHS that can be evenly dividied into
    414 // tiles.  For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have:
    415 //
    416 //   +---+---+---+---+
    417 //   |M00|M10|M20|M30|
    418 //   +---+---+---+---+       +--+--+--+--+
    419 //   |M01|M11|M21|M31| and   |V0|V1|V2|V3|
    420 //   +---+---+---+---+       +--+--+--+--+
    421 //   |M02|M12|M22|M32|
    422 //   +---+---+---+---+
    423 //   |M03|M13|M23|M33|
    424 //   +---+---+---+---+
    425 //
    426 // (Legend: rows are horizontal and columns are vertical; and each row is one
    427 // llvm::Value of a vector type)
    428 //
    429 // where:
    430 //
    431 //   a. The left tile is loaded from the row major left matrix.
    432 //   b. The right vector is loaded from the RHS vector.
    433 //
    434 // We keep 4 vector accumulators accumulating the following four vector
    435 // expressions as we iterate over the row dimension:
    436 //
    437 //   +------+------+------+------+
    438 //   |M0I*V0|M1I*V1|M2I*V2|M3I*V3|  for I in [0,4)
    439 //   +------+------+------+------+
    440 //
    441 // In the end we do a horizontal reduction over these 4 vector accumulators to
    442 // get 4 values in the result vector.
    443 //
    444 // We have an inner epilogue loop to deal with the "B" sub-matrix and an outer
    445 // epilogue loop to deal with the C,D submatrix.
    446 class RowMajorMatrixVectorProductEmitter
    447     : public GemvConfig::User<RowMajorMatrixVectorProductEmitter> {
    448  public:
    449   class Config : public GemvConfig {
    450    public:
    451     explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols,
    452                     int64 m, int64 k, bool has_addend)
    453         : GemvConfig(/*name=*/"row_major_gemv", scalar_type,
    454                      /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m,
    455                      /*k=*/k, /*has_addend=*/has_addend) {}
    456   };
    457 
    458   RowMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs,
    459                                      llvm::Value* rhs, llvm::Value* addend,
    460                                      llvm::Value* result, llvm::IRBuilder<>* b)
    461       : config_(config),
    462         lhs_(lhs),
    463         rhs_(rhs),
    464         addend_(addend),
    465         result_(result),
    466         b_(b),
    467         ksl_(b_),
    468         vsl_(scalar_type(), /*vector_size=*/tile_cols(), b_, "") {
    469     CHECK(tile_cols() > 0 && IsPowerOfTwo(static_cast<uint64>(tile_cols())));
    470     CHECK(!has_addend() || addend != nullptr);
    471   }
    472 
    473   void Emit();
    474 
    475   const Config& config() const { return config_; }
    476 
    477  private:
    478   MemoryTile GetLhsMemoryTile(llvm::Value* row_start, int64 row_count) {
    479     return MemoryTile(&vsl_, b_, /*matrix=*/lhs_,
    480                       /*matrix_size_along_minor_dim=*/k(),
    481                       /*major_dim_offset=*/row_start,
    482                       /*tile_size_along_major_dim=*/row_count);
    483   }
    484 
    485   void EmitOuterLoopBody(llvm::Value* row, int64 row_count);
    486 
    487   void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile, int64 rows,
    488                           std::vector<VectorVariable>* vector_accumulators);
    489 
    490   void EmitInnerLoopEpilogue(llvm::Value* current_tile_row, int64 rows,
    491                              std::vector<ScalarVariable>* scalar_accumulators);
    492 
    493   Config config_;
    494   llvm::Value* lhs_;
    495   llvm::Value* rhs_;
    496   llvm::Value* addend_;
    497   llvm::Value* result_;
    498   llvm::IRBuilder<>* b_;
    499   KernelSupportLibrary ksl_;
    500   VectorSupportLibrary vsl_;
    501 };
    502 
    503 void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row,
    504                                                            int64 row_count) {
    505   MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*row_start=*/row,
    506                                                 /*row_count=*/row_count);
    507   std::vector<VectorVariable> vector_accumulators;
    508   std::vector<ScalarVariable> scalar_accumulators;
    509   for (int i = 0; i < row_count; i++) {
    510     vector_accumulators.emplace_back(&vsl_, vsl_.GetZeroVector());
    511     scalar_accumulators.emplace_back(&vsl_, vsl_.GetZeroScalar());
    512   }
    513   EmitInnerLoopTiled(&lhs_memory_tile, /*rows=*/row_count,
    514                      &vector_accumulators);
    515   EmitInnerLoopEpilogue(/*current_tile_row=*/row, /*rows=*/row_count,
    516                         &scalar_accumulators);
    517 
    518   std::vector<llvm::Value*> accumulator_values;
    519   std::transform(
    520       vector_accumulators.begin(), vector_accumulators.end(),
    521       std::back_inserter(accumulator_values),
    522       [](const VectorVariable& vector_var) { return vector_var.Get(); });
    523 
    524   std::vector<llvm::Value*> horizontal_sums;
    525   if (row_count == vsl_.vector_size()) {
    526     if (addend_) {
    527       horizontal_sums = vsl_.ComputeHorizontalSums(
    528           std::move(accumulator_values), vsl_.LoadVector(addend_, row));
    529     } else {
    530       horizontal_sums =
    531           vsl_.ComputeHorizontalSums(std::move(accumulator_values));
    532     }
    533   } else {
    534     horizontal_sums = vsl_.ComputeHorizontalSums(std::move(accumulator_values));
    535   }
    536 
    537   for (int i = 0; i < row_count; i++) {
    538     llvm::Value* result_value =
    539         vsl_.Add(horizontal_sums[i], scalar_accumulators[i].Get());
    540     llvm::Value* offset = b_->CreateAdd(b_->getInt64(i), row);
    541     if (addend_ && row_count != vsl_.vector_size()) {
    542       result_value = vsl_.Add(vsl_.LoadScalar(addend_, offset), result_value);
    543     }
    544     vsl_.StoreScalar(result_value, result_, offset);
    545   }
    546 }
    547 
    548 void RowMajorMatrixVectorProductEmitter::Emit() {
    549   // See the comment on the class declaration for the algorithm used here.
    550   int64 row_remainder = m() % tile_rows();
    551   int64 row_limit = m() - row_remainder;
    552 
    553   ksl_.For("dot.outer.tiled",
    554            /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(),
    555            [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); });
    556 
    557   if (row_remainder != 0) {
    558     EmitOuterLoopBody(b_->getInt64(row_limit), row_remainder);
    559   }
    560 }
    561 
    562 void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled(
    563     MemoryTile* lhs_memory_tile, int64 rows,
    564     std::vector<VectorVariable>* vector_accumulators) {
    565   int64 column_limit = k() - (k() % tile_cols());
    566 
    567   ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/column_limit,
    568            /*step=*/tile_cols(), [&](llvm::Value* col) {
    569              std::vector<llvm::Value*> lhs_tile =
    570                  lhs_memory_tile->LoadTile(/*minor_dim_offset=*/col);
    571              llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col);
    572              for (int i = 0; i < rows; i++) {
    573                llvm::Value* old_sum = (*vector_accumulators)[i].Get();
    574                (*vector_accumulators)[i].Set(
    575                    vsl_.Add(old_sum, vsl_.Mul(rhs_value, lhs_tile[i])));
    576              }
    577            });
    578 }
    579 
    580 void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
    581     llvm::Value* current_tile_row, int64 rows,
    582     std::vector<ScalarVariable>* scalar_accumulators) {
    583   int64 column_start = k() - (k() % tile_cols());
    584   if (column_start == k()) {
    585     return;
    586   }
    587 
    588   for (int r = 0; r < rows; r++) {
    589     llvm::Value* total_offset = b_->CreateMul(
    590         b_->CreateAdd(b_->getInt64(r), current_tile_row), b_->getInt64(k()));
    591     llvm::Value* lhs_base_pointer =
    592         vsl_.ComputeOffsetPointer(lhs_, total_offset);
    593     ksl_.For("dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k(),
    594              /*step=*/1, [&](llvm::Value* scalar_col) {
    595                llvm::Value* product =
    596                    vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col),
    597                             vsl_.LoadScalar(rhs_, scalar_col));
    598                llvm::Value* old_value = (*scalar_accumulators)[r].Get();
    599                (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product));
    600              });
    601   }
    602 }
    603 
    604 // This class implements a tiled matrix multiplication algorithm, intended for
    605 // multiplying small matrices that don't need cache tiling.
    606 //
    607 // In the future this can be used as the innermost GEBP loop in a GEMM kernel as
    608 // described in "Goto, Kazushige, and Robert A. Geijn. "Anatomy of
    609 // high-performance matrix multiplication." ACM Transactions on Mathematical
    610 // Software (TOMS) 34.3 (2008): 12.".
    611 //
    612 // This only supports canonical dot operations (i.e. where the lhs contraction
    613 // dimension is 1 and the rhs contraction dimension is 0) over row major
    614 // matrices.
    615 class TiledSmallGemmEmitter {
    616  public:
    617   // Describe the dimensions of the kernel.
    618   class Dimensions {
    619    public:
    620     explicit Dimensions(int64 m, int64 k, int64 n) : m_(m), k_(k), n_(n) {}
    621 
    622     int64 m() const { return m_; }
    623     int64 k() const { return k_; }
    624     int64 n() const { return n_; }
    625 
    626     string ToString() const { return absl::StrCat(m(), "x", k(), "x", n()); }
    627 
    628    private:
    629     const int64 m_;
    630     const int64 k_;
    631     const int64 n_;
    632   };
    633 
    634   // Represents the configuration of the emitter.  The LLVM IR emitted by the
    635   // emitter, modulo the LLVM values holding the input and output buffers, must
    636   // be a function of the instance of `Config` passed to it.
    637   //
    638   // `dims` holds the matrix multiplication dimensions.
    639   //
    640   // `max_vectorization_width` is the maximum vector width (i.e. the width of
    641   // the largest vector register we will use).  This can be larger than the
    642   // largest vector register supported by the machine -- LLVM will legalize
    643   // these large vector widths into legally sized vectors.
    644   //
    645   // `max_vector_count` is the maximum number of vectors of size
    646   // `max_vectorization_width` that we will attempt to process at once.
    647   //
    648   // `min_vectorization_width` is the smallest vector width the emitter will use
    649   // -- below that it will devolve to using a scalar loop.
    650   //
    651   // The innermost reduction loop executes the matrix multiply in tiles of size
    652   // [`tile_size_m`, `tile_size_k`] from the LHS and [`tile_size_k`,
    653   // <vectorization width>] in the RHS.
    654   class Config {
    655    public:
    656     explicit Config(PrimitiveType scalar_type, Dimensions dims,
    657                     int64 max_vectorization_width, int64 max_vector_count,
    658                     int64 min_vectorization_width, int64 tile_size_m,
    659                     int64 tile_size_k)
    660         : scalar_type_(scalar_type),
    661           dims_(dims),
    662           max_vectorization_width_(max_vectorization_width),
    663           max_vector_count_(max_vector_count),
    664           min_vectorization_width_(min_vectorization_width),
    665           tile_size_m_(tile_size_m),
    666           tile_size_k_(tile_size_k) {}
    667 
    668     string GetCacheKey() const {
    669       return absl::StrCat("gemm_", PrimitiveType_Name(scalar_type()), "_",
    670                           dims().ToString(), "_", max_vectorization_width(),
    671                           "_", min_vectorization_width(), "_", tile_size_m(),
    672                           "_", tile_size_k());
    673     }
    674 
    675     PrimitiveType scalar_type() const { return scalar_type_; }
    676     Dimensions dims() const { return dims_; }
    677     int64 max_vectorization_width() const { return max_vectorization_width_; }
    678     int64 max_vector_count() const { return max_vector_count_; }
    679     int64 min_vectorization_width() const { return min_vectorization_width_; }
    680 
    681     int64 tile_size_m() const { return tile_size_m_; }
    682     int64 tile_size_k() const { return tile_size_k_; }
    683 
    684    private:
    685     PrimitiveType scalar_type_;
    686     Dimensions dims_;
    687     int64 max_vectorization_width_;
    688     int64 max_vector_count_;
    689     int64 min_vectorization_width_;
    690     int64 tile_size_m_;
    691     int64 tile_size_k_;
    692   };
    693 
    694   // Creates an instance of TiledSmallGemmEmitter that matrix-multiplies
    695   // `lhs` with `rhs` and stores the result in `result`.
    696   explicit TiledSmallGemmEmitter(Config config, llvm::Value* lhs,
    697                                  llvm::Value* rhs, llvm::Value* result,
    698                                  llvm::IRBuilder<>* b)
    699       : lhs_(lhs),
    700         rhs_(rhs),
    701         result_(result),
    702         config_(config),
    703         b_(b),
    704         ksl_(b_) {
    705     CHECK(max_vectorization_width() > 0 &&
    706           IsPowerOfTwo(static_cast<uint64>(max_vectorization_width())));
    707     CHECK_GT(max_vector_count(), 0);
    708     CHECK(min_vectorization_width() > 0 &&
    709           IsPowerOfTwo(static_cast<uint64>(min_vectorization_width())));
    710     CHECK_GE(max_vectorization_width(), min_vectorization_width());
    711     CHECK_GT(tile_size_k(), 0);
    712   }
    713 
    714   void Emit();
    715 
    716  private:
    717   // The HandleResiduesOnX helpers split the iteration space for dimension X
    718   // into a multiple of the tile size on dimension X and an epilogue.  These
    719   // helpers ultimately call into `EmitTiledGemm` for emitting the
    720   // tiled GEMM kernel.
    721 
    722   void HandleResiduesOnN();
    723   void HandleResiduesOnK(VectorSupportLibrary* vsl, llvm::Value* n_start,
    724                          llvm::Value* n_end);
    725   void HandleResiduesOnM(VectorSupportLibrary* vsl, int64 tile_size_k,
    726                          llvm::Value* k_start, llvm::Value* k_end,
    727                          llvm::Value* n_start, llvm::Value* n_end);
    728 
    729   // This emits a tiled GEMM kernel.  For a detailed description see the comment
    730   // on the implementation.
    731   void EmitTiledGemm(VectorSupportLibrary* vsl, int64 tile_size_k,
    732                      llvm::Value* k_start, llvm::Value* k_end,
    733                      llvm::Value* n_start, llvm::Value* n_end,
    734                      int64 tile_size_m, llvm::Value* m_start,
    735                      llvm::Value* m_end);
    736 
    737   llvm::Value* GetInt64(int64 value) { return b_->getInt64(value); }
    738 
    739   Config config() const { return config_; }
    740   Dimensions dims() const { return config().dims(); }
    741 
    742   int64 max_vectorization_width() const {
    743     return config().max_vectorization_width();
    744   }
    745   int64 max_vector_count() const { return config().max_vector_count(); }
    746   int64 min_vectorization_width() const {
    747     return config().min_vectorization_width();
    748   }
    749   int64 tile_size_m() const { return config().tile_size_m(); }
    750   int64 tile_size_k() const { return config().tile_size_k(); }
    751   PrimitiveType scalar_type() const { return config().scalar_type(); }
    752 
    753   llvm::Value* lhs_;
    754   llvm::Value* rhs_;
    755   llvm::Value* result_;
    756   Config config_;
    757 
    758   llvm::IRBuilder<>* b_;
    759   KernelSupportLibrary ksl_;
    760 };
    761 
    762 void TiledSmallGemmEmitter::Emit() { HandleResiduesOnN(); }
    763 
    764 void TiledSmallGemmEmitter::HandleResiduesOnN() {
    765   // We can only iterate the `n` dimension for an extent that is divisible by
    766   // the vectorization width.  So we emit an outer loop that first processes the
    767   // largest extent in `n` that is divisible by max_vectorization_width, then
    768   // the largest remaining extent that is divisible by max_vectorization_width /
    769   // 2 etc.
    770 
    771   int64 current_vectorization_width =
    772       max_vector_count() * max_vectorization_width();
    773   int64 current_vector_count = max_vector_count();
    774 
    775   int64 n_start = 0;
    776   while (n_start != dims().n() &&
    777          current_vectorization_width >= min_vectorization_width()) {
    778     int64 n_end = dims().n() - (dims().n() % current_vectorization_width);
    779     if (n_start != n_end) {
    780       VectorSupportLibrary vsl(scalar_type(), current_vectorization_width, b_,
    781                                "gemm");
    782       HandleResiduesOnK(&vsl, GetInt64(n_start), GetInt64(n_end));
    783       n_start = n_end;
    784     }
    785     if (current_vector_count == 1) {
    786       current_vectorization_width /= 2;
    787     } else {
    788       current_vector_count--;
    789       current_vectorization_width =
    790           current_vector_count * max_vectorization_width();
    791     }
    792   }
    793 
    794   if (n_start != dims().n()) {
    795     VectorSupportLibrary vsl(scalar_type(), 1, b_, "gemm");
    796     ksl_.For("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) {
    797       llvm::Value* n_i_next = b_->CreateAdd(n_i, b_->getInt64(1));
    798       HandleResiduesOnK(&vsl, n_i, n_i_next);
    799     });
    800   }
    801 }
    802 
    803 void TiledSmallGemmEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl,
    804                                               llvm::Value* n_start,
    805                                               llvm::Value* n_end) {
    806   int64 k_start = 0;
    807   int64 k_end = dims().k() - (dims().k() % tile_size_k());
    808   if (k_end != k_start) {
    809     HandleResiduesOnM(vsl, tile_size_k(), GetInt64(k_start), GetInt64(k_end),
    810                       n_start, n_end);
    811     k_start = k_end;
    812   }
    813 
    814   if (k_start != dims().k()) {
    815     HandleResiduesOnM(vsl, dims().k() - k_start, GetInt64(k_start),
    816                       GetInt64(dims().k()), n_start, n_end);
    817   }
    818 }
    819 
    820 void TiledSmallGemmEmitter::HandleResiduesOnM(
    821     VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start,
    822     llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end) {
    823   const int64 m_end = dims().m() - dims().m() % tile_size_m();
    824   EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end, tile_size_m(),
    825                 GetInt64(0), GetInt64(m_end));
    826 
    827   if (m_end != dims().m()) {
    828     EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end,
    829                   dims().m() - m_end, GetInt64(m_end), GetInt64(dims().m()));
    830   }
    831 }
    832 
    833 // The loop structure is:
    834 //
    835 // Iterate over dimension M as m:
    836 //   Iterate over dimension N as n:
    837 //     Iterate over dimension K as k:
    838 //       OutputTile[m,n] += Dot(LhsTile[m,k], RhsTile[k,n])
    839 //
    840 // I.e. a just a tiled version of a "naive" GEMM.
    841 //
    842 // The tiling scheme is as follows:
    843 //
    844 // Let the LHS be:
    845 //
    846 //   +----+----+----+
    847 //   | a0 | b0 | c0 | .
    848 //   +----+----+----+ .
    849 //   | a1 | b1 | c1 | .
    850 //   +----+----+----+
    851 //     ..     ..
    852 //
    853 // and the RHS be:
    854 //
    855 //   +----+----+----+----+
    856 //   | p0 | p1 | p2 | p3 | .
    857 //   +----+----+----+----+ .
    858 //   | q0 | q1 | q2 | q3 | .
    859 //   +----+----+----+----+
    860 //   | r0 | r1 | r2 | r3 | .
    861 //   +----+----+----+----+ .
    862 //     ......    ......
    863 //
    864 // and let tile_size_m=2, tile_size_k=3 and the vector width (implicitly denoted
    865 // by `vsl`) be 4.  Then we want to matrix multiply this tile to get a [2,4]
    866 // matrix that we can increment the result matrix by.
    867 //
    868 // First broadcast the rows row in LHS to 3 vectors of width 4, giving us a rank
    869 // 3 array, L, of dimension [2,3,4]:
    870 //
    871 //       L[0,_,_]           *      L[1,_,_]
    872 //                          *
    873 //   +----+----+----+----+  *  +----+----+----+----+
    874 //   | a0 | a0 | a0 | a0 |  *  | a1 | a1 | a1 | a1 |
    875 //   +----+----+----+----+  *  +----+----+----+----+
    876 //   | b0 | b0 | b0 | b0 |  *  | b1 | b1 | b1 | b1 |
    877 //   +----+----+----+----+  *  +----+----+----+----+
    878 //   | c0 | c0 | c0 | c0 |  *  | c1 | c1 | c1 | c1 |
    879 //   +----+----+----+----+  *  +----+----+----+----+
    880 //
    881 //
    882 // Then we FMA L[0,_,_] with the RHS to get the first row of the result and
    883 // L[1,_,_] with the RHS to get the second row of the result.  For example,
    884 // L[0,_,_] is computed as:
    885 //
    886 //   +----+----+----+----+   +----+----+----+----+
    887 //   | a0 | a0 | a0 | a0 | * | p0 | p1 | p2 | p3 |   +
    888 //   +----+----+----+----+   +----+----+----+----+
    889 //
    890 //   +----+----+----+----+   +----+----+----+----+
    891 //   | b0 | b0 | b0 | b0 | * | q0 | q1 | q2 | q3 |   +
    892 //   +----+----+----+----+   +----+----+----+----+
    893 //
    894 //   +----+----+----+----+   +----+----+----+----+
    895 //   | c0 | c0 | c0 | c0 | * | r0 | r1 | r2 | r3 |
    896 //   +----+----+----+----+   +----+----+----+----+
    897 //
    898 // to get:
    899 //
    900 //   +-------------------+-------------------+-------------------+---------
    901 //   | a0*p0+b0*q0+c0*r0 | a0*p1+b0*q1+c0*r1 | a0*p2+b0*q2+c0*r2 |  ...
    902 //   +-------------------+-------------------+-------------------+---------
    903 void TiledSmallGemmEmitter::EmitTiledGemm(
    904     VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start,
    905     llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end,
    906     int64 tile_size_m, llvm::Value* m_start, llvm::Value* m_end) {
    907   ksl_.For("dot.m", m_start, m_end, tile_size_m, [&](llvm::Value* m_i) {
    908     MemoryTile result_memory_tile(vsl, b_, /*matrix=*/result_,
    909                                   /*matrix_size_along_minor_dim=*/dims().n(),
    910                                   /*major_dim_offset=*/m_i,
    911                                   /*tile_size_along_major_dim=*/tile_size_m);
    912     MemoryTile lhs_memory_tile(vsl, b_, /*matrix=*/lhs_,
    913                                /*matrix_size_along_minor_dim=*/dims().k(),
    914                                /*major_dim_offset=*/m_i,
    915                                /*tile_size_along_major_dim=*/tile_size_m);
    916     ksl_.For(
    917         "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) {
    918           TileVariable result_tile_var(vsl, result_memory_tile.LoadTile(n_i));
    919           ksl_.For("dot.k", k_start, k_end, tile_size_k, [&](llvm::Value* k_i) {
    920             MemoryTile rhs_memory_tile(vsl, b_, rhs_, dims().n(), k_i,
    921                                        tile_size_k);
    922             std::vector<std::vector<llvm::Value*>> lhs_tile =
    923                 lhs_memory_tile.LoadBroadcastTile(k_i, tile_size_k);
    924             std::vector<llvm::Value*> rhs_tile = rhs_memory_tile.LoadTile(n_i);
    925             std::vector<llvm::Value*> result_tile = result_tile_var.Get();
    926             for (int64 r_m_i = 0; r_m_i < tile_size_m; r_m_i++) {
    927               for (int64 r_k_i = 0; r_k_i < tile_size_k; r_k_i++) {
    928                 result_tile[r_m_i] =
    929                     vsl->MulAdd(lhs_tile[r_m_i][r_k_i], rhs_tile[r_k_i],
    930                                 result_tile[r_m_i]);
    931               }
    932             }
    933             result_tile_var.Set(result_tile);
    934           });
    935 
    936           result_memory_tile.StoreTile(result_tile_var.Get(), n_i);
    937         });
    938   });
    939 }
    940 
    941 llvm::Type* GetPointerToElementType(llvm::Type* pointer_type) {
    942   llvm::Type* type =
    943       llvm::cast<llvm::PointerType>(pointer_type)->getElementType();
    944   while (auto* array_type = llvm::dyn_cast<llvm::ArrayType>(type)) {
    945     type = array_type->getElementType();
    946   }
    947 
    948   return type->getPointerTo();
    949 }
    950 
    951 struct GemvBuffersWithCanonicalType {
    952   llvm::Value* lhs_canonicalized;
    953   llvm::Value* rhs_canonicalized;
    954   llvm::Value* addend_canonicalized;
    955   llvm::Value* result_canonicalized;
    956 };
    957 
    958 GemvBuffersWithCanonicalType GetGemvBuffersWithCanonicalType(
    959     llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend,
    960     llvm::Value* result, llvm::IRBuilder<>* b) {
    961   // We characterize a GEMV operation via M and K, since N is implicitly 1.
    962   // This means the GEMV that multiplies (say) [5,6] with [6,1] is implemented
    963   // by the same GEMV that multiplies [5,6] with [1,6].  However, the
    964   // `llvm::Types` for the inputs to the two GEMVs don't match (in a trivial
    965   // sense -- the in memory representations are the same) since they're computed
    966   // from the `xla::Shape`s.  Since we want to be able to call the same
    967   // `llvm::Function` for the two GEMVs we canonicalize the types of the GEMV
    968   // inputs here into the same type.
    969   GemvBuffersWithCanonicalType buffers_with_canonical_type;
    970   llvm::Type* lhs_type = lhs->getType();
    971   llvm::Type* rhs_type = rhs->getType();
    972   llvm::Type* addend_type = addend ? addend->getType() : nullptr;
    973   llvm::Type* result_type = result->getType();
    974 
    975   buffers_with_canonical_type.lhs_canonicalized =
    976       b->CreateBitCast(lhs, GetPointerToElementType(lhs_type));
    977   buffers_with_canonical_type.rhs_canonicalized =
    978       b->CreateBitCast(rhs, GetPointerToElementType(rhs_type));
    979   buffers_with_canonical_type.addend_canonicalized =
    980       addend ? b->CreateBitCast(addend, GetPointerToElementType(addend_type))
    981              : nullptr;
    982   buffers_with_canonical_type.result_canonicalized =
    983       b->CreateBitCast(result, GetPointerToElementType(result_type));
    984 
    985   return buffers_with_canonical_type;
    986 }
    987 
    988 }  // namespace
    989 
    990 void EmitRowMajorGemv(PrimitiveType scalar_type, int64 tile_rows,
    991                       int64 tile_cols, int64 m, int64 k, llvm::Value* lhs,
    992                       llvm::Value* rhs, llvm::Value* addend,
    993                       llvm::Value* result, llvm::IRBuilder<>* b,
    994                       const HloModuleConfig& module_config) {
    995   RowMajorMatrixVectorProductEmitter::Config config(
    996       /*scalar_type=*/scalar_type,
    997       /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols,
    998       /*m=*/m, /*k=*/k, /*has_addend=*/addend != nullptr);
    999 
   1000   GemvBuffersWithCanonicalType canonical_inputs =
   1001       GetGemvBuffersWithCanonicalType(lhs, rhs, addend, result, b);
   1002 
   1003   KernelSupportLibrary::EmitAndCallOutlinedKernel(
   1004       module_config, b, config.GetCacheKey(),
   1005       canonical_inputs.lhs_canonicalized, canonical_inputs.rhs_canonicalized,
   1006       canonical_inputs.addend_canonicalized,
   1007       canonical_inputs.result_canonicalized,
   1008       [&config, b, &canonical_inputs](llvm::Value* lhs, llvm::Value* rhs,
   1009                                       llvm::Value* addend,
   1010                                       llvm::Value* result) {
   1011         RowMajorMatrixVectorProductEmitter emitter(config, lhs, rhs, addend,
   1012                                                    result, b);
   1013         emitter.Emit();
   1014       });
   1015 }
   1016 
   1017 void EmitColumnMajorGemv(PrimitiveType scalar_type, int64 tile_rows,
   1018                          int64 tile_cols, int64 m, int64 k, llvm::Value* lhs,
   1019                          llvm::Value* rhs, llvm::Value* addend,
   1020                          llvm::Value* result, llvm::IRBuilder<>* b,
   1021                          const HloModuleConfig& module_config) {
   1022   ColumnMajorMatrixVectorProductEmitter::Config config(
   1023       /*scalar_type=*/scalar_type,
   1024       /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols,
   1025       /*m=*/m, /*k=*/k, /*has_addend=*/addend != nullptr);
   1026 
   1027   GemvBuffersWithCanonicalType canonical_inputs =
   1028       GetGemvBuffersWithCanonicalType(lhs, rhs, addend, result, b);
   1029 
   1030   KernelSupportLibrary::EmitAndCallOutlinedKernel(
   1031       module_config, b, config.GetCacheKey(),
   1032       canonical_inputs.lhs_canonicalized, canonical_inputs.rhs_canonicalized,
   1033       canonical_inputs.addend_canonicalized,
   1034       canonical_inputs.result_canonicalized,
   1035       [&config, b, &canonical_inputs](llvm::Value* lhs, llvm::Value* rhs,
   1036                                       llvm::Value* addend,
   1037                                       llvm::Value* result) {
   1038         ColumnMajorMatrixVectorProductEmitter emitter(config, lhs, rhs, addend,
   1039                                                       result, b);
   1040         emitter.Emit();
   1041       });
   1042 }
   1043 
   1044 void EmitSmallGemm(PrimitiveType scalar_type, int64 m, int64 k, int64 n,
   1045                    int64 max_vectorization_width, int64 max_vector_count,
   1046                    int64 min_vectorization_width, int64 tile_size_m,
   1047                    int64 tile_size_k, llvm::Value* lhs, llvm::Value* rhs,
   1048                    llvm::Value* result, llvm::IRBuilder<>* b,
   1049                    const HloModuleConfig& module_config) {
   1050   TiledSmallGemmEmitter::Config config(
   1051       /*scalar_type=*/scalar_type,
   1052       TiledSmallGemmEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n},
   1053       /*max_vectorization_width=*/max_vectorization_width,
   1054       /*max_vector_count=*/max_vector_count,
   1055       /*min_vectorization_width=*/min_vectorization_width,
   1056       /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k);
   1057 
   1058   KernelSupportLibrary::EmitAndCallOutlinedKernel(
   1059       module_config, b, config.GetCacheKey(), lhs, rhs, result,
   1060       [&](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result) {
   1061         TiledSmallGemmEmitter small_gemm_emitter(config, /*lhs=*/lhs,
   1062                                                  /*rhs=*/rhs,
   1063                                                  /*result=*/result, b);
   1064         small_gemm_emitter.Emit();
   1065       });
   1066 }
   1067 
   1068 }  // namespace cpu
   1069 }  // namespace xla
   1070