Home | History | Annotate | Download | only in gpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h"
     17 
     18 #include <string>
     19 #include <unordered_map>
     20 #include <utility>
     21 
     22 #include "tensorflow/core/platform/logging.h"
     23 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
     24 #include "absl/algorithm/container.h"
     25 #include "llvm/IR/BasicBlock.h"
     26 #include "llvm/IR/Constants.h"
     27 #include "llvm/IR/Instructions.h"
     28 #include "llvm/IR/Module.h"
     29 #include "tensorflow/compiler/xla/primitive_util.h"
     30 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
     31 #include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h"
     32 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h"
     33 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
     34 #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
     35 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     36 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
     37 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
     38 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
     39 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
     40 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
     41 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
     42 #include "tensorflow/compiler/xla/service/name_uniquer.h"
     43 #include "tensorflow/compiler/xla/shape_util.h"
     44 #include "tensorflow/compiler/xla/status_macros.h"
     45 #include "tensorflow/compiler/xla/types.h"
     46 #include "tensorflow/compiler/xla/util.h"
     47 #include "tensorflow/compiler/xla/window_util.h"
     48 #include "tensorflow/core/lib/core/errors.h"
     49 
     50 namespace xla {
     51 
     52 using llvm_ir::IrName;
     53 using llvm_ir::SetToFirstInsertPoint;
     54 
     55 namespace gpu {
     56 
     57 IrEmitter::IrEmitter(const HloModuleConfig& hlo_module_config,
     58                      IrEmitterContext* ir_emitter_context, bool is_nested)
     59     : ir_emitter_context_(ir_emitter_context),
     60       module_(ir_emitter_context->llvm_module()),
     61       b_(module_->getContext()),
     62       bindings_(ir_emitter_context->hlo_module(),
     63                 &ir_emitter_context->buffer_assignment(), &b_, module_,
     64                 is_nested),
     65       hlo_module_config_(hlo_module_config) {
     66 }
     67 
     68 Status IrEmitter::DefaultAction(HloInstruction* hlo) {
     69   ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
     70   for (const HloInstruction* operand : hlo->operands()) {
     71     operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
     72       return GetIrArray(*operand, *hlo).EmitReadArrayElement(index, &b_);
     73     };
     74   }
     75   return EmitTargetElementLoop(
     76       *hlo, GpuElementalIrEmitter(hlo_module_config_, module_, &b_,
     77                                   GetNestedComputer())
     78                 .MakeElementGenerator(hlo, operand_to_generator));
     79 }
     80 
     81 Status IrEmitter::HandleConstant(HloInstruction* constant) {
     82   return Status::OK();
     83 }
     84 
     85 Status IrEmitter::HandleBitcast(HloInstruction* bitcast) {
     86   VLOG(2) << "HandleBitcast: " << bitcast->ToString();
     87   const HloInstruction* operand = bitcast->operand(0);
     88   // Bitcast is a no-op, but we still want to bind it to an llvm::Value
     89   // sometimes, e.g., when it's operand is a constant or a bitcast of a
     90   // constant.
     91   if (bindings_.BoundToIrValue(*operand)) {
     92     bindings_.BindHloToIrValue(*bitcast, GetBasePointer(*operand));
     93   }
     94   return Status::OK();
     95 }
     96 
     97 Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) {
     98   VLOG(2) << "HandleAddDependency: " << add_dependency->ToString();
     99   const HloInstruction* operand = add_dependency->operand(0);
    100   // Add_Dependency is a no-op, but we still want to bind it to an llvm::Value
    101   // sometimes, e.g., when it's operand is a constant or a bitcast of a
    102   // constant.
    103   if (bindings_.BoundToIrValue(*operand)) {
    104     bindings_.BindHloToIrValue(*add_dependency, GetBasePointer(*operand));
    105   }
    106   return Status::OK();
    107 }
    108 
    109 Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) {
    110   auto operand = get_tuple_element->operand(0);
    111   CHECK(bindings_.BoundToIrValue(*operand));
    112   bindings_.BindHloToIrValue(
    113       *get_tuple_element,
    114       llvm_ir::EmitGetTupleElement(
    115           get_tuple_element->shape(), get_tuple_element->tuple_index(),
    116           // TODO(b/26344050): tighten the alignment here
    117           // based on the real element type.
    118           /*alignment=*/1, GetBasePointer(*operand), &b_));
    119   return Status::OK();
    120 }
    121 
    122 Status IrEmitter::HandleSend(HloInstruction*) {
    123   return Unimplemented("Send is not implemented on GPU");
    124 }
    125 
    126 Status IrEmitter::HandleSendDone(HloInstruction*) {
    127   return Unimplemented("Send-Done is not implemented on GPU");
    128 }
    129 
    130 Status IrEmitter::HandleRecv(HloInstruction*) {
    131   return Unimplemented("Recv is not implemented on GPU");
    132 }
    133 
    134 Status IrEmitter::HandleRecvDone(HloInstruction*) {
    135   return Unimplemented("Recv-done is not implemented on GPU");
    136 }
    137 
    138 Status IrEmitter::HandleScatter(HloInstruction*) {
    139   return Unimplemented("Scatter is not implemented on GPUs.");
    140 }
    141 
    142 Status IrEmitter::HandleTuple(HloInstruction* tuple) {
    143   std::vector<llvm::Value*> base_ptrs;
    144   for (const HloInstruction* operand : tuple->operands()) {
    145     base_ptrs.push_back(GetBasePointer(*operand));
    146   }
    147   llvm_ir::EmitTuple(GetIrArray(*tuple, *tuple), base_ptrs, &b_);
    148   return Status::OK();
    149 }
    150 
    151 Status IrEmitter::EmitCallToNestedComputation(
    152     const HloComputation& nested_computation,
    153     absl::Span<llvm::Value* const> operands, llvm::Value* output) {
    154   TF_RET_CHECK(nested_computation.num_parameters() > 0);
    155   llvm::Function*& emitted_function =
    156       computation_to_ir_function_[&nested_computation];
    157   if (emitted_function == nullptr) {
    158     IrEmitterNested ir_emitter_nested(hlo_module_config_, nested_computation,
    159                                       ir_emitter_context_);
    160     TF_RETURN_IF_ERROR(
    161         nested_computation.root_instruction()->Accept(&ir_emitter_nested));
    162     emitted_function = ir_emitter_nested.GetEmittedFunction();
    163   }
    164 
    165   std::vector<llvm::Value*> arguments(operands.begin(), operands.end());
    166   arguments.push_back(output);
    167   arguments.push_back(bindings_.GetTempBufferBase());
    168   Call(emitted_function, arguments);
    169 
    170   return Status::OK();
    171 }
    172 
    173 bool IrEmitter::MaybeEmitDirectAtomicOperation(
    174     const HloComputation& computation, llvm::Value* output_address,
    175     llvm::Value* source_address) {
    176   CHECK_EQ(2, computation.num_parameters());
    177 
    178   if (computation.instruction_count() != 3) {
    179     // We special-case only computations with one computing instruction for now.
    180     // Such computation has exactly three instructions given it has two
    181     // parameters.
    182     return false;
    183   }
    184 
    185   HloOpcode root_opcode = computation.root_instruction()->opcode();
    186   PrimitiveType element_type =
    187       computation.root_instruction()->shape().element_type();
    188   bool is_atomic_integral = element_type == S32 || element_type == U32 ||
    189                             element_type == S64 || element_type == U64;
    190   llvm::Value* source = Load(source_address, "source");
    191 
    192   // kCopy of RHS -> atomic store.
    193   if (root_opcode == HloOpcode::kCopy &&
    194       (element_type == F32 || is_atomic_integral) &&
    195       computation.root_instruction()->operand(0)->opcode() ==
    196           HloOpcode::kParameter &&
    197       computation.root_instruction()->operand(0)->parameter_number() == 1) {
    198     llvm::StoreInst* store = Store(source, output_address);
    199     store->setAtomic(llvm::AtomicOrdering::Unordered);
    200     // Derive a minimum alignment from the type. The optimizer can increase it
    201     // later.
    202     store->setAlignment(ShapeUtil::ByteSizeOfPrimitiveType(element_type));
    203     return true;
    204   }
    205 
    206   if (root_opcode == HloOpcode::kAdd) {
    207     // NVPTX supports atomicAdd on F32 and integer types.
    208     if (element_type == F32) {
    209       // F32 + F32
    210       llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_atomic_load_add_f32,
    211                                    {output_address, source},
    212                                    {output_address->getType()}, &b_);
    213       return true;
    214     }
    215     if (is_atomic_integral) {
    216       // integral + integral
    217       AtomicRMW(llvm::AtomicRMWInst::Add, output_address, source,
    218                 llvm::AtomicOrdering::SequentiallyConsistent);
    219       return true;
    220     }
    221   }
    222 
    223   // NVPTX supports atomicMax and atomicMin only on integer types.
    224   if (root_opcode == HloOpcode::kMaximum && is_atomic_integral) {
    225     // max(integral, integral)
    226     auto opcode = primitive_util::IsSignedIntegralType(element_type)
    227                       ? llvm::AtomicRMWInst::Max
    228                       : llvm::AtomicRMWInst::UMax;
    229     AtomicRMW(opcode, output_address, source,
    230               llvm::AtomicOrdering::SequentiallyConsistent);
    231     return true;
    232   }
    233 
    234   if (root_opcode == HloOpcode::kMinimum && is_atomic_integral) {
    235     // min(integral, integral)
    236     auto opcode = primitive_util::IsSignedIntegralType(element_type)
    237                       ? llvm::AtomicRMWInst::Min
    238                       : llvm::AtomicRMWInst::UMin;
    239     AtomicRMW(opcode, output_address, source,
    240               llvm::AtomicOrdering::SequentiallyConsistent);
    241     return true;
    242   }
    243 
    244   return false;
    245 }
    246 
    247 // Implements atomic binary operations using atomic compare-and-swap
    248 // (atomicCAS) as follows:
    249 //   1. Reads the value from the memory pointed to by output_address and
    250 //     records it as old_output.
    251 //   2. Uses old_output as one of the source operand to perform the binary
    252 //     operation and stores the result in new_output.
    253 //   3. Calls atomicCAS which implements compare-and-swap as an atomic
    254 //     operation. In particular, atomicCAS reads the value from the memory
    255 //     pointed to by output_address, and compares the value with old_output. If
    256 //     the two values equal, new_output is written to the same memory location
    257 //     and true is returned to indicate that the atomic operation succeeds.
    258 //     Otherwise, the new value read from the memory is returned. In this case,
    259 //     the new value is copied to old_output, and steps 2. and 3. are repeated
    260 //     until atomicCAS succeeds.
    261 //
    262 // On Nvidia GPUs, atomicCAS can only operate on 32 bit and 64 bit integers. If
    263 // the element type of the binary operation is 32 bits or 64 bits, the integer
    264 // type of the same size is used for the atomicCAS operation. On the other hand,
    265 // if the element type is smaller than 32 bits, int32 is used for the atomicCAS
    266 // operation. In this case, atomicCAS reads and writes 32 bit values from
    267 // the memory, which is larger than the memory size required by the original
    268 // atomic binary operation. We mask off the last two bits of the output_address
    269 // and use the result as an address to read the 32 bit values from the memory.
    270 // This can avoid out of bound memory accesses if tensor buffers are 4 byte
    271 // aligned and have a size of 4N, an assumption that the runtime can guarantee.
    272 //
    273 // The pseudo code is shown below. Variables *_address are pointers to a memory
    274 // region with a size equal to the size of the atomicCAS operation, with the
    275 // exception that new_output_address is a pointer to a memory region with a size
    276 // equal to the element size of the binary operation.
    277 //
    278 //   element_size = sizeof(element_type);
    279 //   atomic_size = max(32, element_size);
    280 //   cas_new_output_address = alloca(atomic_size);
    281 //   cas_old_output_address = alloca(atomic_size);
    282 //   if (atomic_size != element_size) {
    283 //     atomic_address = output_address & ((int64)(-4));
    284 //     new_output_address = cas_new_output_address + (output_address & 3);
    285 //   } else {
    286 //     atomic_address = output_address;
    287 //     new_output_address = cas_new_output_address;
    288 //   }
    289 //
    290 //   *cas_old_output_address = *atomic_address;
    291 //   do {
    292 //     *cas_new_output_address = *cas_old_output_address;
    293 //     *new_output_address = operation(*new_output_address, *source_address);
    294 //     (*cas_old_output_address, success) =
    295 //       atomicCAS(atomic_address, *cas_old_output_address,
    296 //       *cas_new_output_address);
    297 //   } while (!success);
    298 //
    299 Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation,
    300                                               llvm::Value* output_address,
    301                                               llvm::Value* source_address) {
    302   llvm::PointerType* output_address_type =
    303       llvm::dyn_cast<llvm::PointerType>(output_address->getType());
    304   CHECK_NE(output_address_type, nullptr);
    305 
    306   // element_type is the data type for the binary operation.
    307   llvm::Type* element_type = output_address_type->getPointerElementType();
    308   int element_size = llvm_ir::GetSizeInBits(element_type);
    309   llvm::Type* element_address_type = element_type->getPointerTo();
    310 
    311   int atomic_size = (element_size < 32) ? 32 : element_size;
    312   llvm::Type* atomic_type = b_.getIntNTy(atomic_size);
    313   llvm::Type* atomic_address_type =
    314       atomic_type->getPointerTo(output_address_type->getPointerAddressSpace());
    315 
    316   // cas_old_output_address and cas_new_output_address point to the scratch
    317   // memory where we store the old and new values for the repeated atomicCAS
    318   // operations.
    319   llvm::Value* cas_old_output_address =
    320       Alloca(atomic_type, /*ArraySize=*/nullptr, "cas_old_output_address");
    321   llvm::Value* cas_new_output_address =
    322       Alloca(atomic_type, /*ArraySize=*/nullptr, "cas_new_output_address");
    323 
    324   // Emit preparation code to the preheader.
    325   llvm::BasicBlock* loop_preheader_bb = b_.GetInsertBlock();
    326 
    327   llvm::Value* atomic_memory_address;
    328   // binop_output_address points to the scratch memory that stores the
    329   // result of the binary operation.
    330   llvm::Value* binop_output_address;
    331   if (element_size < 32) {
    332     // Assume the element size is an integer number of bytes.
    333     CHECK_EQ((element_size % sizeof(char)), 0);
    334     llvm::Type* address_int_type =
    335         module_->getDataLayout().getIntPtrType(output_address_type);
    336     atomic_memory_address = PtrToInt(output_address, address_int_type);
    337     llvm::Value* mask = llvm::ConstantInt::get(address_int_type, 3);
    338     llvm::Value* offset = And(atomic_memory_address, mask);
    339     mask = llvm::ConstantInt::get(address_int_type, -4);
    340     atomic_memory_address = And(atomic_memory_address, mask);
    341     atomic_memory_address =
    342         IntToPtr(atomic_memory_address, atomic_address_type);
    343     binop_output_address =
    344         Add(PtrToInt(cas_new_output_address, address_int_type), offset);
    345     binop_output_address = IntToPtr(binop_output_address, element_address_type);
    346   } else {
    347     atomic_memory_address = BitCast(output_address, atomic_address_type);
    348     binop_output_address =
    349         BitCast(cas_new_output_address, element_address_type);
    350   }
    351 
    352   // Use the value from the memory that atomicCAS operates on to initialize
    353   // cas_old_output.
    354   llvm::Value* cas_old_output = Load(atomic_memory_address, "cas_old_output");
    355   Store(cas_old_output, cas_old_output_address);
    356 
    357   llvm::BasicBlock* loop_exit_bb = loop_preheader_bb->splitBasicBlock(
    358       b_.GetInsertPoint(), "atomic_op_loop_exit");
    359   llvm::BasicBlock* loop_body_bb = llvm::BasicBlock::Create(
    360       b_.getContext(), "atomic_op_loop_body", b_.GetInsertBlock()->getParent());
    361   b_.SetInsertPoint(loop_body_bb);
    362   // Change preheader's successor from loop_exit_bb to loop_body_bb.
    363   loop_preheader_bb->getTerminator()->setSuccessor(0, loop_body_bb);
    364 
    365   // Emit the body of the loop that repeatedly invokes atomicCAS.
    366   //
    367   // Use cas_old_output to initialize cas_new_output.
    368   cas_old_output = Load(cas_old_output_address, "cas_old_output");
    369   Store(cas_old_output, cas_new_output_address);
    370   // Emits code to calculate new_output = operation(old_output, source);
    371   TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
    372       computation, {binop_output_address, source_address},
    373       binop_output_address));
    374 
    375   llvm::Value* cas_new_output = Load(cas_new_output_address, "cas_new_output");
    376 
    377   // Emit code to perform the atomicCAS operation
    378   // (cas_old_output, success) = atomicCAS(memory_address, cas_old_output,
    379   //                                       cas_new_output);
    380   llvm::Value* ret_value =
    381       AtomicCmpXchg(atomic_memory_address, cas_old_output, cas_new_output,
    382                     llvm::AtomicOrdering::SequentiallyConsistent,
    383                     llvm::AtomicOrdering::SequentiallyConsistent);
    384 
    385   // Extract the memory value returned from atomicCAS and store it as
    386   // cas_old_output.
    387   Store(ExtractValue(ret_value, 0, "cas_old_output"), cas_old_output_address);
    388   // Extract the success bit returned from atomicCAS and generate a
    389   // conditional branch on the success bit.
    390   CondBr(ExtractValue(ret_value, 1, "success"), loop_exit_bb, loop_body_bb);
    391 
    392   // Set the insertion point to the exit basic block so that the caller of
    393   // this method can continue emitting code to the right place.
    394   SetToFirstInsertPoint(loop_exit_bb, &b_);
    395   return Status::OK();
    396 }
    397 
    398 Status IrEmitter::EmitAtomicOperationForNestedComputation(
    399     const HloComputation& computation, llvm::Value* output_address,
    400     llvm::Value* source_address) {
    401   if (computation.num_parameters() != 2) {
    402     // TODO(b/30258929): We only accept binary computations so far.
    403     return Unimplemented(
    404         "We only support atomic functions with exactly two parameters, but "
    405         "computation %s has %d.",
    406         computation.name(), computation.num_parameters());
    407   }
    408 
    409   if (MaybeEmitDirectAtomicOperation(computation, output_address,
    410                                      source_address)) {
    411     return Status::OK();
    412   }
    413 
    414   return EmitAtomicOperationUsingCAS(computation, output_address,
    415                                      source_address);
    416 }
    417 
    418 Status IrEmitter::HandleSelect(HloInstruction* select) {
    419   auto pred = select->operand(0);
    420   TF_RET_CHECK(pred->shape().element_type() == PRED);
    421   // We must not call the subclass `DefaultAction` method, lest its
    422   // `HandleSelect` call `IrEmitter::HandleSelect` and its `DefaultAction`
    423   // assume no handler has already been called.
    424   return IrEmitter::DefaultAction(select);
    425 }
    426 
    427 Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) {
    428   auto pred = tuple_select->operand(0);
    429   auto on_true = tuple_select->operand(1);
    430   auto on_false = tuple_select->operand(2);
    431   TF_RET_CHECK(pred->shape().element_type() == PRED);
    432   TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()));
    433   TF_RET_CHECK(tuple_select->shape().IsTuple());
    434   llvm_ir::EmitTupleSelect(GetIrArray(*tuple_select, *tuple_select),
    435                            GetIrArray(*pred, *tuple_select),
    436                            GetBasePointer(*on_true), GetBasePointer(*on_false),
    437                            &b_);
    438   return Status::OK();
    439 }
    440 
    441 namespace {
    442 llvm::Value* Real(llvm::Value* x, llvm::IRBuilder<>* b) {
    443   return b->CreateExtractValue(x, {0});
    444 }
    445 
    446 llvm::Value* Imag(llvm::Value* x, llvm::IRBuilder<>* b) {
    447   return b->CreateExtractValue(x, {1});
    448 }
    449 
    450 std::pair<llvm::Value*, llvm::Value*> MultiplyComplex(llvm::Value* lhs_value,
    451                                                       llvm::Value* rhs_value,
    452                                                       llvm::IRBuilder<>* b) {
    453   llvm::Value* lhs_real = Real(lhs_value, b);
    454   llvm::Value* lhs_imag = Imag(lhs_value, b);
    455   llvm::Value* rhs_real = Real(rhs_value, b);
    456   llvm::Value* rhs_imag = Imag(rhs_value, b);
    457   llvm::Value* real_result1 = b->CreateFMul(lhs_real, rhs_real);
    458   llvm::Value* real_result2 = b->CreateFMul(lhs_imag, rhs_imag);
    459   llvm::Value* real_result = b->CreateFSub(real_result1, real_result2);
    460   llvm::Value* imag_result1 = b->CreateFMul(lhs_real, rhs_imag);
    461   llvm::Value* imag_result2 = b->CreateFMul(lhs_imag, rhs_real);
    462   llvm::Value* imag_result = b->CreateFAdd(imag_result1, imag_result2);
    463   return {real_result, imag_result};
    464 }
    465 }  // namespace
    466 
    467 Status IrEmitter::HandleDot(HloInstruction* dot) {
    468   auto lhs_instruction = dot->operand(0);
    469   auto rhs_instruction = dot->operand(1);
    470   const llvm_ir::IrArray& target_array = GetIrArray(*dot, *dot);
    471   const llvm_ir::IrArray& lhs_array = GetIrArray(*lhs_instruction, *dot);
    472   const llvm_ir::IrArray& rhs_array = GetIrArray(*rhs_instruction, *dot);
    473 
    474   const Shape& lhs_shape = lhs_instruction->shape();
    475   const Shape& rhs_shape = rhs_instruction->shape();
    476   const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
    477   CHECK_EQ(dnums.lhs_batch_dimensions_size(),
    478            dnums.rhs_batch_dimensions_size());
    479 
    480   // TODO(b/110211620): Convert to use i32 index_type when it is possible.
    481   llvm::Type* index_type = b_.getInt64Ty();
    482   llvm_ir::IrArray::Index element_index(index_type);
    483   if (ShapeUtil::IsScalar(lhs_shape) && ShapeUtil::IsScalar(rhs_shape)) {
    484     // If the operands are scalar, don't emit any loops.
    485     llvm::Value* lhs_value =
    486         lhs_array.EmitReadArrayElement(/*index=*/element_index, &b_);
    487     llvm::Value* rhs_value =
    488         rhs_array.EmitReadArrayElement(/*index=*/element_index, &b_);
    489     llvm::Value* result;
    490     if (ShapeUtil::ElementIsComplex(lhs_shape)) {
    491       auto value = MultiplyComplex(lhs_value, rhs_value, &b_);
    492       result = llvm::ConstantAggregateZero::get(lhs_array.GetElementLlvmType());
    493       result = InsertValue(result, value.first, {0});
    494       result = InsertValue(result, value.second, {1});
    495     } else if (ShapeUtil::ElementIsFloating(lhs_shape)) {
    496       result = FMul(lhs_value, rhs_value);
    497     } else {
    498       TF_RET_CHECK(ShapeUtil::ElementIsIntegral(lhs_shape));
    499       result = Mul(lhs_value, rhs_value);
    500     }
    501     target_array.EmitWriteArrayElement(/*index=*/element_index, result, &b_);
    502     return Status::OK();
    503   }
    504 
    505   // "Scalar dot non-scalar" or "non-scalar dot scalar" is invalid. See
    506   // the semantics of Dot in the XLA documentation for details.
    507   TF_RET_CHECK(!ShapeUtil::IsScalar(lhs_shape) &&
    508                !ShapeUtil::IsScalar(rhs_shape));
    509 
    510   const int64 lhs_reduction_dimension = dnums.lhs_contracting_dimensions(0);
    511   const int64 rhs_reduction_dimension = dnums.rhs_contracting_dimensions(0);
    512 
    513   // Check that the batch dims don't cover the reduction dimensions.
    514   for (int64 batch_dim : dnums.lhs_batch_dimensions()) {
    515     CHECK_NE(lhs_reduction_dimension, batch_dim);
    516     CHECK_NE(rhs_reduction_dimension, batch_dim);
    517   }
    518 
    519   // Verify the reduction dimension in the two operands are the same size.
    520   TF_RET_CHECK(lhs_shape.dimensions(lhs_reduction_dimension) ==
    521                rhs_shape.dimensions(rhs_reduction_dimension))
    522       << "lhs_shape.dimensions(" << lhs_reduction_dimension
    523       << ") = " << lhs_shape.dimensions(lhs_reduction_dimension)
    524       << ", and rhs_shape.dimensions(" << rhs_reduction_dimension
    525       << ") = " << rhs_shape.dimensions(rhs_reduction_dimension);
    526 
    527   // Create loop nests which loop through the LHS operand dimensions and the RHS
    528   // operand dimensions. The reduction dimension of the LHS and RHS are handled
    529   // in a separate innermost loop which performs the sum of products.
    530   llvm_ir::ForLoopNest loop_nest(IrName(dot), &b_);
    531   std::vector<llvm::Value*> lhs_multi_index =
    532       loop_nest.EmitOperandArrayLoopNest(
    533           lhs_array, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs");
    534   std::vector<llvm::Value*> rhs_multi_index =
    535       loop_nest.EmitOperandArrayLoopNest(
    536           rhs_array, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs");
    537 
    538   // We don't have to iterate over the batch dimensions in both arrays, simplify
    539   // the loop nest of the rhs.
    540   for (int i = 0; i != dnums.lhs_batch_dimensions_size(); ++i) {
    541     DCHECK(absl::c_linear_search(dnums.lhs_batch_dimensions(), i));
    542     rhs_multi_index[i] = lhs_multi_index[i];
    543   }
    544 
    545   // Create the reduction loop which does the sum of products reduction.
    546   std::unique_ptr<llvm_ir::ForLoop> reduction_loop = loop_nest.AddLoop(
    547       /*start_index=*/0,
    548       /*end_index=*/lhs_shape.dimensions(lhs_reduction_dimension),
    549       /*suffix=*/"reduction");
    550 
    551   // The final entry in the rhs and lhs indexes is the indvar of the reduction
    552   // loop.
    553   lhs_multi_index[lhs_reduction_dimension] = reduction_loop->GetIndVarValue();
    554   rhs_multi_index[rhs_reduction_dimension] = reduction_loop->GetIndVarValue();
    555 
    556   // For computing the sum of products we alloca a single location to store the
    557   // dot product result as we accumulate it within the reduction loop. After the
    558   // reduction loop we load the result and store into the output array.
    559   llvm::Type* accum_type = target_array.GetElementLlvmType();
    560   llvm::Value* accum_address = llvm_ir::EmitAllocaAtFunctionEntry(
    561       accum_type,       // The pointee type of the alloca instruction.
    562       "accum_address",  // The name of the alloca instruction.
    563       &b_);
    564 
    565   // Initialize the accumulator in the preheader to zero.
    566   new llvm::StoreInst(
    567       llvm::Constant::getNullValue(lhs_array.GetElementLlvmType()),  // init 0
    568       accum_address,  // The address.
    569       reduction_loop->GetPreheaderBasicBlock()
    570           ->getTerminator());  // The instruction this store is inserted before.
    571 
    572   // Emit the body of the reduction loop:
    573   //   accum = *accum_address
    574   //   updated_accum = accum + lhs_element * rhs_element
    575   //   *accum_address = updated_accum
    576   TF_RET_CHECK(!reduction_loop->GetBodyBasicBlock()->empty());
    577   b_.SetInsertPoint(
    578       &*reduction_loop->GetBodyBasicBlock()->getFirstInsertionPt());
    579   llvm_ir::IrArray::Index lhs_index(lhs_multi_index, lhs_array.GetShape(),
    580                                     b_.getInt64Ty());
    581   llvm::Value* lhs_element = lhs_array.EmitReadArrayElement(lhs_index, &b_);
    582   llvm_ir::IrArray::Index rhs_index(rhs_multi_index, rhs_array.GetShape(),
    583                                     b_.getInt64Ty());
    584   llvm::Value* rhs_element = rhs_array.EmitReadArrayElement(rhs_index, &b_);
    585   llvm::Value* accum = Load(accum_address);
    586   llvm::Value* updated_accum;
    587   if (ShapeUtil::ElementIsComplex(lhs_shape)) {
    588     auto value = MultiplyComplex(lhs_element, rhs_element, &b_);
    589     llvm::Value* accum_real = Real(accum, &b_);
    590     llvm::Value* real_sum = FAdd(accum_real, value.first);
    591     updated_accum = InsertValue(accum, real_sum, {0});
    592     llvm::Value* accum_imag = Imag(accum, &b_);
    593     llvm::Value* imag_sum = FAdd(accum_imag, value.second);
    594     updated_accum = InsertValue(updated_accum, imag_sum, {1});
    595   } else if (ShapeUtil::ElementIsFloating(lhs_shape)) {
    596     llvm::Value* product = FMul(lhs_element, rhs_element);
    597     updated_accum = FAdd(accum, product);
    598   } else {
    599     TF_RET_CHECK(ShapeUtil::ElementIsIntegral(lhs_shape));
    600     llvm::Value* product = Mul(lhs_element, rhs_element);
    601     updated_accum = Add(accum, product);
    602   }
    603   Store(updated_accum, accum_address);
    604 
    605   // After the reduction loop exits, store the accumulator into the target
    606   // address. The index into the target address is the concatenation of the rhs
    607   // and lhs indexes with the reduction dimensions removed. The terms from the
    608   // rhs index are the lower dimensions in the index so we add them first.
    609   std::vector<llvm::Value*> target_multi_index;
    610   for (size_t dimension = 0; dimension < lhs_index.size(); ++dimension) {
    611     if (dimension != lhs_reduction_dimension) {
    612       target_multi_index.push_back(lhs_index[dimension]);
    613     }
    614   }
    615   // Skip over the batch dimensions to not have them in the index twice.
    616   for (size_t dimension = dnums.lhs_batch_dimensions_size();
    617        dimension < rhs_index.size(); ++dimension) {
    618     if (dimension != rhs_reduction_dimension) {
    619       target_multi_index.push_back(rhs_index[dimension]);
    620     }
    621   }
    622   SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), &b_);
    623   llvm_ir::IrArray::Index target_index(target_multi_index,
    624                                        target_array.GetShape(), index_type);
    625   target_array.EmitWriteArrayElement(
    626       target_index,
    627       Load(accum_address),  // The value written to the target array.
    628       &b_);
    629 
    630   // Set the IR builder insert point to the exit basic block of the outer most
    631   // loop. This ensures later instructions are inserted after this loop nest.
    632   b_.SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock());
    633 
    634   return Status::OK();
    635 }
    636 
    637 Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
    638   if (ShapeUtil::IsZeroElementArray(convolution->shape())) {
    639     // Emit no code for an empty output.
    640     return Status::OK();
    641   }
    642   // TODO(b/31409998): Support convolution with dilation.
    643   return Unimplemented(
    644       "Hit a case for convolution that is not implemented on GPU.");
    645 }
    646 
    647 Status IrEmitter::HandleFft(HloInstruction* fft) {
    648   if (ShapeUtil::IsZeroElementArray(fft->shape())) {
    649     // Emit no code for an empty output.
    650     return Status::OK();
    651   }
    652   return Unimplemented("Hit a case for fft that is not implemented on GPU.");
    653 }
    654 
    655 Status IrEmitter::HandleAllReduce(HloInstruction* crs) {
    656   // TODO(b/33011107): Support cross replica sum on GPU.
    657   return Unimplemented("AllReduce is not implemented on GPU.");
    658 }
    659 
    660 Status IrEmitter::HandleParameter(HloInstruction* parameter) {
    661   return Status::OK();
    662 }
    663 
    664 Status IrEmitter::HandleReduce(HloInstruction* reduce) {
    665   // TODO(b/118332391): Support variadic reduce.
    666   if (!reduce->shape().IsArray()) {
    667     return Unimplemented("Variadic reduce is not supported on GPU");
    668   }
    669   auto arg = reduce->operand(0);
    670   auto init_value = reduce->operand(1);
    671   absl::Span<const int64> dimensions(reduce->dimensions());
    672   HloComputation* function = reduce->to_apply();
    673   return EmitTargetElementLoop(
    674       *reduce,
    675       [=](const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> {
    676         // Initialize an accumulator with init_value.
    677         llvm::AllocaInst* accumulator_addr =
    678             Alloca(llvm_ir::PrimitiveTypeToIrType(
    679                 reduce->shape().element_type(), module_));
    680         Store(Load(GetBasePointer(*init_value)), accumulator_addr);
    681 
    682         // The enclosing loops go over all the target elements. Now we have to
    683         // compute the actual target element. For this, we build a new loop nest
    684         // to iterate over all the reduction dimensions in the argument.
    685         // AddLoopsForShapeOnDimensions will return an Index where induction
    686         // Value*s are placed for each dimension in dimensions, and all the rest
    687         // are nullptrs.
    688         llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &b_);
    689         std::vector<llvm::Value*> input_multi_index =
    690             loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions,
    691                                                "reduction_dim");
    692 
    693         SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
    694 
    695         // Build a full index for the input argument, using reduced_dims_index
    696         // as the base. In reduced_dims_index only the reduction dimensions are
    697         // filled in. We fill in the rest of the dimensions with induction
    698         // Value*s taken from 'index' which iterates over the target array.
    699         // See the high-level description in the XLA documentation for details.
    700         llvm_ir::IrArray::Index::const_iterator it = index.begin();
    701 
    702         for (auto& i : input_multi_index) {
    703           if (i == nullptr) {
    704             i = *it++;
    705           }
    706         }
    707         CHECK(index.end() == it);
    708 
    709         // Apply the reduction function to the loaded value.
    710         llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(),
    711                                             b_.getInt64Ty());
    712         llvm::Value* input_address =
    713             GetIrArray(*arg, *reduce).EmitArrayElementAddress(input_index, &b_);
    714         TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
    715             *function, {accumulator_addr, input_address}, accumulator_addr));
    716 
    717         SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
    718         return Load(accumulator_addr);
    719       });
    720 }
    721 
    722 Status IrEmitter::HandleFusion(HloInstruction* fusion) {
    723   // kFusion for library calls should be handled by
    724   // IrEmitterUnnested::HandleFusion.
    725   CHECK_EQ(HloInstruction::FusionKind::kLoop, fusion->fusion_kind());
    726   GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_,
    727                                           GetNestedComputer());
    728   FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(fusion),
    729                                &elemental_emitter);
    730   TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter));
    731 
    732   return EmitTargetElementLoop(*fusion, fused_emitter.GetRootGenerator());
    733 }
    734 
    735 Status IrEmitter::HandleCall(HloInstruction* call) {
    736   std::vector<llvm::Value*> operand_addresses;
    737   for (HloInstruction* operand : call->operands()) {
    738     operand_addresses.push_back(GetBasePointer(*operand));
    739   }
    740   return EmitCallToNestedComputation(*call->to_apply(), operand_addresses,
    741                                      GetBasePointer(*call));
    742 }
    743 
    744 Status IrEmitter::HandleCustomCall(HloInstruction*) {
    745   return Unimplemented("custom-call");
    746 }
    747 
    748 Status IrEmitter::HandleInfeed(HloInstruction*) {
    749   // TODO(b/30467474): Implement infeed on GPU.
    750   return Unimplemented("Infeed is not supported on GPU.");
    751 }
    752 
    753 Status IrEmitter::HandleOutfeed(HloInstruction*) {
    754   // TODO(b/34359662): Implement outfeed on GPU.
    755   return Unimplemented("Outfeed is not supported on GPU.");
    756 }
    757 
    758 Status IrEmitter::HandleBatchNormInference(HloInstruction*) {
    759   return Unimplemented(
    760       "The GPU backend does not implement BatchNormInference directly.  It "
    761       "should be lowered before IR emission to HLO-soup using "
    762       "BatchNormRewriter or to a cudnn CustomCall using "
    763       "CudnnBatchNormRewriter.");
    764 }
    765 
    766 Status IrEmitter::HandleBatchNormTraining(HloInstruction*) {
    767   return Unimplemented(
    768       "The GPU backend does not implement BatchNormTraining directly.  It "
    769       "should be lowered before IR emission to HLO-soup using "
    770       "BatchNormRewriter or to a cudnn CustomCall using "
    771       "CudnnBatchNormRewriter.");
    772 }
    773 
    774 Status IrEmitter::HandleBatchNormGrad(HloInstruction*) {
    775   return Unimplemented(
    776       "The GPU backend does not implement BatchNormGrad directly.  It should "
    777       "be lowered before IR emission to HLO-soup (using BatchNormRewriter) or "
    778       "to a cudnn CustomCall using CudnnBatchNormRewriter.");
    779 }
    780 
    781 StatusOr<llvm::Value*> IrEmitter::ComputeNestedElement(
    782     const HloComputation& computation,
    783     absl::Span<llvm::Value* const> parameter_elements) {
    784   llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
    785       llvm_ir::PrimitiveTypeToIrType(
    786           computation.root_instruction()->shape().element_type(), module_),
    787       "return_buffer", &b_);
    788   std::vector<llvm::Value*> parameter_buffers;
    789   for (llvm::Value* parameter_element : parameter_elements) {
    790     parameter_buffers.push_back(llvm_ir::EmitAllocaAtFunctionEntry(
    791         parameter_element->getType(), "parameter_buffer", &b_));
    792     Store(parameter_element, parameter_buffers.back());
    793   }
    794   TF_RETURN_IF_ERROR(EmitCallToNestedComputation(computation, parameter_buffers,
    795                                                  return_buffer));
    796   return Load(return_buffer);
    797 }
    798 
    799 std::vector<llvm_ir::IrArray> IrEmitter::ConstructIrArrayForOutputs(
    800     const HloInstruction& hlo) {
    801   std::vector<llvm_ir::IrArray> output_arrays;
    802   if (hlo.shape().IsTuple()) {
    803     int64 num_outputs = ShapeUtil::TupleElementCount(hlo.shape());
    804     output_arrays.reserve(num_outputs);
    805     for (int64 i = 0; i < num_outputs; ++i) {
    806       output_arrays.push_back(GetIrArray(hlo, hlo, {i}));
    807     }
    808   } else {
    809     output_arrays.push_back(GetIrArray(hlo, hlo));
    810   }
    811   return output_arrays;
    812 }
    813 
    814 }  // namespace gpu
    815 }  // namespace xla
    816