Home | History | Annotate | Download | only in cpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
     17 
     18 #include <memory>
     19 #include <vector>
     20 
     21 #include "absl/strings/str_cat.h"
     22 #include "llvm/IR/BasicBlock.h"
     23 #include "llvm/IR/Instructions.h"
     24 #include "llvm/IR/Module.h"
     25 #include "llvm/IR/Value.h"
     26 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
     27 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
     28 #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
     29 #include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h"
     30 #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
     31 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
     32 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
     33 #include "tensorflow/compiler/xla/service/hlo_module.h"
     34 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
     35 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
     36 #include "tensorflow/compiler/xla/shape_util.h"
     37 #include "tensorflow/compiler/xla/status_macros.h"
     38 #include "tensorflow/compiler/xla/util.h"
     39 #include "tensorflow/compiler/xla/xla_data.pb.h"
     40 #include "tensorflow/core/platform/logging.h"
     41 
     42 namespace xla {
     43 
     44 using llvm_ir::SetToFirstInsertPoint;
     45 
     46 namespace cpu {
     47 namespace {
     48 // Returns true if we should call into multi-threaded Eigen routines.
     49 bool ShouldUseMultiThreadedEigen(const HloModuleConfig& config) {
     50   return config.debug_options().xla_cpu_multi_thread_eigen();
     51 }
     52 
     53 // Represents a dot operation.  We use this in lieu of an `HloInstruction`
     54 // because we want to be able to create this for the "inner" dot operation in a
     55 // batch dot, for which there is no separate HLO instruction.
     56 struct DotInfo {
     57   Shape lhs_shape;
     58   Shape rhs_shape;
     59   Shape result_shape;
     60   DotDimensionNumbers dim_nums;
     61 
     62   DotInfo() = default;
     63 
     64   explicit DotInfo(const HloInstruction& instr) {
     65     CHECK_EQ(instr.opcode(), HloOpcode::kDot);
     66     lhs_shape = instr.operand(0)->shape();
     67     rhs_shape = instr.operand(1)->shape();
     68     result_shape = instr.shape();
     69     dim_nums = instr.dot_dimension_numbers();
     70   }
     71 };
     72 
     73 // Dictates how a dot operation is implemented.
     74 enum class DotImplementationStrategy {
     75   // The dot operation is lowered into LLVM IR that implements a naive nested
     76   // loop that computes the result one element at a time.  This is our
     77   // "fallback"; we don't really want this to kick in for any non-trival dot
     78   // operation.
     79   kNaiveLlvmIr,
     80 
     81   // The dot operation is lowered into LLVM IR that implements a tiled
     82   // Matrix*Vector operation.  This strategy also allows fusing in a bias add
     83   // into the dot.  The matrix can be row major or column major, both are
     84   // supported.
     85   kTiledLlvmIrGemv,
     86 
     87   // The dot operation is lowered into LLVM IR that implemetns a tiled
     88   // Matrix*Matrix operation.  No fusions are supported.  The two inputs
     89   // and the output have to be row major.
     90   kTiledLlvmIrGemm,
     91 
     92   // The dot operation is lowered into a call into an Eigen routine.  No fusions
     93   // are supported today.  The two inputs and the output have to be row major.
     94   // However, we do allow transposing either the LHS or the RHS as part of the
     95   // GEMM -- we expose this flexibility as flexibility in the contraction
     96   // dimensions, but we can also see this as flexibility in the input layouts.
     97   kEigen,
     98 };
     99 
    100 // Returns the implementation strategy for a dot with the configuration
    101 // `dot_info`.
    102 DotImplementationStrategy GetDotImplementationStrategy(
    103     const HloModuleConfig& config, const DotInfo& dot_info,
    104     const TargetMachineFeatures& target_machine_features);
    105 
    106 // Helper class for emitting LLVM IR to perform the dot operation.
    107 class DotOpEmitter {
    108  public:
    109   explicit DotOpEmitter(DotInfo dot_info, string dot_hlo_name,
    110                         const llvm_ir::IrArray& target_array,
    111                         const llvm_ir::IrArray& lhs_array,
    112                         const llvm_ir::IrArray& rhs_array,
    113                         const llvm_ir::IrArray* addend_array,
    114                         llvm::Value* executable_run_options_value,
    115                         llvm::IRBuilder<>* b,
    116                         const HloModuleConfig& hlo_module_config,
    117                         const TargetMachineFeatures& target_machine_features);
    118 
    119   // Emits the IR to perform the dot operation.
    120   Status Emit();
    121 
    122  private:
    123   // Emits instructions to perform a scalar dot product (a multiply of the
    124   // LHS and RHS) and store the results in the target.
    125   Status EmitScalarDot();
    126 
    127   // Emits a call to the CPU runtime to perform the matrix multiply.
    128   Status EmitCallToRuntime();
    129 
    130   // Represents the dimensions of a matrix-matrix multiply operation.
    131   struct MatMultDims {
    132     // The number of rows in the LHS.
    133     int64 m;
    134 
    135     // The number of columns in the LHS, which is also must be equal to the
    136     // number of rows in the RHS.
    137     int64 k;
    138 
    139     // The number of columns on the RHS.
    140     int64 n;
    141 
    142     // True if the LHS matrix is column major.
    143     bool lhs_column_major;
    144 
    145     // True if the LHS contraction dimension is not 1.
    146     bool lhs_non_canonical;
    147 
    148     // True if the RHS matrix is column major.
    149     bool rhs_column_major;
    150 
    151     // True if the RHS contraction dimension is not 0.
    152     bool rhs_non_canonical;
    153 
    154     // True if the result matrix is column major.
    155     bool target_column_major;
    156   };
    157 
    158   // Get the MatMultDims instance for the dot product this DotOpEmitter
    159   // represents.  Precondition: the dot is of rank 2 (and thus its operands are
    160   // of rank 2 as well).
    161   MatMultDims GetMatMultDims() const;
    162 
    163   // Lowers the dot operation as a tiled Matrix*Vector loop.
    164   void EmitTiledLlvmIrGemv();
    165 
    166   // Lowers the dot operation as a tiled Matrix*Matrix loop.
    167   void EmitTiledLlvmIrGemm();
    168 
    169   // Lowers the dot operation as a naive nested loop that computes the result
    170   // one element at a time.
    171   void EmitNaiveLlvmIrGemm();
    172 
    173   // When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector
    174   // registers.
    175   int64 GetGemvTilingFactor() const {
    176     const int64 kDefaultTilingFactor = 8;
    177     return options::LlvmIrGemvTilingFactor(hlo_module_config_)
    178         .value_or(kDefaultTilingFactor);
    179   }
    180 
    181   std::tuple<int64, int64, int64> GetGemmTileSize() const {
    182     // Tuned for broadwell - Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz
    183     //
    184     // TODO(b/80093688): Tune for other architectures and centralize this
    185     // information in one place.
    186     const std::tuple<int64, int64, int64> kDefaultTileSize =
    187         std::tuple<int64, int64, int64>(11, 9, 1);
    188     return options::LlvmIrGemmTileSize(hlo_module_config_)
    189         .value_or(kDefaultTileSize);
    190   }
    191 
    192   DotInfo dot_info_;
    193   string dot_hlo_name_;
    194   const llvm_ir::IrArray& target_array_;
    195   const llvm_ir::IrArray& lhs_array_;
    196   const llvm_ir::IrArray& rhs_array_;
    197   const llvm_ir::IrArray* addend_array_;
    198   llvm::Value* executable_run_options_value_;
    199   llvm::IRBuilder<>* b_;
    200   const HloModuleConfig& hlo_module_config_;
    201   const TargetMachineFeatures& target_machine_features_;
    202 };
    203 }  // namespace
    204 
    205 DotOpEmitter::DotOpEmitter(DotInfo dot_info, string dot_hlo_name,
    206                            const llvm_ir::IrArray& target_array,
    207                            const llvm_ir::IrArray& lhs_array,
    208                            const llvm_ir::IrArray& rhs_array,
    209                            const llvm_ir::IrArray* addend_array,
    210                            llvm::Value* executable_run_options_value,
    211                            llvm::IRBuilder<>* b,
    212                            const HloModuleConfig& hlo_module_config,
    213                            const TargetMachineFeatures& target_machine_features)
    214     : dot_info_(std::move(dot_info)),
    215       dot_hlo_name_(std::move(dot_hlo_name)),
    216       target_array_(target_array),
    217       lhs_array_(lhs_array),
    218       rhs_array_(rhs_array),
    219       addend_array_(addend_array),
    220       executable_run_options_value_(executable_run_options_value),
    221       b_(b),
    222       hlo_module_config_(hlo_module_config),
    223       target_machine_features_(target_machine_features) {}
    224 
    225 void DotOpEmitter::EmitTiledLlvmIrGemm() {
    226   PrimitiveType primitive_type = dot_info_.result_shape.element_type();
    227   MatMultDims mat_mult_dims = GetMatMultDims();
    228 
    229   llvm::Value* lhs = lhs_array_.GetBasePointer();
    230   llvm::Value* rhs = rhs_array_.GetBasePointer();
    231   llvm::Value* target = target_array_.GetBasePointer();
    232   int64 m = mat_mult_dims.m;
    233   int64 k = mat_mult_dims.k;
    234   int64 n = mat_mult_dims.n;
    235 
    236   if (mat_mult_dims.lhs_column_major) {
    237     std::swap(lhs, rhs);
    238     std::swap(m, n);
    239   }
    240 
    241   int64 size_bytes = m * n * ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
    242   b_->CreateMemSet(target, b_->getInt8(0), /*Size=*/size_bytes,
    243                    /*Align=*/1);
    244 
    245   int64 max_target_vector_width =
    246       target_machine_features_.vector_register_num_elements(
    247           *b_->GetInsertBlock()->getParent(), primitive_type);
    248 
    249   int64 tile_size_m, tile_size_k, tile_size_n_in_vector_width;
    250   std::tie(tile_size_m, tile_size_k, tile_size_n_in_vector_width) =
    251       GetGemmTileSize();
    252 
    253   EmitSmallGemm(
    254       /*scalar_type=*/primitive_type,
    255       /*m=*/m, /*k=*/k, /*n=*/n,
    256       /*max_vectorization_width=*/max_target_vector_width,
    257       /*max_vector_count=*/tile_size_n_in_vector_width,
    258       /*min_vectorization_width=*/std::min<int64>(4, max_target_vector_width),
    259       /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k, /*lhs=*/lhs,
    260       /*rhs=*/rhs, /*result=*/target, b_, hlo_module_config_);
    261 }
    262 
    263 void DotOpEmitter::EmitTiledLlvmIrGemv() {
    264   PrimitiveType primitive_type = dot_info_.result_shape.element_type();
    265 
    266   CHECK(primitive_util::IsFloatingPointType(primitive_type) ||
    267         primitive_util::IsIntegralType(primitive_type));
    268 
    269   MatMultDims mat_mult_dims = GetMatMultDims();
    270   bool is_column_major_matrix_vector = false;
    271   bool is_row_major_matrix_vector = false;
    272 
    273   int64 m, k;
    274   bool swap_operands;
    275 
    276   if (mat_mult_dims.m == 1) {
    277     bool rhs_effectively_row_major =
    278         mat_mult_dims.rhs_non_canonical ^ !mat_mult_dims.rhs_column_major;
    279     if (rhs_effectively_row_major) {
    280       k = mat_mult_dims.k;
    281       m = mat_mult_dims.n;
    282       is_column_major_matrix_vector = true;
    283       swap_operands = true;
    284     } else {
    285       k = mat_mult_dims.k;
    286       m = mat_mult_dims.n;
    287       is_row_major_matrix_vector = true;
    288       swap_operands = true;
    289     }
    290   }
    291 
    292   if (mat_mult_dims.n == 1) {
    293     bool lhs_effectively_column_major =
    294         mat_mult_dims.lhs_non_canonical ^ mat_mult_dims.lhs_column_major;
    295     if (lhs_effectively_column_major) {
    296       m = mat_mult_dims.m;
    297       k = mat_mult_dims.k;
    298       is_column_major_matrix_vector = true;
    299       swap_operands = false;
    300     } else {
    301       m = mat_mult_dims.m;
    302       k = mat_mult_dims.k;
    303       is_row_major_matrix_vector = true;
    304       swap_operands = false;
    305     }
    306   }
    307 
    308   CHECK(is_column_major_matrix_vector || is_row_major_matrix_vector);
    309 
    310   int64 tiling_factor = GetGemvTilingFactor();
    311   CHECK_GT(tiling_factor, 0);
    312 
    313   llvm::Value* result_op = target_array_.GetBasePointer();
    314   llvm::Value* lhs_op =
    315       swap_operands ? rhs_array_.GetBasePointer() : lhs_array_.GetBasePointer();
    316   llvm::Value* rhs_op =
    317       swap_operands ? lhs_array_.GetBasePointer() : rhs_array_.GetBasePointer();
    318 
    319   const int target_vector_register_element_size =
    320       target_machine_features_.vector_register_num_elements(
    321           *b_->GetInsertBlock()->getParent(), primitive_type);
    322 
    323   // We may not always know the vector register size for the target we're
    324   // compiling against, in which case target_vector_register_element_size is 0.
    325   // In these cases we choose a default LLVM IR register size.
    326   const int kUnknownTargetVectorRegisterSize = 4;
    327   const int vector_register_element_size =
    328       target_vector_register_element_size == 0
    329           ? kUnknownTargetVectorRegisterSize
    330           : target_vector_register_element_size;
    331 
    332   if (is_column_major_matrix_vector) {
    333     VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m
    334             << " and k = " << k;
    335     EmitColumnMajorGemv(
    336         /*scalar_type=*/primitive_type,
    337         /*tile_rows=*/vector_register_element_size, /*tile_cols=*/tiling_factor,
    338         /*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op,
    339         /*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr,
    340         /*result=*/result_op, b_, hlo_module_config_);
    341   } else {
    342     VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m
    343             << " and k = " << k;
    344     EmitRowMajorGemv(
    345         /*scalar_type=*/primitive_type,
    346         /*tile_rows=*/tiling_factor,
    347         /*tile_cols=*/vector_register_element_size,
    348         /*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op,
    349         /*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr,
    350         /*result=*/result_op, b_, hlo_module_config_);
    351   }
    352 }
    353 
    354 Status DotOpEmitter::Emit() {
    355   // The dot operation performs a sum of products over dimension 0 of the left
    356   // hand side operand and dimension 1 of the right hand side operand.
    357   //
    358   // Let the shapes of lhs and rhs be defined as below:
    359   //
    360   //   lhs = [L{n-1} x L{n-2} x ... L{0}]
    361   //   rhs = [R{m-1} x R{m-2} x ... R{0}]
    362   //
    363   // The sum-of-products dimension in the lhs has size L{0} and the dimension in
    364   // the rhs has size R{1}. Necessarily, then:
    365   //
    366   //   L{0} == R{1}
    367   //
    368   // The output of the operation has the following shape:
    369   //
    370   //   output = [L{n-1} x L{n-2} x ... L{1} x R{m-1} x R{m-2} x ... R{2} x R{0}]
    371   //
    372   // To perform the operation we construct a loop nest with one for-loop for
    373   // each dimension of the output. Inside this loop nest is another for-loop
    374   // which performs the sum-of-products (the reduction loop) before storing
    375   // the result in the output buffer.
    376 
    377   const Shape& lhs_shape = lhs_array_.GetShape();
    378   const Shape& rhs_shape = rhs_array_.GetShape();
    379 
    380   if (ShapeUtil::IsScalar(lhs_shape) || ShapeUtil::IsScalar(rhs_shape)) {
    381     // If the operands are scalar, don't emit any loops.
    382     TF_RET_CHECK(ShapeUtil::IsScalar(lhs_shape) &&
    383                  ShapeUtil::IsScalar(rhs_shape));
    384     return EmitScalarDot();
    385   }
    386 
    387   switch (GetDotImplementationStrategy(hlo_module_config_, dot_info_,
    388                                        target_machine_features_)) {
    389     case DotImplementationStrategy::kNaiveLlvmIr:
    390       EmitNaiveLlvmIrGemm();
    391       return Status::OK();
    392 
    393     case DotImplementationStrategy::kTiledLlvmIrGemv:
    394       EmitTiledLlvmIrGemv();
    395       return Status::OK();
    396 
    397     case DotImplementationStrategy::kTiledLlvmIrGemm:
    398       EmitTiledLlvmIrGemm();
    399       return Status::OK();
    400 
    401     case DotImplementationStrategy::kEigen:
    402       return EmitCallToRuntime();
    403   }
    404 }
    405 
    406 void DotOpEmitter::EmitNaiveLlvmIrGemm() {
    407   CHECK_EQ(addend_array_, nullptr);
    408 
    409   const Shape& lhs_shape = lhs_array_.GetShape();
    410   const Shape& rhs_shape = rhs_array_.GetShape();
    411   const DotDimensionNumbers& dim_nums = dot_info_.dim_nums;
    412 
    413   // Reduce along dimension 0 of the LHS and 1 of the RHS. Vectors are a special
    414   // case where the reduction dimension is 0 for both LHS and RHS. This results
    415   // in a vector dot product producing a scalar.
    416   int64 lhs_reduction_dimension = dim_nums.lhs_contracting_dimensions(0);
    417   int64 rhs_reduction_dimension = dim_nums.rhs_contracting_dimensions(0);
    418 
    419   // Verify the reduction dimension in the two operands are the same size.
    420   CHECK_EQ(lhs_shape.dimensions(lhs_reduction_dimension),
    421            rhs_shape.dimensions(rhs_reduction_dimension));
    422 
    423   bool lhs_reduction_along_minor_dimension =
    424       lhs_reduction_dimension == LayoutUtil::Minor(lhs_shape.layout(), 0);
    425   bool rhs_reduction_along_minor_dimension =
    426       rhs_reduction_dimension == LayoutUtil::Minor(rhs_shape.layout(), 0);
    427 
    428   // Create loop nests which loop through the LHS operand dimensions and the RHS
    429   // operand dimensions. The reduction dimension of the LHS and RHS are handled
    430   // in a separate innermost loop which performs the sum of products.
    431   llvm_ir::ForLoopNest loop_nest(llvm_ir::IrName(dot_hlo_name_), b_);
    432   std::vector<llvm::Value*> lhs_multi_index =
    433       loop_nest.EmitOperandArrayLoopNest(
    434           lhs_array_, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs");
    435   std::vector<llvm::Value*> rhs_multi_index =
    436       loop_nest.EmitOperandArrayLoopNest(
    437           rhs_array_, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs");
    438 
    439   // Create the loop which does the sum of products reduction.
    440   //
    441   // The prevent_unrolling bit is working around a deficiency in LLVM's loop
    442   // vectorization pipeline, wherein in some cases unrolling a loop can prevent
    443   // effective vectorization.  Since we know that the IR we generate when
    444   // reducing across the minor dimension in both LHS and RHS is vectorized well
    445   // by the loop vectorizer, we block unrolling in that case to stop loop unroll
    446   // from messing up the vectorization.
    447   std::unique_ptr<llvm_ir::ForLoop> reduction_loop = loop_nest.AddLoop(
    448       0, lhs_shape.dimensions(lhs_reduction_dimension), "reduction",
    449       /*unroll_mode=*/
    450       (lhs_reduction_along_minor_dimension &&
    451        rhs_reduction_along_minor_dimension)
    452           ? xla::llvm_ir::UnrollMode::kNoUnroll
    453           : xla::llvm_ir::UnrollMode::kDefaultUnroll);
    454 
    455   // The final entry in the rhs and lhs indexes is the indvar of the
    456   // reduction loop.
    457   lhs_multi_index[lhs_reduction_dimension] = reduction_loop->GetIndVarValue();
    458   llvm_ir::IrArray::Index lhs_index(lhs_multi_index, lhs_shape,
    459                                     b_->getInt64Ty());
    460   rhs_multi_index[rhs_reduction_dimension] = reduction_loop->GetIndVarValue();
    461   llvm_ir::IrArray::Index rhs_index(rhs_multi_index, rhs_shape,
    462                                     b_->getInt64Ty());
    463 
    464   // For computing the sum of products we alloca a single location to store the
    465   // dot product result as we accumulate it within the reduction loop. After the
    466   // reduction loop we load the result and store into the output array.
    467 
    468   // Function entry basic block.
    469   // - Emit alloca for accumulator
    470   llvm::Function* func = reduction_loop->GetPreheaderBasicBlock()->getParent();
    471   SetToFirstInsertPoint(&func->getEntryBlock(), b_);
    472   llvm::Type* accum_type = target_array_.GetElementLlvmType();
    473   llvm::Value* accum_address =
    474       b_->CreateAlloca(accum_type, /*ArraySize=*/nullptr, "accum_address");
    475 
    476   // Preheader basic block of reduction loop:
    477   // - Initialize accumulator to zero.
    478   llvm::BasicBlock* preheader_bb = reduction_loop->GetPreheaderBasicBlock();
    479   b_->SetInsertPoint(preheader_bb->getTerminator());
    480 
    481   b_->CreateStore(llvm::Constant::getNullValue(accum_type), accum_address);
    482 
    483   // Body basic block of reduction loop:
    484   // - Load elements from lhs and rhs array.
    485   // - Multiply lhs-element and rhs-element.
    486   // - Load accumulator and add to product.
    487   // - Store sum back into accumulator.
    488   SetToFirstInsertPoint(reduction_loop->GetBodyBasicBlock(), b_);
    489 
    490   llvm::Value* lhs_element = lhs_array_.EmitReadArrayElement(lhs_index, b_);
    491   llvm::Value* rhs_element = rhs_array_.EmitReadArrayElement(rhs_index, b_);
    492 
    493   llvm::Value* accum = b_->CreateLoad(accum_address);
    494   llvm::Value* updated_accum;
    495   if (ShapeUtil::ElementIsComplex(lhs_shape)) {
    496     auto real = [&](llvm::Value* x) { return b_->CreateExtractValue(x, {0}); };
    497     auto imag = [&](llvm::Value* x) { return b_->CreateExtractValue(x, {1}); };
    498     llvm::Value* product_real =
    499         b_->CreateFSub(b_->CreateFMul(real(lhs_element), real(rhs_element)),
    500                        b_->CreateFMul(imag(lhs_element), imag(rhs_element)));
    501     llvm::Value* product_imag =
    502         b_->CreateFAdd(b_->CreateFMul(real(lhs_element), imag(rhs_element)),
    503                        b_->CreateFMul(imag(lhs_element), real(rhs_element)));
    504     updated_accum = b_->CreateInsertValue(
    505         accum, b_->CreateFAdd(real(accum), product_real), {0});
    506     updated_accum = b_->CreateInsertValue(
    507         updated_accum, b_->CreateFAdd(imag(accum), product_imag), {1});
    508   } else {
    509     llvm::Value* product = b_->CreateFMul(lhs_element, rhs_element);
    510     updated_accum = b_->CreateFAdd(accum, product);
    511   }
    512   b_->CreateStore(updated_accum, accum_address);
    513 
    514   // Exit basic block of reduction loop.
    515   // - Load accumulator value (the result).
    516   // - Store into output array.
    517   SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), b_);
    518 
    519   llvm::Value* result = b_->CreateLoad(accum_address);
    520 
    521   // Create index into target address. The target index is the concatenation of
    522   // the rhs and lhs indexes with the reduction dimensions removed. The terms
    523   // from the rhs index are the lower dimensions in the index so we add them
    524   // first.
    525   std::vector<llvm::Value*> target_multi_index;
    526   for (int dimension = 0; dimension < lhs_index.size(); ++dimension) {
    527     if (dimension != lhs_reduction_dimension) {
    528       target_multi_index.push_back(lhs_index[dimension]);
    529     }
    530   }
    531   for (int dimension = 0; dimension < rhs_index.size(); ++dimension) {
    532     if (dimension != rhs_reduction_dimension) {
    533       target_multi_index.push_back(rhs_index[dimension]);
    534     }
    535   }
    536 
    537   llvm_ir::IrArray::Index target_index(
    538       target_multi_index, target_array_.GetShape(), lhs_index.GetType());
    539   target_array_.EmitWriteArrayElement(target_index, result, b_);
    540 
    541   // Set the IR builder insert point to the exit basic block of the outer most
    542   // loop.
    543   b_->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock());
    544 }
    545 
    546 Status DotOpEmitter::EmitScalarDot() {
    547   // A scalar dot is just a scalar multiply.
    548   llvm::Value* result;
    549   // Use the same index_type for all tensor accesses in the same kernel.
    550   llvm::Type* index_type = b_->getInt64Ty();
    551   llvm_ir::IrArray::Index element_index(index_type);
    552   llvm::Value* lhs_value =
    553       lhs_array_.EmitReadArrayElement(/*index=*/element_index, b_);
    554   llvm::Value* rhs_value =
    555       rhs_array_.EmitReadArrayElement(/*index=*/element_index, b_);
    556   if (ShapeUtil::ElementIsComplex(lhs_array_.GetShape())) {
    557     auto get_real = [&](llvm::Value* x) {
    558       return b_->CreateExtractValue(x, {0});
    559     };
    560 
    561     auto get_imag = [&](llvm::Value* x) {
    562       return b_->CreateExtractValue(x, {1});
    563     };
    564 
    565     llvm::Value* real = b_->CreateFSub(
    566         b_->CreateFMul(get_real(lhs_value), get_real(rhs_value)),
    567         b_->CreateFMul(get_imag(lhs_value), get_imag(rhs_value)));
    568     llvm::Value* imag = b_->CreateFAdd(
    569         b_->CreateFMul(get_real(lhs_value), get_imag(rhs_value)),
    570         b_->CreateFMul(get_imag(lhs_value), get_real(rhs_value)));
    571     result = llvm::ConstantAggregateZero::get(lhs_array_.GetElementLlvmType());
    572     result = b_->CreateInsertValue(result, real, {0});
    573     result = b_->CreateInsertValue(result, imag, {1});
    574   } else {
    575     result = b_->CreateFMul(lhs_value, rhs_value);
    576   }
    577   target_array_.EmitWriteArrayElement(/*index=*/element_index, result, b_);
    578   return Status::OK();
    579 }
    580 
    581 Status DotOpEmitter::EmitCallToRuntime() {
    582   // The signature of the Eigen runtime matmul function is:
    583   //
    584   //   (void)(void* run_options, float* out, float* lhs, float* rhs,
    585   //          int64 m, int64 n, int64 k, int32 transpose_lhs,
    586   //          int32 transpose_rhs);
    587   // The two transpose_... parameters are actually booleans, but we use int32
    588   // to avoid target-dependent calling convention details.
    589 
    590   bool multi_threaded = ShouldUseMultiThreadedEigen(hlo_module_config_);
    591   bool use_mkl_dnn = hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn();
    592   PrimitiveType type = target_array_.GetShape().element_type();
    593   llvm::Type* float_type;
    594   const char* fn_name;
    595   switch (type) {
    596     case F16:
    597       fn_name = multi_threaded
    598                     ? runtime::kEigenMatMulF16SymbolName
    599                     : runtime::kEigenSingleThreadedMatMulF16SymbolName;
    600       float_type = b_->getHalfTy();
    601       break;
    602     case F32:
    603       fn_name = multi_threaded
    604                     ? (use_mkl_dnn ? runtime::kMKLMatMulF32SymbolName
    605                                    : runtime::kEigenMatMulF32SymbolName)
    606                     : (use_mkl_dnn
    607                            ? runtime::kMKLSingleThreadedMatMulF32SymbolName
    608                            : runtime::kEigenSingleThreadedMatMulF32SymbolName);
    609       float_type = b_->getFloatTy();
    610       break;
    611     case F64:
    612       fn_name = multi_threaded
    613                     ? (use_mkl_dnn ? runtime::kMKLMatMulF64SymbolName
    614                                    : runtime::kEigenMatMulF64SymbolName)
    615                     : (use_mkl_dnn
    616                            ? runtime::kMKLSingleThreadedMatMulF64SymbolName
    617                            : runtime::kEigenSingleThreadedMatMulF64SymbolName);
    618       float_type = b_->getDoubleTy();
    619       break;
    620     default:
    621       return Unimplemented("Invalid type %s for dot operation",
    622                            PrimitiveType_Name(type));
    623   }
    624 
    625   llvm::Type* float_ptr_type = float_type->getPointerTo();
    626   llvm::Type* int64_type = b_->getInt64Ty();
    627   llvm::Type* int32_type = b_->getInt32Ty();
    628   llvm::Type* int8_ptr_type = b_->getInt8Ty()->getPointerTo();
    629   llvm::FunctionType* matmul_type = llvm::FunctionType::get(
    630       b_->getVoidTy(),
    631       {int8_ptr_type, float_ptr_type, float_ptr_type, float_ptr_type,
    632        int64_type, int64_type, int64_type, int32_type, int32_type},
    633       /*isVarArg=*/false);
    634 
    635   llvm::Function* function = b_->GetInsertBlock()->getParent();
    636   llvm::Module* module = function->getParent();
    637 
    638   llvm::FunctionCallee matmul_func =
    639       module->getOrInsertFunction(fn_name, matmul_type);
    640   if (auto* fn = llvm::dyn_cast<llvm::Function>(matmul_func.getCallee())) {
    641     fn->setCallingConv(llvm::CallingConv::C);
    642     fn->setDoesNotThrow();
    643     fn->setOnlyAccessesArgMemory();
    644   }
    645 
    646   // The Eigen runtime function expects column-major layout. If the matrices are
    647   // row major, then use the following identity to compute the product:
    648   //
    649   //   (A x B)^T = B^T x A^T
    650   //
    651   // The connection between this identity and memory layout is that the
    652   // transpose operation can also be considered as an operation that changes the
    653   // memory layout of a matrix from row-major to column-major or vice versa.
    654   //
    655   // Effectively this involves swapping the 'lhs' with 'rhs' and 'm' with 'n'.
    656 
    657   MatMultDims mat_mult_dims = GetMatMultDims();
    658 
    659   CHECK_EQ(mat_mult_dims.lhs_column_major, mat_mult_dims.rhs_column_major);
    660 
    661   const llvm_ir::IrArray* lhs = &lhs_array_;
    662   const llvm_ir::IrArray* rhs = &rhs_array_;
    663   bool transpose_lhs = mat_mult_dims.lhs_non_canonical;
    664   bool transpose_rhs = mat_mult_dims.rhs_non_canonical;
    665 
    666   if (!mat_mult_dims.lhs_column_major) {
    667     std::swap(mat_mult_dims.m, mat_mult_dims.n);
    668     std::swap(lhs, rhs);
    669     std::swap(transpose_lhs, transpose_rhs);
    670   }
    671 
    672   b_->CreateCall(
    673       matmul_func,
    674       {b_->CreateBitCast(executable_run_options_value_, int8_ptr_type),
    675        b_->CreateBitCast(target_array_.GetBasePointer(), float_ptr_type),
    676        b_->CreateBitCast(lhs->GetBasePointer(), float_ptr_type),
    677        b_->CreateBitCast(rhs->GetBasePointer(), float_ptr_type),
    678        b_->getInt64(mat_mult_dims.m), b_->getInt64(mat_mult_dims.n),
    679        b_->getInt64(mat_mult_dims.k), b_->getInt32(transpose_lhs),
    680        b_->getInt32(transpose_rhs)});
    681   return Status::OK();
    682 }
    683 
    684 DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const {
    685   CHECK_EQ(dot_info_.result_shape.dimensions_size(), 2);
    686 
    687   const Shape& lhs_shape = lhs_array_.GetShape();
    688   const Shape& rhs_shape = rhs_array_.GetShape();
    689   const DotDimensionNumbers& dim_nums = dot_info_.dim_nums;
    690 
    691   return {
    692       /*m=*/lhs_shape.dimensions(1 - dim_nums.lhs_contracting_dimensions(0)),
    693       /*k=*/lhs_shape.dimensions(dim_nums.lhs_contracting_dimensions(0)),
    694       /*n=*/rhs_shape.dimensions(1 - dim_nums.rhs_contracting_dimensions(0)),
    695       /*lhs_column_major=*/LayoutUtil::Minor(lhs_shape.layout(), 0) == 0,
    696       /*lhs_non_canonical=*/dim_nums.lhs_contracting_dimensions(0) == 0,
    697       /*rhs_column_major=*/LayoutUtil::Minor(rhs_shape.layout(), 0) == 0,
    698       /*rhs_non_canonical=*/dim_nums.rhs_contracting_dimensions(0) == 1,
    699       /*target_column_major=*/
    700       LayoutUtil::Minor(target_array_.GetShape().layout(), 0) == 0};
    701 }
    702 
    703 // For vector-matrix dot products, it is always profitable to make the Rhs
    704 // column major.
    705 absl::optional<int64> ProfitableToMakeDotOperandColumnMajor(
    706     const HloInstruction& hlo) {
    707   if (hlo.opcode() == HloOpcode::kDot && hlo.shape().dimensions_size() == 2 &&
    708       hlo.shape().dimensions(0) == 1) {
    709     if (hlo.dot_dimension_numbers().rhs_contracting_dimensions(0) == 0) {
    710       return 1;
    711     }
    712     return {};
    713   }
    714 
    715   if (hlo.opcode() == HloOpcode::kFusion &&
    716       hlo.fusion_kind() == HloInstruction::FusionKind::kOutput) {
    717     auto* fusion_root =
    718         hlo.fused_instructions_computation()->root_instruction();
    719     if (fusion_root->opcode() != HloOpcode::kAdd) {
    720       return {};
    721     }
    722 
    723     for (auto* fusion_root_op : fusion_root->operands()) {
    724       if (fusion_root_op->opcode() != HloOpcode::kDot) {
    725         continue;
    726       }
    727       if (auto operand_num =
    728               ProfitableToMakeDotOperandColumnMajor(*fusion_root_op)) {
    729         auto* operand = fusion_root_op->operand(*operand_num);
    730         if (operand->opcode() == HloOpcode::kParameter &&
    731             operand->user_count() == 1) {
    732           return operand->parameter_number();
    733         }
    734       }
    735     }
    736   }
    737 
    738   return {};
    739 }
    740 
    741 namespace {
    742 // Return whether the given shape is rank 2.
    743 bool IsRank2(const Shape& shape) { return shape.rank() == 2; }
    744 
    745 bool IsSimpleLayout(const Layout& layout) {
    746   return layout.tiles().empty() && layout.format() == DENSE;
    747 }
    748 
    749 // In a gemm operation where output = lhs * rhs, check whether the given shapes
    750 // are valid for the operation.
    751 bool AreGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
    752                    const Shape& output_shape,
    753                    const TargetMachineFeatures& target_machine_features) {
    754   CHECK(!lhs_shape.has_layout() || IsSimpleLayout(lhs_shape.layout()))
    755       << lhs_shape.DebugString();
    756   CHECK(!rhs_shape.has_layout() || IsSimpleLayout(rhs_shape.layout()))
    757       << rhs_shape.DebugString();
    758   CHECK(!output_shape.has_layout() || IsSimpleLayout(output_shape.layout()))
    759       << output_shape.DebugString();
    760 
    761   switch (output_shape.element_type()) {
    762     case F64:
    763     case F32:
    764     case F16:
    765       return IsRank2(lhs_shape) && IsRank2(rhs_shape) && IsRank2(output_shape);
    766     default:
    767       return false;
    768   }
    769 }
    770 
    771 bool IsAlignedGemm(const DotInfo& dot_info,
    772                    const TargetMachineFeatures& target_machine_features) {
    773   if (ShapeUtil::IsZeroElementArray(dot_info.lhs_shape) ||
    774       ShapeUtil::IsZeroElementArray(dot_info.rhs_shape)) {
    775     return false;
    776   }
    777 
    778   return AreGemmShapes(dot_info.lhs_shape, dot_info.rhs_shape,
    779                        dot_info.result_shape, target_machine_features);
    780 }
    781 
    782 bool CanEmitTiledLlvmIrGemm(
    783     const HloModuleConfig& config, const DotInfo& dot_info,
    784     const TargetMachineFeatures& target_machine_features) {
    785   CHECK(IsAlignedGemm(dot_info, target_machine_features));
    786 
    787   if (ShouldUseMultiThreadedEigen(config)) {
    788     return false;
    789   }
    790 
    791   int m = dot_info.result_shape.dimensions(0);
    792   int k = dot_info.lhs_shape.dimensions(
    793       dot_info.dim_nums.lhs_contracting_dimensions(0));
    794   int n = dot_info.result_shape.dimensions(1);
    795 
    796   if (!options::ForceEnableExperimentalLlvmIrGemm(config)) {
    797     // TODO(sanjoy):  We should make these numbers micro-arch specific.
    798     bool small_gemm =
    799         k <= 128 && ((m <= 32 && n <= 128) || (m <= 128 && n <= 32));
    800     if (!small_gemm) {
    801       return false;
    802     }
    803   }
    804 
    805   bool lhs_non_canonical = dot_info.dim_nums.lhs_contracting_dimensions(0) == 0;
    806   bool rhs_non_canonical = dot_info.dim_nums.rhs_contracting_dimensions(0) == 1;
    807 
    808   if (lhs_non_canonical || rhs_non_canonical) {
    809     return false;
    810   }
    811 
    812   if (dot_info.result_shape.element_type() == F16) {
    813     // TODO(sanjoy): This is probably easy to fix, but I want to keep the CL
    814     // adding this comment NFC.
    815     return false;
    816   }
    817 
    818   return true;
    819 }
    820 
    821 DotImplementationStrategy GetDotImplementationStrategy(
    822     const HloModuleConfig& config, const DotInfo& dot_info,
    823     const TargetMachineFeatures& target_machine_features) {
    824   PrimitiveType element_type = dot_info.result_shape.element_type();
    825   // Any Matrix-Vector product of floating point or integral type, or
    826   // a transpose-dot fusion of the same can be lowered to a tiled LLVM
    827   // IR implementation.
    828   if (dot_info.result_shape.dimensions_size() == 2 &&
    829       (dot_info.result_shape.dimensions(0) == 1 ||
    830        dot_info.result_shape.dimensions(1) == 1) &&
    831       (primitive_util::IsFloatingPointType(element_type) ||
    832        primitive_util::IsIntegralType(element_type))) {
    833     return DotImplementationStrategy::kTiledLlvmIrGemv;
    834   }
    835 
    836   if (IsAlignedGemm(dot_info, target_machine_features)) {
    837     return CanEmitTiledLlvmIrGemm(config, dot_info, target_machine_features)
    838                ? DotImplementationStrategy::kTiledLlvmIrGemm
    839                : DotImplementationStrategy::kEigen;
    840   }
    841 
    842   return DotImplementationStrategy::kNaiveLlvmIr;
    843 }
    844 
    845 Status EmitNonBatchDotOperation(
    846     DotInfo dot_info, string hlo_name, const llvm_ir::IrArray& target_array,
    847     const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
    848     const llvm_ir::IrArray* addend_array,
    849     llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
    850     const HloModuleConfig& hlo_module_config,
    851     const TargetMachineFeatures& target_machine_features) {
    852   PrimitiveType type = target_array.GetShape().element_type();
    853   TF_RET_CHECK(F16 == type || F32 == type || F64 == type || C64 == type ||
    854                C128 == type);
    855   DotOpEmitter dot_emitter(std::move(dot_info), std::move(hlo_name),
    856                            target_array, lhs_array, rhs_array, addend_array,
    857                            executable_run_options_value, b, hlo_module_config,
    858                            target_machine_features);
    859   return dot_emitter.Emit();
    860 }
    861 
    862 Shape DropFirstDim(const Shape& shape) {
    863   absl::Span<int64 const> array_shape_dims(shape.dimensions());
    864   array_shape_dims.remove_prefix(1);
    865   return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
    866                                                   array_shape_dims);
    867 }
    868 
    869 Shape CollapseFirstNDims(const Shape& shape, int64 n) {
    870   absl::Span<int64 const> input_shape_dims(shape.dimensions());
    871   int64 prefix_dim =
    872       std::accumulate(input_shape_dims.begin(), input_shape_dims.begin() + n,
    873                       1ll, std::multiplies<int64>());
    874   DimensionVector result_dims;
    875   result_dims.push_back(prefix_dim);
    876   std::copy(input_shape_dims.begin() + n, input_shape_dims.end(),
    877             std::back_inserter(result_dims));
    878   return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
    879                                                   result_dims);
    880 }
    881 
    882 llvm_ir::IrArray CollapseFirstNDims(llvm::IRBuilder<>* b,
    883                                     const llvm_ir::IrArray& array, int64 n) {
    884   llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
    885   const Shape& shape = array.GetShape();
    886   CHECK(shape.has_layout() &&
    887         LayoutUtil::IsMonotonicWithDim0Major(shape.layout()));
    888   CHECK_GE(shape.dimensions_size(), n);
    889   Shape new_shape = CollapseFirstNDims(shape, n);
    890   llvm::Value* new_value = b->CreateBitCast(
    891       array.GetBasePointer(),
    892       llvm_ir::ShapeToIrType(new_shape, module)->getPointerTo());
    893   return llvm_ir::IrArray(new_value, std::move(new_shape));
    894 }
    895 
    896 Status ValidateDotDimensionNumbers(const DotDimensionNumbers& dim_numbers) {
    897   // Checks some invariants that do not hold in general, but DotDecomposer
    898   // should have established for us.  This is just a debugging aid.
    899   TF_RET_CHECK(dim_numbers.lhs_contracting_dimensions_size() == 1);
    900   std::vector<int64> batch_dim_numbers(dim_numbers.lhs_batch_dimensions_size());
    901   absl::c_iota(batch_dim_numbers, 0);
    902   TF_RET_CHECK(
    903       absl::c_equal(batch_dim_numbers, dim_numbers.lhs_batch_dimensions()));
    904   TF_RET_CHECK(
    905       absl::c_equal(batch_dim_numbers, dim_numbers.rhs_batch_dimensions()));
    906   return Status::OK();
    907 }
    908 
    909 // Slice out the inner array at batch index `batch_index` from `outer_array`.
    910 llvm_ir::IrArray SliceOutInnerArray(llvm_ir::IrArray outer_array,
    911                                     llvm::Value* batch_index,
    912                                     llvm::IRBuilder<>* b) {
    913   llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
    914 
    915   Shape inner_shape = DropFirstDim(outer_array.GetShape());
    916   std::vector<llvm::Value*> multidim_index(inner_shape.rank() + 1,
    917                                            b->getInt64(0));
    918   multidim_index[0] = batch_index;
    919   llvm_ir::IrArray::Index slice_index(multidim_index, outer_array.GetShape(),
    920                                       batch_index->getType());
    921   llvm::Value* slice_ptr = outer_array.EmitArrayElementAddress(slice_index, b);
    922   llvm::Type* slice_ptr_type =
    923       llvm_ir::ShapeToIrType(inner_shape, module)->getPointerTo();
    924   return llvm_ir::IrArray(b->CreateBitCast(slice_ptr, slice_ptr_type),
    925                           std::move(inner_shape));
    926 }
    927 
    928 Status EmitBatchDotOperation(
    929     const HloInstruction& dot, const llvm_ir::IrArray& target_array,
    930     const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
    931     llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
    932     const HloModuleConfig& hlo_module_config,
    933     const TargetMachineFeatures& target_machine_features) {
    934   TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(dot.dot_dimension_numbers()));
    935 
    936   // Lower a batch dot into a sequence of non-batch dot operations.
    937 
    938   int64 num_batch_dims =
    939       dot.dot_dimension_numbers().lhs_batch_dimensions_size();
    940 
    941   // First reshape the inputs to make sure we only have one batch dimension.
    942   // This is a no-op bitcast because the operands have to be in row-major layout
    943   // (enforced in CpuLayoutAssignment), and the batch dimensions are the leading
    944   // dimensions (established by DotDecomposer and checked by
    945   // ValidateDotDimensionNumbers above).
    946   llvm_ir::IrArray lhs_array_reshaped =
    947       CollapseFirstNDims(b, lhs_array, num_batch_dims);
    948   llvm_ir::IrArray rhs_array_reshaped =
    949       CollapseFirstNDims(b, rhs_array, num_batch_dims);
    950   llvm_ir::IrArray target_array_reshaped =
    951       CollapseFirstNDims(b, target_array, num_batch_dims);
    952 
    953   int64 batch_count = lhs_array_reshaped.GetShape().dimensions(0);
    954 
    955   KernelSupportLibrary ksl(b);
    956 
    957   return ksl.ForWithStatus(
    958       llvm_ir::IrName(&dot, "bdot"), /*start=*/0, /*end=*/batch_count,
    959       /*step=*/1, [&](llvm::Value* indvar) {
    960         DotDimensionNumbers adjusted_dim_numbers = dot.dot_dimension_numbers();
    961         adjusted_dim_numbers.clear_lhs_batch_dimensions();
    962         adjusted_dim_numbers.clear_rhs_batch_dimensions();
    963 
    964         // Create a DotInfo representing the "inner" non-batch dot operation.
    965         DotInfo dot_info;
    966         dot_info.lhs_shape = DropFirstDim(lhs_array_reshaped.GetShape());
    967         dot_info.rhs_shape = DropFirstDim(rhs_array_reshaped.GetShape());
    968         dot_info.result_shape = DropFirstDim(target_array_reshaped.GetShape());
    969         dot_info.dim_nums = dot.dot_dimension_numbers();
    970         dot_info.dim_nums.clear_lhs_batch_dimensions();
    971         dot_info.dim_nums.clear_rhs_batch_dimensions();
    972 
    973         dot_info.dim_nums.set_lhs_contracting_dimensions(
    974             0,
    975             dot_info.dim_nums.lhs_contracting_dimensions(0) - num_batch_dims);
    976         dot_info.dim_nums.set_rhs_contracting_dimensions(
    977             0,
    978             dot_info.dim_nums.rhs_contracting_dimensions(0) - num_batch_dims);
    979 
    980         llvm_ir::IrArray lhs_slice =
    981             SliceOutInnerArray(lhs_array_reshaped, /*batch_index=*/indvar, b);
    982         llvm_ir::IrArray rhs_slice =
    983             SliceOutInnerArray(rhs_array_reshaped, /*batch_index=*/indvar, b);
    984         llvm_ir::IrArray target_slice = SliceOutInnerArray(
    985             target_array_reshaped, /*batch_index=*/indvar, b);
    986 
    987         // Emit the inner non-batch dot operation.
    988         return EmitNonBatchDotOperation(
    989             dot_info, dot.name(), target_slice, lhs_slice, rhs_slice, nullptr,
    990             executable_run_options_value, b, hlo_module_config,
    991             target_machine_features);
    992       });
    993 }
    994 
    995 bool IsBatchDot(const HloInstruction& instr) {
    996   if (auto* dot_instr = DynCast<HloDotInstruction>(&instr)) {
    997     return dot_instr->dot_dimension_numbers().lhs_batch_dimensions_size() > 0;
    998   }
    999 
   1000   return false;
   1001 }
   1002 }  // namespace
   1003 
   1004 bool DotImplementationCanHandleTranspose(
   1005     const HloInstruction& dot_instr,
   1006     const TargetMachineFeatures& target_machine_features) {
   1007   DotImplementationStrategy impl_strategy =
   1008       GetDotImplementationStrategy(dot_instr.parent()->parent()->config(),
   1009                                    DotInfo(dot_instr), target_machine_features);
   1010 
   1011   // TODO(sanjoy): This is not quite right, it should be `impl_strategy ==
   1012   // kEigen || impl_strategy == kTiledLlvmIrGemv || impl_strategy ==
   1013   // kNaiveLlvmIr` but I'll fix this in a later CL in the interest of keeping
   1014   // the CL adding this comment NFC.
   1015   return impl_strategy == DotImplementationStrategy::kTiledLlvmIrGemm ||
   1016          impl_strategy == DotImplementationStrategy::kEigen;
   1017 }
   1018 
   1019 bool DotOperandsAndResultMustHaveRowMajorLayout(
   1020     const HloInstruction& dot_instr,
   1021     const TargetMachineFeatures& target_machine_features) {
   1022   DotImplementationStrategy impl_strategy =
   1023       GetDotImplementationStrategy(dot_instr.parent()->parent()->config(),
   1024                                    DotInfo(dot_instr), target_machine_features);
   1025 
   1026   return impl_strategy == DotImplementationStrategy::kTiledLlvmIrGemm ||
   1027          impl_strategy == DotImplementationStrategy::kEigen;
   1028 }
   1029 
   1030 Status EmitDotOperation(const HloInstruction& dot,
   1031                         const llvm_ir::IrArray& target_array,
   1032                         const llvm_ir::IrArray& lhs_array,
   1033                         const llvm_ir::IrArray& rhs_array,
   1034                         const llvm_ir::IrArray* addend_array,
   1035                         llvm::Value* executable_run_options_value,
   1036                         llvm::IRBuilder<>* b,
   1037                         const HloModuleConfig& hlo_module_config,
   1038                         const TargetMachineFeatures& target_machine_features) {
   1039   // This routine assumes that the dot operation is not in a parallelized
   1040   // enclosing computation.
   1041   CHECK(dot.parent()->root_instruction()->outer_dimension_partitions().empty());
   1042 
   1043   if (IsBatchDot(dot)) {
   1044     TF_RET_CHECK(addend_array == nullptr);
   1045     return EmitBatchDotOperation(dot, target_array, lhs_array, rhs_array,
   1046                                  executable_run_options_value, b,
   1047                                  hlo_module_config, target_machine_features);
   1048   }
   1049 
   1050   return EmitNonBatchDotOperation(DotInfo(dot), dot.name(), target_array,
   1051                                   lhs_array, rhs_array, addend_array,
   1052                                   executable_run_options_value, b,
   1053                                   hlo_module_config, target_machine_features);
   1054 }
   1055 }  // namespace cpu
   1056 }  // namespace xla
   1057