Home | History | Annotate | Download | only in cpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
     17 
     18 #include <memory>
     19 #include <vector>
     20 
     21 #include "llvm/IR/BasicBlock.h"
     22 #include "llvm/IR/Instructions.h"
     23 #include "llvm/IR/Module.h"
     24 #include "llvm/IR/Value.h"
     25 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
     26 #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
     27 #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
     28 #include "tensorflow/compiler/xla/service/hlo_module.h"
     29 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
     30 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
     31 #include "tensorflow/compiler/xla/shape_util.h"
     32 #include "tensorflow/compiler/xla/status_macros.h"
     33 #include "tensorflow/compiler/xla/util.h"
     34 #include "tensorflow/compiler/xla/xla_data.pb.h"
     35 #include "tensorflow/core/platform/logging.h"
     36 
     37 namespace xla {
     38 
     39 using llvm_ir::SetToFirstInsertPoint;
     40 
     41 namespace cpu {
     42 
     43 namespace {
     44 // Loads a tile of values from a 2D tensor.
     45 class TileLoader {
     46  public:
     47   // Constructs a TileLoader that will load a tile consisting of
     48   // `tile_size_along_major_dim` vectors from the matrix `matrix`, starting at
     49   // `major_dim_offset` in the major dimension.  The tile size along the minor
     50   // dimension is the vector size, and that is implicitly determined by `vsl`.
     51   TileLoader(VectorSupportLibrary* vsl, llvm::IRBuilder<>* ir_builder,
     52              llvm::Value* matrix, int64 matrix_size_along_minor_dim,
     53              llvm::Value* major_dim_offset, int64 tile_size_along_major_dim)
     54       : vsl_(vsl) {
     55     pointers_.reserve(tile_size_along_major_dim);
     56     for (int64 i = 0; i < tile_size_along_major_dim; i++) {
     57       llvm::Value* total_offset = ir_builder->CreateMul(
     58           ir_builder->getInt64(matrix_size_along_minor_dim),
     59           ir_builder->CreateAdd(ir_builder->getInt64(i), major_dim_offset));
     60       pointers_.push_back(vsl_->ComputeOffsetPointer(matrix, total_offset));
     61     }
     62   }
     63 
     64   // Load a tile consisting of `tile_size_along_major_dim_` vectors starting at
     65   // `major_dim_offset_` in the major dimension and `minor_dim_offset` in the
     66   // minor dimension.
     67   std::vector<llvm::Value*> LoadTile(llvm::Value* minor_dim_offset) const {
     68     std::vector<llvm::Value*> result;
     69     result.reserve(pointers_.size());
     70     for (const auto& pointer : pointers_) {
     71       result.push_back(vsl_->LoadVector(pointer, minor_dim_offset));
     72     }
     73     return result;
     74   }
     75 
     76  private:
     77   VectorSupportLibrary* vsl_;
     78   std::vector<llvm::Value*> pointers_;
     79 };
     80 
     81 // Computes a dot product between "[M,K]{0,1} lhs" with a [K,1] vector (the
     82 // layout of the vector does not matter).  This implementation uses a tiling
     83 // scheme to improve performance.
     84 //
     85 // We logically separate the LHS matrix into four segments:
     86 //
     87 //   +----------------------+---+
     88 //   |                      |   |
     89 //   |                      |   |
     90 //   |         A            | B |
     91 //   |                      |   |
     92 //   |                      |   |
     93 //   |                      |   |
     94 //   +----------------------+---+
     95 //   |         C            | D |
     96 //   +----------------------+---+
     97 //
     98 // where A is the largest submatrix of the LHS that can be evenly dividied into
     99 // tiles.  For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have:
    100 //
    101 //   +---+---+---+---+       +--+--+--+--+
    102 //   |M00|M10|M20|M30|       |V0|V1|V2|V3|
    103 //   +---+---+---+---+       +--+--+--+--+
    104 //   |M01|M11|M21|M31| and   |V0|V1|V2|V3|
    105 //   +---+---+---+---+       +--+--+--+--+
    106 //   |M02|M12|M22|M32|       |V0|V1|V2|V3|
    107 //   +---+---+---+---+       +--+--+--+--+
    108 //   |M03|M13|M23|M33|       |V0|V1|V2|V3|
    109 //   +---+---+---+---+       +--+--+--+--+
    110 //
    111 // (Legend: rows are horizontal and columns are vertical; and each column is one
    112 // llvm::Value of a vector type)
    113 //
    114 // where:
    115 //
    116 //   a. The left tile is from the column major left matrix.
    117 //   b. The right tile is an elementwise broadcast of a [V0, V1, V2, V3]
    118 //      vector loaded from the RHS vector.
    119 //
    120 // As we iterate through the column dimension, we compute the change to the
    121 // result vector by an elementwise multiplication between the two tiles above
    122 // followed by a reduction along the major dimension:
    123 //
    124 //                     +-----------------------------------+
    125 //                     | M00*V0 + M10*V1 + M20*V2 + M30*V3 |
    126 //                     +-----------------------------------+
    127 //                     | M01*V0 + M11*V1 + M21*V2 + M31*V3 |
    128 // Result[R:R+4] +=    +-----------------------------------+
    129 //                     | M02*V0 + M12*V1 + M22*V2 + M32*V3 |
    130 //                     +-----------------------------------+
    131 //                     | M03*V0 + M13*V1 + M23*V2 + M33*V3 |
    132 //                     +-----------------------------------+
    133 //
    134 // Where R is the starting row for the tile.
    135 //
    136 // We have an inner epilogue loop to deal with the "C" submatrix and an outer
    137 // epilogue loop to deal with the B,D submarix.
    138 //
    139 // TODO(sanjoy): We should investigate if using gather loads and scatter stores
    140 // can be used here have the same inner loop for both column-major and row-major
    141 // matrix-vector products.
    142 class ColumnMajorMatrixVectorProductEmitter {
    143  public:
    144   ColumnMajorMatrixVectorProductEmitter(PrimitiveType scalar_type,
    145                                         int64 tile_rows, int64 tile_cols,
    146                                         int64 m, int64 k, llvm::Value* lhs,
    147                                         llvm::Value* rhs, llvm::Value* addend,
    148                                         llvm::Value* result,
    149                                         llvm::IRBuilder<>* ir_builder)
    150       : scalar_type_(scalar_type),
    151         tile_rows_(tile_rows),
    152         tile_cols_(tile_cols),
    153         m_(m),
    154         k_(k),
    155         lhs_(lhs),
    156         rhs_(rhs),
    157         addend_(addend),
    158         result_(result),
    159         ir_builder_(ir_builder),
    160         ksl_(ir_builder_),
    161         vsl_(scalar_type_, /*vector_size=*/tile_rows_, ir_builder_, "") {
    162     CHECK(tile_rows_ > 0 && IsPowerOfTwo(static_cast<uint64>(tile_rows_)));
    163   }
    164 
    165   void Emit();
    166 
    167  private:
    168   void EmitOuterLoopBody(llvm::Value* column, int64 column_count,
    169                          bool is_first_column);
    170 
    171   TileLoader GetLhsTileLoader(llvm::Value* column_start, int64 column_count) {
    172     return TileLoader(&vsl_, ir_builder_, /*matrix=*/lhs_,
    173                       /*matrix_size_along_minor_dim=*/m_,
    174                       /*major_dim_offset=*/column_start,
    175                       /*tile_size_along_major_dim=*/column_count);
    176   }
    177 
    178   // Load a tile of values from the RHS.  For the RHS a "tile" is a contiguous
    179   // sequence of `count` values, each one broadcasted to the vector width.
    180   std::vector<llvm::Value*> LoadRhsTile(llvm::Value* offset, int64 count) {
    181     llvm::Value* base_pointer = vsl_.ComputeOffsetPointer(rhs_, offset);
    182     std::vector<llvm::Value*> result;
    183     result.reserve(count);
    184     for (int64 i = 0; i < count; i++) {
    185       result.push_back(vsl_.LoadBroadcast(base_pointer, i));
    186     }
    187     return result;
    188   }
    189 
    190   void EmitInnerLoopTiled(TileLoader* lhs_tile_loader,
    191                           const std::vector<llvm::Value*>& rhs_tile,
    192                           int64 columns, bool is_first_column);
    193 
    194   void EmitInnerLoopEpilogue(llvm::Value* current_tile_col, int64 columns,
    195                              bool is_first_tiled_column);
    196 
    197   PrimitiveType scalar_type_;
    198   int64 tile_rows_;
    199   int64 tile_cols_;
    200   int64 m_;
    201   int64 k_;
    202   llvm::Value* lhs_;
    203   llvm::Value* rhs_;
    204   llvm::Value* addend_;
    205   llvm::Value* result_;
    206   llvm::IRBuilder<>* ir_builder_;
    207   KernelSupportLibrary ksl_;
    208   VectorSupportLibrary vsl_;
    209 };
    210 
    211 void ColumnMajorMatrixVectorProductEmitter::EmitOuterLoopBody(
    212     llvm::Value* column, int64 column_count, bool is_first_column) {
    213   TileLoader lhs_tile_loader = GetLhsTileLoader(/*column_start=*/column,
    214                                                 /*column_count=*/column_count);
    215 
    216   std::vector<llvm::Value*> rhs_tile =
    217       LoadRhsTile(column, /*count=*/column_count);
    218   EmitInnerLoopTiled(&lhs_tile_loader, rhs_tile,
    219                      /*columns=*/column_count, is_first_column);
    220   EmitInnerLoopEpilogue(column, /*columns=*/column_count, is_first_column);
    221 }
    222 
    223 void ColumnMajorMatrixVectorProductEmitter::Emit() {
    224   // See the comment on the class declaration for the algorithm used here.
    225   int64 column_remainder = k_ % tile_cols_;
    226   int64 column_limit = k_ - column_remainder;
    227 
    228   ksl_.For("dot.outer.tiled",
    229            /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols_,
    230            [&](llvm::Value* column, bool is_first_column) {
    231              EmitOuterLoopBody(column, tile_cols_, is_first_column);
    232            });
    233 
    234   if (column_remainder != 0) {
    235     EmitOuterLoopBody(ir_builder_->getInt64(column_limit), column_remainder,
    236                       column_limit == 0);
    237   }
    238 }
    239 
    240 void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled(
    241     TileLoader* lhs_tile_loader, const std::vector<llvm::Value*>& rhs_tile,
    242     int64 columns, bool is_first_column) {
    243   int64 row_limit = m_ - (m_ % tile_rows_);
    244 
    245   ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/row_limit,
    246            /*step=*/tile_rows_, [&](llvm::Value* row) {
    247              std::vector<llvm::Value*> lhs_tile =
    248                  lhs_tile_loader->LoadTile(/*minor_dim_offset=*/row);
    249              llvm::Value* accumulator =
    250                  is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row)
    251                                             : vsl_.GetZeroVector())
    252                                  : vsl_.LoadVector(result_, row);
    253              for (int i = 0; i < columns; i++) {
    254                accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator);
    255              }
    256              vsl_.StoreVector(accumulator, result_, row);
    257            });
    258 }
    259 
    260 void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
    261     llvm::Value* current_tile_col, int64 columns, bool is_first_tiled_column) {
    262   int64 row_start = m_ - (m_ % tile_rows_);
    263   if (row_start == m_) {
    264     return;
    265   }
    266 
    267   llvm::Value* columns_llvm = ir_builder_->getInt64(columns);
    268 
    269   // for (col = current_tile_col; col < (columns + current_tile_col); col++)
    270   //   for (row = row_start, row < m_; row++) {
    271   //     result[row] += lhs[row, col] * rhs[col]
    272   //     // Also take into account that if col is 0 then result[row] is not
    273   //     // initialized.
    274   //   }
    275 
    276   ksl_.For(
    277       "dot.inner.epilg.outer", /*start=*/current_tile_col,
    278       /*end=*/ir_builder_->CreateAdd(columns_llvm, current_tile_col),
    279       /*step=*/1, /*peel_first_iteration=*/false,
    280       [&](llvm::Value* col, llvm::Value* is_first_scalar_col) {
    281         llvm::Value* rhs_element = vsl_.LoadScalar(rhs_, col);
    282         llvm::Value* total_offset =
    283             ir_builder_->CreateMul(col, ir_builder_->getInt64(m_));
    284         llvm::Value* lhs_base_pointer =
    285             vsl_.ComputeOffsetPointer(lhs_, total_offset);
    286         ksl_.For(
    287             "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m_,
    288             /*step=*/1, [&](llvm::Value* scalar_row) {
    289               llvm::Value* product = vsl_.Mul(
    290                   vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element);
    291               llvm::Value* setting_result_first_time = ir_builder_->CreateAnd(
    292                   is_first_scalar_col,
    293                   ir_builder_->getInt1(is_first_tiled_column));
    294               ksl_.If(
    295                   setting_result_first_time,
    296                   /*true_block_generator=*/
    297                   [&]() {
    298                     if (addend_) {
    299                       vsl_.StoreScalar(
    300                           vsl_.Add(vsl_.LoadScalar(addend_, scalar_row),
    301                                    product),
    302                           result_, scalar_row);
    303                     } else {
    304                       vsl_.StoreScalar(product, result_, scalar_row);
    305                     }
    306                   },
    307                   /*false_block_generator=*/
    308                   [&]() {
    309                     vsl_.StoreScalar(
    310                         vsl_.Add(vsl_.LoadScalar(result_, scalar_row), product),
    311                         result_, scalar_row);
    312                   });
    313             });
    314       });
    315 }
    316 
    317 // Computes a dot product between "[M,K]{1,0} lhs" with a [K,1] vector (the
    318 // layout of the vector does not matter).  This implementation uses a tiling
    319 // scheme to improve performance.
    320 //
    321 // We logically separate the LHS matrix into four segments:
    322 //
    323 //   +----------------------+---+
    324 //   |                      |   |
    325 //   |                      |   |
    326 //   |         A            | B |
    327 //   |                      |   |
    328 //   |                      |   |
    329 //   |                      |   |
    330 //   +----------------------+---+
    331 //   |         C            | D |
    332 //   +----------------------+---+
    333 //
    334 // where A is the largest submatrix of the LHS that can be evenly dividied into
    335 // tiles.  For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have:
    336 //
    337 //   +---+---+---+---+
    338 //   |M00|M10|M20|M30|
    339 //   +---+---+---+---+       +--+--+--+--+
    340 //   |M01|M11|M21|M31| and   |V0|V1|V2|V3|
    341 //   +---+---+---+---+       +--+--+--+--+
    342 //   |M02|M12|M22|M32|
    343 //   +---+---+---+---+
    344 //   |M03|M13|M23|M33|
    345 //   +---+---+---+---+
    346 //
    347 // (Legend: rows are horizontal and columns are vertical; and each row is one
    348 // llvm::Value of a vector type)
    349 //
    350 // where:
    351 //
    352 //   a. The left tile is loaded from the row major left matrix.
    353 //   b. The right vector is loaded from the RHS vector.
    354 //
    355 // We keep 4 vector accumulators accumulating the following four vector
    356 // expressions as we iterate over the row dimension:
    357 //
    358 //   +------+------+------+------+
    359 //   |M0I*V0|M1I*V1|M2I*V2|M3I*V3|  for I in [0,4)
    360 //   +------+------+------+------+
    361 //
    362 // In the end we do a horizontal reduction over these 4 vector accumulators to
    363 // get 4 values in the result vector.
    364 //
    365 // We have an inner epilogue loop to deal with the "B" sub-matrix and an outer
    366 // epilogue loop to deal with the C,D submatrix.
    367 class RowMajorMatrixVectorProductEmitter {
    368  public:
    369   RowMajorMatrixVectorProductEmitter(PrimitiveType scalar_type, int64 tile_rows,
    370                                      int64 tile_cols, int64 m, int64 k,
    371                                      llvm::Value* lhs, llvm::Value* rhs,
    372                                      llvm::Value* addend, llvm::Value* result,
    373                                      llvm::IRBuilder<>* ir_builder)
    374       : scalar_type_(scalar_type),
    375         tile_rows_(tile_rows),
    376         tile_cols_(tile_cols),
    377         m_(m),
    378         k_(k),
    379         lhs_(lhs),
    380         rhs_(rhs),
    381         addend_(addend),
    382         result_(result),
    383         ir_builder_(ir_builder),
    384         ksl_(ir_builder_),
    385         vsl_(scalar_type_, /*vector_size=*/tile_cols_, ir_builder_, "") {
    386     CHECK(tile_cols_ > 0 && IsPowerOfTwo(static_cast<uint64>(tile_cols_)));
    387   }
    388 
    389   void Emit();
    390 
    391  private:
    392   TileLoader GetLhsTileLoader(llvm::Value* row_start, int64 row_count) {
    393     return TileLoader(&vsl_, ir_builder_, /*matrix=*/lhs_,
    394                       /*matrix_size_along_minor_dim=*/k_,
    395                       /*major_dim_offset=*/row_start,
    396                       /*tile_size_along_major_dim=*/row_count);
    397   }
    398 
    399   void EmitOuterLoopBody(llvm::Value* row, int64 row_count);
    400 
    401   void EmitInnerLoopTiled(TileLoader* lhs_tile_loader, int64 rows,
    402                           std::vector<VectorVariable>* vector_accumulators);
    403 
    404   void EmitInnerLoopEpilogue(llvm::Value* current_tile_row, int64 rows,
    405                              std::vector<ScalarVariable>* scalar_accumulators);
    406 
    407   PrimitiveType scalar_type_;
    408   int64 tile_rows_;
    409   int64 tile_cols_;
    410   int64 m_;
    411   int64 k_;
    412   llvm::Value* lhs_;
    413   llvm::Value* rhs_;
    414   llvm::Value* addend_;
    415   llvm::Value* result_;
    416   llvm::IRBuilder<>* ir_builder_;
    417   KernelSupportLibrary ksl_;
    418   VectorSupportLibrary vsl_;
    419 };
    420 
    421 void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row,
    422                                                            int64 row_count) {
    423   TileLoader lhs_tile_loader = GetLhsTileLoader(/*row_start=*/row,
    424                                                 /*row_count=*/row_count);
    425   std::vector<VectorVariable> vector_accumulators;
    426   std::vector<ScalarVariable> scalar_accumulators;
    427   for (int i = 0; i < row_count; i++) {
    428     vector_accumulators.emplace_back(&vsl_, vsl_.GetZeroVector());
    429     scalar_accumulators.emplace_back(&vsl_, vsl_.GetZeroScalar());
    430   }
    431   EmitInnerLoopTiled(&lhs_tile_loader, /*rows=*/row_count,
    432                      &vector_accumulators);
    433   EmitInnerLoopEpilogue(/*current_tile_row=*/row, /*rows=*/row_count,
    434                         &scalar_accumulators);
    435 
    436   std::vector<llvm::Value*> accumulator_values;
    437   std::transform(
    438       vector_accumulators.begin(), vector_accumulators.end(),
    439       std::back_inserter(accumulator_values),
    440       [](const VectorVariable& vector_var) { return vector_var.Get(); });
    441 
    442   std::vector<llvm::Value*> horizontal_sums;
    443   if (row_count == vsl_.vector_size()) {
    444     if (addend_) {
    445       horizontal_sums = vsl_.ComputeHorizontalSums(
    446           std::move(accumulator_values), vsl_.LoadVector(addend_, row));
    447     } else {
    448       horizontal_sums =
    449           vsl_.ComputeHorizontalSums(std::move(accumulator_values));
    450     }
    451   } else {
    452     horizontal_sums = vsl_.ComputeHorizontalSums(std::move(accumulator_values));
    453   }
    454 
    455   for (int i = 0; i < row_count; i++) {
    456     llvm::Value* result_value =
    457         vsl_.Add(horizontal_sums[i], scalar_accumulators[i].Get());
    458     llvm::Value* offset = ir_builder_->CreateAdd(ir_builder_->getInt64(i), row);
    459     if (addend_ && row_count != vsl_.vector_size()) {
    460       result_value = vsl_.Add(vsl_.LoadScalar(addend_, offset), result_value);
    461     }
    462     vsl_.StoreScalar(result_value, result_, offset);
    463   }
    464 }
    465 
    466 void RowMajorMatrixVectorProductEmitter::Emit() {
    467   // See the comment on the class declaration for the algorithm used here.
    468   int64 row_remainder = m_ % tile_rows_;
    469   int64 row_limit = m_ - row_remainder;
    470 
    471   ksl_.For("dot.outer.tiled",
    472            /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows_,
    473            [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows_); });
    474 
    475   if (row_remainder != 0) {
    476     EmitOuterLoopBody(ir_builder_->getInt64(row_limit), row_remainder);
    477   }
    478 }
    479 
    480 void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled(
    481     TileLoader* lhs_tile_loader, int64 rows,
    482     std::vector<VectorVariable>* vector_accumulators) {
    483   int64 column_limit = k_ - (k_ % tile_cols_);
    484 
    485   ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/column_limit,
    486            /*step=*/tile_cols_, [&](llvm::Value* col) {
    487              std::vector<llvm::Value*> lhs_tile =
    488                  lhs_tile_loader->LoadTile(/*minor_dim_offset=*/col);
    489              llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col);
    490              for (int i = 0; i < rows; i++) {
    491                llvm::Value* old_sum = (*vector_accumulators)[i].Get();
    492                (*vector_accumulators)[i].Set(
    493                    vsl_.Add(old_sum, vsl_.Mul(rhs_value, lhs_tile[i])));
    494              }
    495            });
    496 }
    497 
    498 void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
    499     llvm::Value* current_tile_row, int64 rows,
    500     std::vector<ScalarVariable>* scalar_accumulators) {
    501   int64 column_start = k_ - (k_ % tile_cols_);
    502   if (column_start == k_) {
    503     return;
    504   }
    505 
    506   for (int r = 0; r < rows; r++) {
    507     llvm::Value* total_offset = ir_builder_->CreateMul(
    508         ir_builder_->CreateAdd(ir_builder_->getInt64(r), current_tile_row),
    509         ir_builder_->getInt64(k_));
    510     llvm::Value* lhs_base_pointer =
    511         vsl_.ComputeOffsetPointer(lhs_, total_offset);
    512     ksl_.For("dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k_,
    513              /*step=*/1, [&](llvm::Value* scalar_col) {
    514                llvm::Value* product =
    515                    vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col),
    516                             vsl_.LoadScalar(rhs_, scalar_col));
    517                llvm::Value* old_value = (*scalar_accumulators)[r].Get();
    518                (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product));
    519              });
    520   }
    521 }
    522 
    523 }  // namespace
    524 
    525 DotOpEmitter::DotOpEmitter(
    526     const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs,
    527     const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array,
    528     const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array,
    529     llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder,
    530     const HloModuleConfig& hlo_module_config,
    531     const TargetMachineFeatures& target_machine_features)
    532     : dot_(dot),
    533       transpose_lhs_(transpose_lhs),
    534       transpose_rhs_(transpose_rhs),
    535       target_array_(target_array),
    536       lhs_array_(lhs_array),
    537       rhs_array_(rhs_array),
    538       addend_array_(addend_array),
    539       executable_run_options_value_(executable_run_options_value),
    540       ir_builder_(ir_builder),
    541       hlo_module_config_(hlo_module_config),
    542       target_machine_features_(target_machine_features) {}
    543 
    544 /* static */ tensorflow::Status DotOpEmitter::EmitDotOperation(
    545     const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs,
    546     const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array,
    547     const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array,
    548     llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder,
    549     const HloModuleConfig& hlo_module_config,
    550     const TargetMachineFeatures& target_machine_features) {
    551   PrimitiveType type = target_array.GetShape().element_type();
    552   TF_RET_CHECK(F16 == type || F32 == type || F64 == type || C64 == type);
    553   DotOpEmitter dot_emitter(dot, transpose_lhs, transpose_rhs, target_array,
    554                            lhs_array, rhs_array, addend_array,
    555                            executable_run_options_value, ir_builder,
    556                            hlo_module_config, target_machine_features);
    557   return dot_emitter.Emit();
    558 }
    559 
    560 bool DotOpEmitter::ShapesAreLegalForRuntimeDot() const { return true; }
    561 
    562 bool DotOpEmitter::EmitLlvmIrDotIfProfitable() {
    563   if (dot_.shape().dimensions_size() != 2) {
    564     return false;
    565   }
    566 
    567   PrimitiveType primitive_type = dot_.shape().element_type();
    568 
    569   if (!primitive_util::IsFloatingPointType(primitive_type) &&
    570       !primitive_util::IsIntegralType(primitive_type)) {
    571     return false;
    572   }
    573 
    574   MatMultDims mat_mult_dims = GetMatMultDims();
    575   bool is_column_major_matrix_vector = false;
    576   bool is_row_major_matrix_vector = false;
    577 
    578   int64 m, k;
    579   bool swap_operands;
    580 
    581   if (mat_mult_dims.m == 1) {
    582     bool rhs_effectively_row_major =
    583         transpose_rhs_ ^ !mat_mult_dims.rhs_column_major;
    584     if (rhs_effectively_row_major) {
    585       k = mat_mult_dims.k;
    586       m = mat_mult_dims.n;
    587       is_column_major_matrix_vector = true;
    588       swap_operands = true;
    589     } else {
    590       k = mat_mult_dims.k;
    591       m = mat_mult_dims.n;
    592       is_row_major_matrix_vector = true;
    593       swap_operands = true;
    594     }
    595   }
    596 
    597   if (mat_mult_dims.n == 1) {
    598     bool lhs_effectively_column_major =
    599         transpose_lhs_ ^ mat_mult_dims.lhs_column_major;
    600     if (lhs_effectively_column_major) {
    601       m = mat_mult_dims.m;
    602       k = mat_mult_dims.k;
    603       is_column_major_matrix_vector = true;
    604       swap_operands = false;
    605     } else {
    606       m = mat_mult_dims.m;
    607       k = mat_mult_dims.k;
    608       is_row_major_matrix_vector = true;
    609       swap_operands = false;
    610     }
    611   }
    612 
    613   if (!is_column_major_matrix_vector && !is_row_major_matrix_vector) {
    614     return false;
    615   }
    616 
    617   int64 tiling_factor = GetGemvTilingFactor();
    618   CHECK_GT(tiling_factor, 0);
    619 
    620   llvm::Value* result_op = target_array_.GetBasePointer();
    621   llvm::Value* lhs_op =
    622       swap_operands ? rhs_array_.GetBasePointer() : lhs_array_.GetBasePointer();
    623   llvm::Value* rhs_op =
    624       swap_operands ? lhs_array_.GetBasePointer() : rhs_array_.GetBasePointer();
    625 
    626   const bool enable_fast_math =
    627       hlo_module_config_.debug_options().xla_enable_fast_math();
    628   const bool optimize_for_size =
    629       options::OptimizeForSizeRequested(hlo_module_config_);
    630 
    631   const int target_vector_register_element_size =
    632       target_machine_features_.vector_register_num_elements(
    633           *ir_builder_->GetInsertBlock()->getParent(), primitive_type);
    634 
    635   // We may not always know the vector register size for the target we're
    636   // compiling against, in which case target_vector_register_element_size is 0.
    637   // In these cases we choose a default LLVM IR register size.
    638   const int kUnknownTargetVectorRegisterSize = 4;
    639   const int vector_register_element_size =
    640       target_vector_register_element_size == 0
    641           ? kUnknownTargetVectorRegisterSize
    642           : target_vector_register_element_size;
    643 
    644   if (is_column_major_matrix_vector) {
    645     VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m
    646             << " and k = " << k;
    647     int64 tile_rows = vector_register_element_size;
    648     int64 tile_cols = tiling_factor;
    649 
    650     string kernel_name = tensorflow::strings::StrCat(
    651         "col_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows,
    652         "_", tile_cols, "_", m, "_", k, addend_array_ ? "_with_addend" : "");
    653 
    654     KernelSupportLibrary::EmitAndCallOutlinedKernel(
    655         /*enable_fast_math=*/enable_fast_math,
    656         /*optimize_for_size=*/optimize_for_size, ir_builder_, kernel_name,
    657         lhs_op, rhs_op,
    658         addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op,
    659         [this, tile_rows, tile_cols, m, k, primitive_type](
    660             llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* addend_op,
    661             llvm::Value* result_op) {
    662           ColumnMajorMatrixVectorProductEmitter emitter(
    663               primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op,
    664               addend_op, result_op, ir_builder_);
    665           emitter.Emit();
    666         });
    667   } else {
    668     VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m
    669             << " and k = " << k;
    670     int64 tile_rows = tiling_factor;
    671     int64 tile_cols = vector_register_element_size;
    672 
    673     string kernel_name = tensorflow::strings::StrCat(
    674         "row_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows,
    675         "_", tile_cols, "_", m, "_", k, addend_array_ ? "_with_addend" : "");
    676 
    677     KernelSupportLibrary::EmitAndCallOutlinedKernel(
    678         /*enable_fast_math=*/enable_fast_math,
    679         /*optimize_for_size=*/optimize_for_size, ir_builder_, kernel_name,
    680         lhs_op, rhs_op,
    681         addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op,
    682         [this, tile_rows, tile_cols, m, k, primitive_type](
    683             llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* addend_op,
    684             llvm::Value* result_op) {
    685           RowMajorMatrixVectorProductEmitter emitter(
    686               primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op,
    687               addend_op, result_op, ir_builder_);
    688           emitter.Emit();
    689         });
    690   }
    691 
    692   return true;
    693 }
    694 
    695 tensorflow::Status DotOpEmitter::Emit() {
    696   // The dot operation performs a sum of products over dimension 0 of the left
    697   // hand side operand and dimension 1 of the right hand side operand.
    698   //
    699   // Let the shapes of lhs and rhs be defined as below:
    700   //
    701   //   lhs = [L{n-1} x L{n-2} x ... L{0}]
    702   //   rhs = [R{m-1} x R{m-2} x ... R{0}]
    703   //
    704   // The sum-of-products dimension in the lhs has size L{0} and the dimension in
    705   // the rhs has size R{1}. Necessarily, then:
    706   //
    707   //   L{0} == R{1}
    708   //
    709   // The output of the operation has the following shape:
    710   //
    711   //   output = [L{n-1} x L{n-2} x ... L{1} x R{m-1} x R{m-2} x ... R{2} x R{0}]
    712   //
    713   // To perform the operation we construct a loop nest with one for-loop for
    714   // each dimension of the output. Inside this loop nest is another for-loop
    715   // which performs the sum-of-products (the reduction loop) before storing
    716   // the result in the output buffer.
    717 
    718   const Shape& lhs_shape = lhs_array_.GetShape();
    719   const Shape& rhs_shape = rhs_array_.GetShape();
    720 
    721   if (ShapeUtil::IsScalar(lhs_shape) || ShapeUtil::IsScalar(rhs_shape)) {
    722     // If the operands are scalar, don't emit any loops.
    723     TF_RET_CHECK(ShapeUtil::IsScalar(lhs_shape) &&
    724                  ShapeUtil::IsScalar(rhs_shape));
    725     return EmitScalarDot();
    726   }
    727 
    728   if (EmitLlvmIrDotIfProfitable()) {
    729     return Status::OK();
    730   }
    731 
    732   CHECK_EQ(addend_array_, nullptr);
    733 
    734   if (PotentiallyImplementedAsEigenDot(dot_)) {
    735     return EmitCallToRuntime();
    736   }
    737 
    738   // Reduce along dimension 0 of the LHS and 1 of the RHS. Vectors are a special
    739   // case where the reduction dimension is 0 for both LHS and RHS. This results
    740   // in a vector dot product producing a scalar.
    741   int64 lhs_reduction_dimension = 0;
    742   if (ShapeUtil::Rank(lhs_shape) >= 2) {
    743     lhs_reduction_dimension =
    744         ShapeUtil::GetDimensionNumber(lhs_shape, transpose_lhs_ ? -2 : -1);
    745   }
    746   int64 rhs_reduction_dimension = 0;
    747   if (ShapeUtil::Rank(rhs_shape) >= 2) {
    748     rhs_reduction_dimension =
    749         ShapeUtil::GetDimensionNumber(rhs_shape, transpose_rhs_ ? -1 : -2);
    750   }
    751 
    752   // Verify the reduction dimension in the two operands are the same size.
    753   TF_RET_CHECK(lhs_shape.dimensions(lhs_reduction_dimension) ==
    754                rhs_shape.dimensions(rhs_reduction_dimension));
    755 
    756   bool lhs_reduction_along_minor_dimension =
    757       lhs_reduction_dimension == LayoutUtil::Minor(lhs_shape.layout(), 0);
    758   bool rhs_reduction_along_minor_dimension =
    759       rhs_reduction_dimension == LayoutUtil::Minor(rhs_shape.layout(), 0);
    760 
    761   // Create loop nests which loop through the LHS operand dimensions and the RHS
    762   // operand dimensions. The reduction dimension of the LHS and RHS are handled
    763   // in a separate innermost loop which performs the sum of products.
    764   llvm_ir::ForLoopNest loop_nest(llvm_ir::IrName(&dot_), ir_builder_);
    765   llvm_ir::IrArray::Index lhs_index = EmitOperandArrayLoopNest(
    766       &loop_nest, lhs_array_, lhs_reduction_dimension, "lhs");
    767   llvm_ir::IrArray::Index rhs_index = EmitOperandArrayLoopNest(
    768       &loop_nest, rhs_array_, rhs_reduction_dimension, "rhs");
    769 
    770   // Create the loop which does the sum of products reduction.
    771   //
    772   // The prevent_unrolling bit is working around a deficiency in LLVM's loop
    773   // vectorization pipeline, wherein in some cases unrolling a loop can prevent
    774   // effective vectorization.  Since we know that the IR we generate when
    775   // reducing across the minor dimension in both LHS and RHS is vectorized well
    776   // by the loop vectorizer, we block unrolling in that case to stop loop unroll
    777   // from messing up the vectorization.
    778   std::unique_ptr<llvm_ir::ForLoop> reduction_loop = loop_nest.AddLoop(
    779       0, lhs_shape.dimensions(lhs_reduction_dimension), "reduction",
    780       /*prevent_unrolling=*/lhs_reduction_along_minor_dimension &&
    781           rhs_reduction_along_minor_dimension);
    782 
    783   // The final entry in the rhs and lhs indexes is the indvar of the
    784   // reduction loop.
    785   lhs_index[lhs_reduction_dimension] = reduction_loop->GetIndVarValue();
    786   rhs_index[rhs_reduction_dimension] = reduction_loop->GetIndVarValue();
    787 
    788   // For computing the sum of products we alloca a single location to store the
    789   // dot product result as we accumulate it within the reduction loop. After the
    790   // reduction loop we load the result and store into the output array.
    791 
    792   // Function entry basic block.
    793   // - Emit alloca for accumulator
    794   llvm::Function* func = reduction_loop->GetPreheaderBasicBlock()->getParent();
    795   SetToFirstInsertPoint(&func->getEntryBlock(), ir_builder_);
    796   llvm::Type* accum_type = target_array_.GetElementLlvmType();
    797   llvm::Value* accum_address = ir_builder_->CreateAlloca(
    798       accum_type, /*ArraySize=*/nullptr, "accum_address");
    799 
    800   // Preheader basic block of reduction loop:
    801   // - Initialize accumulator to zero.
    802   llvm::BasicBlock* preheader_bb = reduction_loop->GetPreheaderBasicBlock();
    803   ir_builder_->SetInsertPoint(preheader_bb->getTerminator());
    804 
    805   ir_builder_->CreateStore(llvm::Constant::getNullValue(accum_type),
    806                            accum_address);
    807 
    808   // Body basic block of reduction loop:
    809   // - Load elements from lhs and rhs array.
    810   // - Multiply lhs-element and rhs-element.
    811   // - Load accumulator and add to product.
    812   // - Store sum back into accumulator.
    813   SetToFirstInsertPoint(reduction_loop->GetBodyBasicBlock(), ir_builder_);
    814 
    815   llvm::Value* lhs_element =
    816       lhs_array_.EmitReadArrayElement(lhs_index, ir_builder_);
    817   llvm::Value* rhs_element =
    818       rhs_array_.EmitReadArrayElement(rhs_index, ir_builder_);
    819 
    820   llvm::Value* accum = ir_builder_->CreateLoad(accum_address);
    821   llvm::Value* updated_accum;
    822   if (ShapeUtil::ElementIsComplex(lhs_shape)) {
    823     auto real = [&](llvm::Value* x) {
    824       return ir_builder_->CreateExtractValue(x, {0});
    825     };
    826     auto imag = [&](llvm::Value* x) {
    827       return ir_builder_->CreateExtractValue(x, {1});
    828     };
    829     llvm::Value* product_real = ir_builder_->CreateFSub(
    830         ir_builder_->CreateFMul(real(lhs_element), real(rhs_element)),
    831         ir_builder_->CreateFMul(imag(lhs_element), imag(rhs_element)));
    832     llvm::Value* product_imag = ir_builder_->CreateFAdd(
    833         ir_builder_->CreateFMul(real(lhs_element), imag(rhs_element)),
    834         ir_builder_->CreateFMul(imag(lhs_element), real(rhs_element)));
    835     updated_accum = ir_builder_->CreateInsertValue(
    836         accum, ir_builder_->CreateFAdd(real(accum), product_real), {0});
    837     updated_accum = ir_builder_->CreateInsertValue(
    838         updated_accum, ir_builder_->CreateFAdd(imag(accum), product_imag), {1});
    839   } else {
    840     llvm::Value* product = ir_builder_->CreateFMul(lhs_element, rhs_element);
    841     updated_accum = ir_builder_->CreateFAdd(accum, product);
    842   }
    843   ir_builder_->CreateStore(updated_accum, accum_address);
    844 
    845   // Exit basic block of reduction loop.
    846   // - Load accumulator value (the result).
    847   // - Store into output array.
    848   SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), ir_builder_);
    849 
    850   llvm::Value* result = ir_builder_->CreateLoad(accum_address);
    851 
    852   // Create index into target address. The target index is the concatenation of
    853   // the rhs and lhs indexes with the reduction dimensions removed. The terms
    854   // from the rhs index are the lower dimensions in the index so we add them
    855   // first.
    856   llvm_ir::IrArray::Index target_index;
    857   for (int dimension = 0; dimension < lhs_index.size(); ++dimension) {
    858     if (dimension != lhs_reduction_dimension) {
    859       target_index.push_back(lhs_index[dimension]);
    860     }
    861   }
    862   for (int dimension = 0; dimension < rhs_index.size(); ++dimension) {
    863     if (dimension != rhs_reduction_dimension) {
    864       target_index.push_back(rhs_index[dimension]);
    865     }
    866   }
    867 
    868   target_array_.EmitWriteArrayElement(target_index, result, ir_builder_);
    869 
    870   // Set the IR builder insert point to the exit basic block of the outer most
    871   // loop.
    872   ir_builder_->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock());
    873 
    874   return tensorflow::Status::OK();
    875 }
    876 
    877 tensorflow::Status DotOpEmitter::EmitScalarDot() {
    878   // A scalar dot is just a scalar multiply.
    879   llvm::Value* result;
    880   llvm::Value* lhs_value =
    881       lhs_array_.EmitReadArrayElement(/*index=*/{}, ir_builder_);
    882   llvm::Value* rhs_value =
    883       rhs_array_.EmitReadArrayElement(/*index=*/{}, ir_builder_);
    884   if (ShapeUtil::ElementIsComplex(lhs_array_.GetShape())) {
    885 #define REAL(x) ir_builder_->CreateExtractValue(x, {0})
    886 #define IMAG(x) ir_builder_->CreateExtractValue(x, {1})
    887     llvm::Value* real = ir_builder_->CreateFSub(
    888         ir_builder_->CreateFMul(REAL(lhs_value), REAL(rhs_value)),
    889         ir_builder_->CreateFMul(IMAG(lhs_value), IMAG(rhs_value)));
    890     llvm::Value* imag = ir_builder_->CreateFAdd(
    891         ir_builder_->CreateFMul(REAL(lhs_value), IMAG(rhs_value)),
    892         ir_builder_->CreateFMul(IMAG(lhs_value), REAL(rhs_value)));
    893 #undef IMAG
    894 #undef REAL
    895     result = llvm::ConstantAggregateZero::get(lhs_array_.GetElementLlvmType());
    896     result = ir_builder_->CreateInsertValue(result, real, {0});
    897     result = ir_builder_->CreateInsertValue(result, imag, {1});
    898   } else {
    899     result = ir_builder_->CreateFMul(lhs_value, rhs_value);
    900   }
    901   target_array_.EmitWriteArrayElement(/*index=*/{}, result, ir_builder_);
    902   return tensorflow::Status::OK();
    903 }
    904 
    905 tensorflow::Status DotOpEmitter::EmitCallToRuntime() {
    906   DCHECK(ShapesAreLegalForRuntimeDot());
    907 
    908   // The signature of the Eigen runtime matmul function is:
    909   //
    910   //   (void)(void* run_options, float* out, float* lhs, float* rhs,
    911   //          int64 m, int64 n, int64 k, int32 transpose_lhs,
    912   //          int32 transpose_rhs);
    913   // The two transpose_... parameters are actually booleans, but we use int32
    914   // to avoid target-dependent calling convention details.
    915 
    916   bool multi_threaded_eigen =
    917       hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen();
    918   PrimitiveType type = target_array_.GetShape().element_type();
    919   llvm::Type* float_type;
    920   const char* fn_name;
    921   switch (type) {
    922     case F32:
    923       fn_name = multi_threaded_eigen
    924                     ? runtime::kEigenMatMulF32SymbolName
    925                     : runtime::kEigenSingleThreadedMatMulF32SymbolName;
    926       float_type = ir_builder_->getFloatTy();
    927       break;
    928     case F64:
    929       fn_name = multi_threaded_eigen
    930                     ? runtime::kEigenMatMulF64SymbolName
    931                     : runtime::kEigenSingleThreadedMatMulF64SymbolName;
    932       float_type = ir_builder_->getDoubleTy();
    933       break;
    934     default:
    935       return Unimplemented("Invalid type %s for dot operation",
    936                            PrimitiveType_Name(type).c_str());
    937   }
    938 
    939   llvm::Type* float_ptr_type = float_type->getPointerTo();
    940   llvm::Type* int64_type = ir_builder_->getInt64Ty();
    941   llvm::Type* int32_type = ir_builder_->getInt32Ty();
    942   llvm::Type* int8_ptr_type = ir_builder_->getInt8Ty()->getPointerTo();
    943   llvm::FunctionType* matmul_type = llvm::FunctionType::get(
    944       ir_builder_->getVoidTy(),
    945       {int8_ptr_type, float_ptr_type, float_ptr_type, float_ptr_type,
    946        int64_type, int64_type, int64_type, int32_type, int32_type},
    947       /*isVarArg=*/false);
    948 
    949   llvm::Function* function = ir_builder_->GetInsertBlock()->getParent();
    950   llvm::Module* module = function->getParent();
    951 
    952   llvm::Function* matmul_func = llvm::cast<llvm::Function>(
    953       module->getOrInsertFunction(fn_name, matmul_type));
    954   matmul_func->setCallingConv(llvm::CallingConv::C);
    955   matmul_func->setDoesNotThrow();
    956   matmul_func->setOnlyAccessesArgMemory();
    957 
    958   // The Eigen runtime function expects column-major layout. If the matrices are
    959   // row major, then use the following identity to compute the product:
    960   //
    961   //   (A x B)^T = B^T x A^T
    962   //
    963   // The connection between this identity and memory layout is that the
    964   // transpose operation can also be considered as an operation that changes the
    965   // memory layout of a matrix from row-major to column-major or vice versa.
    966   //
    967   // Effectively this involves swapping the 'lhs' with 'rhs' and 'm' with 'n'.
    968 
    969   MatMultDims mat_mult_dims = GetMatMultDims();
    970 
    971   CHECK_EQ(mat_mult_dims.lhs_column_major, mat_mult_dims.rhs_column_major);
    972 
    973   const llvm_ir::IrArray* lhs = &lhs_array_;
    974   const llvm_ir::IrArray* rhs = &rhs_array_;
    975   bool transpose_lhs = transpose_lhs_;
    976   bool transpose_rhs = transpose_rhs_;
    977 
    978   if (!mat_mult_dims.lhs_column_major) {
    979     std::swap(mat_mult_dims.m, mat_mult_dims.n);
    980     std::swap(lhs, rhs);
    981     std::swap(transpose_lhs, transpose_rhs);
    982   }
    983 
    984   ir_builder_->CreateCall(
    985       matmul_func,
    986       {ir_builder_->CreateBitCast(executable_run_options_value_, int8_ptr_type),
    987        ir_builder_->CreateBitCast(target_array_.GetBasePointer(),
    988                                   float_ptr_type),
    989        ir_builder_->CreateBitCast(lhs->GetBasePointer(), float_ptr_type),
    990        ir_builder_->CreateBitCast(rhs->GetBasePointer(), float_ptr_type),
    991        ir_builder_->getInt64(mat_mult_dims.m),
    992        ir_builder_->getInt64(mat_mult_dims.n),
    993        ir_builder_->getInt64(mat_mult_dims.k),
    994        ir_builder_->getInt32(transpose_lhs),
    995        ir_builder_->getInt32(transpose_rhs)});
    996   return tensorflow::Status::OK();
    997 }
    998 
    999 DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const {
   1000   CHECK_EQ(dot_.shape().dimensions_size(), 2);
   1001 
   1002   const Shape& lhs_shape = lhs_array_.GetShape();
   1003   const Shape& rhs_shape = rhs_array_.GetShape();
   1004 
   1005   return {lhs_shape.dimensions(transpose_lhs_ ? 1 : 0),
   1006           lhs_shape.dimensions(transpose_lhs_ ? 0 : 1),
   1007           rhs_shape.dimensions(transpose_rhs_ ? 0 : 1),
   1008           LayoutUtil::Minor(lhs_shape.layout(), 0) == 0,
   1009           LayoutUtil::Minor(rhs_shape.layout(), 0) == 0};
   1010 }
   1011 
   1012 llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest(
   1013     llvm_ir::ForLoopNest* loop_nest, const llvm_ir::IrArray& operand_array,
   1014     int64 reduction_dimension, tensorflow::StringPiece name_suffix) {
   1015   // Prepares the dimension list we will use to emit the loop nest. Outermost
   1016   // loops are added first. Add loops in major-to-minor order, and skip the
   1017   // reduction dimension.
   1018   std::vector<int64> dimensions;
   1019   const Shape& shape = operand_array.GetShape();
   1020   for (int i = LayoutUtil::MinorToMajor(shape).size() - 1; i >= 0; --i) {
   1021     int64 dimension = LayoutUtil::Minor(shape.layout(), i);
   1022     if (dimension != reduction_dimension) {
   1023       dimensions.push_back(dimension);
   1024     }
   1025   }
   1026 
   1027   // Create loop nest with one for-loop for each dimension of the
   1028   // output.
   1029   llvm_ir::IrArray::Index index =
   1030       loop_nest->AddLoopsForShapeOnDimensions(shape, dimensions, name_suffix);
   1031   // Verify every dimension except the reduction dimension was set in the index.
   1032   for (int dimension = 0; dimension < index.size(); ++dimension) {
   1033     if (dimension == reduction_dimension) {
   1034       DCHECK_EQ(nullptr, index[dimension]);
   1035     } else {
   1036       DCHECK_NE(nullptr, index[dimension]);
   1037     }
   1038   }
   1039   return index;
   1040 }
   1041 
   1042 // Return whether the given shape is a matrix with no padding.
   1043 static bool IsRank2WithNoPadding(const Shape& shape) {
   1044   return ShapeUtil::Rank(shape) == 2 && !LayoutUtil::IsPadded(shape);
   1045 }
   1046 
   1047 // In a gemm operation where output = lhs * rhs, check whether the given shapes
   1048 // are valid for the operation.
   1049 static bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
   1050                                const Shape& output_shape) {
   1051   // The inputs and the output must
   1052   // 1) be matrices with no padding, and
   1053   // 2) have an allowed element type.
   1054   return output_shape.element_type() == F32 &&
   1055          IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) &&
   1056          IsRank2WithNoPadding(output_shape);
   1057 }
   1058 
   1059 bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) {
   1060   // For certain types of Dot, we can call Eigen
   1061   if (hlo.opcode() == HloOpcode::kDot) {
   1062     const Shape& lhs_shape = hlo.operand(0)->shape();
   1063     const Shape& rhs_shape = hlo.operand(1)->shape();
   1064 
   1065     if (ShapeUtil::HasZeroElements(lhs_shape) ||
   1066         ShapeUtil::HasZeroElements(rhs_shape)) {
   1067       return false;
   1068     }
   1069 
   1070     if (ProfitableToImplementDotInTiledLlvmIr(hlo)) {
   1071       return false;
   1072     }
   1073 
   1074     // If gemm can accept the operand shapes, use it rather than a custom
   1075     // kernel.
   1076     if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) {
   1077       // The size of the reduction dimension should match. The shape inference
   1078       // guarantees this invariant, so the check here is for programming
   1079       // errors.
   1080       CHECK_EQ(lhs_shape.dimensions(1), rhs_shape.dimensions(0));
   1081       return true;
   1082     }
   1083   }
   1084 
   1085   if (hlo.opcode() == HloOpcode::kFusion &&
   1086       hlo.fusion_kind() == HloInstruction::FusionKind::kTransposeDot &&
   1087       hlo.fused_expression_root()->opcode() == HloOpcode::kDot) {
   1088     auto* dot = hlo.fused_expression_root();
   1089     const Shape& lhs_shape = dot->operand(0)->shape();
   1090     const Shape& rhs_shape = dot->operand(1)->shape();
   1091     if (ShapeUtil::HasZeroElements(lhs_shape) ||
   1092         ShapeUtil::HasZeroElements(rhs_shape)) {
   1093       return false;
   1094     }
   1095     return true;
   1096   }
   1097 
   1098   return false;
   1099 }
   1100 
   1101 // For vector-matrix dot products, it is always profitable to make the Rhs
   1102 // column major.
   1103 tensorflow::gtl::optional<int64> ProfitableToMakeDotOperandColumnMajor(
   1104     const HloInstruction& hlo) {
   1105   if (hlo.opcode() == HloOpcode::kDot && hlo.shape().dimensions_size() == 2 &&
   1106       hlo.shape().dimensions(0) == 1) {
   1107     if (hlo.dot_dimension_numbers().rhs_contracting_dimensions(0) == 0) {
   1108       return 1;
   1109     }
   1110     return {};
   1111   }
   1112 
   1113   if (hlo.opcode() == HloOpcode::kFusion &&
   1114       hlo.fusion_kind() == HloInstruction::FusionKind::kOutput) {
   1115     auto* fusion_root =
   1116         hlo.fused_instructions_computation()->root_instruction();
   1117     if (fusion_root->opcode() != HloOpcode::kAdd) {
   1118       return {};
   1119     }
   1120 
   1121     for (auto* fusion_root_op : fusion_root->operands()) {
   1122       if (fusion_root_op->opcode() != HloOpcode::kDot) {
   1123         continue;
   1124       }
   1125       if (auto operand_num =
   1126               ProfitableToMakeDotOperandColumnMajor(*fusion_root_op)) {
   1127         auto* operand = fusion_root_op->operand(*operand_num);
   1128         if (operand->opcode() == HloOpcode::kParameter &&
   1129             operand->user_count() == 1) {
   1130           return operand->parameter_number();
   1131         }
   1132       }
   1133     }
   1134   }
   1135 
   1136   return {};
   1137 }
   1138 
   1139 bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot) {
   1140   // Any Matrix-Vector product of floating point or integral type, or
   1141   // a transpose-dot fusion of the same can be lowered to a tiled LLVM
   1142   // IR implementation.
   1143   const Shape& shape = dot.shape();
   1144   return shape.dimensions_size() == 2 &&
   1145          (shape.dimensions(0) == 1 || shape.dimensions(1) == 1) &&
   1146          (primitive_util::IsFloatingPointType(shape.element_type()) ||
   1147           primitive_util::IsIntegralType(shape.element_type()));
   1148 }
   1149 
   1150 }  // namespace cpu
   1151 }  // namespace xla
   1152