1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" 17 18 #include <string> 19 #include <unordered_map> 20 #include <utility> 21 22 #include "tensorflow/core/platform/logging.h" 23 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" 24 #include "absl/algorithm/container.h" 25 #include "llvm/IR/BasicBlock.h" 26 #include "llvm/IR/Constants.h" 27 #include "llvm/IR/Instructions.h" 28 #include "llvm/IR/Module.h" 29 #include "tensorflow/compiler/xla/primitive_util.h" 30 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" 31 #include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h" 32 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h" 33 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" 34 #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" 35 #include "tensorflow/compiler/xla/service/hlo_computation.h" 36 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" 37 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" 38 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" 39 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" 40 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" 41 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" 42 #include "tensorflow/compiler/xla/service/name_uniquer.h" 43 #include "tensorflow/compiler/xla/shape_util.h" 44 #include "tensorflow/compiler/xla/status_macros.h" 45 #include "tensorflow/compiler/xla/types.h" 46 #include "tensorflow/compiler/xla/util.h" 47 #include "tensorflow/compiler/xla/window_util.h" 48 #include "tensorflow/core/lib/core/errors.h" 49 50 namespace xla { 51 52 using llvm_ir::IrName; 53 using llvm_ir::SetToFirstInsertPoint; 54 55 namespace gpu { 56 57 IrEmitter::IrEmitter(const HloModuleConfig& hlo_module_config, 58 IrEmitterContext* ir_emitter_context, bool is_nested) 59 : ir_emitter_context_(ir_emitter_context), 60 module_(ir_emitter_context->llvm_module()), 61 b_(module_->getContext()), 62 bindings_(ir_emitter_context->hlo_module(), 63 &ir_emitter_context->buffer_assignment(), &b_, module_, 64 is_nested), 65 hlo_module_config_(hlo_module_config) { 66 } 67 68 Status IrEmitter::DefaultAction(HloInstruction* hlo) { 69 ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; 70 for (const HloInstruction* operand : hlo->operands()) { 71 operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) { 72 return GetIrArray(*operand, *hlo).EmitReadArrayElement(index, &b_); 73 }; 74 } 75 return EmitTargetElementLoop( 76 *hlo, GpuElementalIrEmitter(hlo_module_config_, module_, &b_, 77 GetNestedComputer()) 78 .MakeElementGenerator(hlo, operand_to_generator)); 79 } 80 81 Status IrEmitter::HandleConstant(HloInstruction* constant) { 82 return Status::OK(); 83 } 84 85 Status IrEmitter::HandleBitcast(HloInstruction* bitcast) { 86 VLOG(2) << "HandleBitcast: " << bitcast->ToString(); 87 const HloInstruction* operand = bitcast->operand(0); 88 // Bitcast is a no-op, but we still want to bind it to an llvm::Value 89 // sometimes, e.g., when it's operand is a constant or a bitcast of a 90 // constant. 91 if (bindings_.BoundToIrValue(*operand)) { 92 bindings_.BindHloToIrValue(*bitcast, GetBasePointer(*operand)); 93 } 94 return Status::OK(); 95 } 96 97 Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) { 98 VLOG(2) << "HandleAddDependency: " << add_dependency->ToString(); 99 const HloInstruction* operand = add_dependency->operand(0); 100 // Add_Dependency is a no-op, but we still want to bind it to an llvm::Value 101 // sometimes, e.g., when it's operand is a constant or a bitcast of a 102 // constant. 103 if (bindings_.BoundToIrValue(*operand)) { 104 bindings_.BindHloToIrValue(*add_dependency, GetBasePointer(*operand)); 105 } 106 return Status::OK(); 107 } 108 109 Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) { 110 auto operand = get_tuple_element->operand(0); 111 CHECK(bindings_.BoundToIrValue(*operand)); 112 bindings_.BindHloToIrValue( 113 *get_tuple_element, 114 llvm_ir::EmitGetTupleElement( 115 get_tuple_element->shape(), get_tuple_element->tuple_index(), 116 // TODO(b/26344050): tighten the alignment here 117 // based on the real element type. 118 /*alignment=*/1, GetBasePointer(*operand), &b_)); 119 return Status::OK(); 120 } 121 122 Status IrEmitter::HandleSend(HloInstruction*) { 123 return Unimplemented("Send is not implemented on GPU"); 124 } 125 126 Status IrEmitter::HandleSendDone(HloInstruction*) { 127 return Unimplemented("Send-Done is not implemented on GPU"); 128 } 129 130 Status IrEmitter::HandleRecv(HloInstruction*) { 131 return Unimplemented("Recv is not implemented on GPU"); 132 } 133 134 Status IrEmitter::HandleRecvDone(HloInstruction*) { 135 return Unimplemented("Recv-done is not implemented on GPU"); 136 } 137 138 Status IrEmitter::HandleScatter(HloInstruction*) { 139 return Unimplemented("Scatter is not implemented on GPUs."); 140 } 141 142 Status IrEmitter::HandleTuple(HloInstruction* tuple) { 143 std::vector<llvm::Value*> base_ptrs; 144 for (const HloInstruction* operand : tuple->operands()) { 145 base_ptrs.push_back(GetBasePointer(*operand)); 146 } 147 llvm_ir::EmitTuple(GetIrArray(*tuple, *tuple), base_ptrs, &b_); 148 return Status::OK(); 149 } 150 151 Status IrEmitter::EmitCallToNestedComputation( 152 const HloComputation& nested_computation, 153 absl::Span<llvm::Value* const> operands, llvm::Value* output) { 154 TF_RET_CHECK(nested_computation.num_parameters() > 0); 155 llvm::Function*& emitted_function = 156 computation_to_ir_function_[&nested_computation]; 157 if (emitted_function == nullptr) { 158 IrEmitterNested ir_emitter_nested(hlo_module_config_, nested_computation, 159 ir_emitter_context_); 160 TF_RETURN_IF_ERROR( 161 nested_computation.root_instruction()->Accept(&ir_emitter_nested)); 162 emitted_function = ir_emitter_nested.GetEmittedFunction(); 163 } 164 165 std::vector<llvm::Value*> arguments(operands.begin(), operands.end()); 166 arguments.push_back(output); 167 arguments.push_back(bindings_.GetTempBufferBase()); 168 Call(emitted_function, arguments); 169 170 return Status::OK(); 171 } 172 173 bool IrEmitter::MaybeEmitDirectAtomicOperation( 174 const HloComputation& computation, llvm::Value* output_address, 175 llvm::Value* source_address) { 176 CHECK_EQ(2, computation.num_parameters()); 177 178 if (computation.instruction_count() != 3) { 179 // We special-case only computations with one computing instruction for now. 180 // Such computation has exactly three instructions given it has two 181 // parameters. 182 return false; 183 } 184 185 HloOpcode root_opcode = computation.root_instruction()->opcode(); 186 PrimitiveType element_type = 187 computation.root_instruction()->shape().element_type(); 188 bool is_atomic_integral = element_type == S32 || element_type == U32 || 189 element_type == S64 || element_type == U64; 190 llvm::Value* source = Load(source_address, "source"); 191 192 // kCopy of RHS -> atomic store. 193 if (root_opcode == HloOpcode::kCopy && 194 (element_type == F32 || is_atomic_integral) && 195 computation.root_instruction()->operand(0)->opcode() == 196 HloOpcode::kParameter && 197 computation.root_instruction()->operand(0)->parameter_number() == 1) { 198 llvm::StoreInst* store = Store(source, output_address); 199 store->setAtomic(llvm::AtomicOrdering::Unordered); 200 // Derive a minimum alignment from the type. The optimizer can increase it 201 // later. 202 store->setAlignment(ShapeUtil::ByteSizeOfPrimitiveType(element_type)); 203 return true; 204 } 205 206 if (root_opcode == HloOpcode::kAdd) { 207 // NVPTX supports atomicAdd on F32 and integer types. 208 if (element_type == F32) { 209 // F32 + F32 210 llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_atomic_load_add_f32, 211 {output_address, source}, 212 {output_address->getType()}, &b_); 213 return true; 214 } 215 if (is_atomic_integral) { 216 // integral + integral 217 AtomicRMW(llvm::AtomicRMWInst::Add, output_address, source, 218 llvm::AtomicOrdering::SequentiallyConsistent); 219 return true; 220 } 221 } 222 223 // NVPTX supports atomicMax and atomicMin only on integer types. 224 if (root_opcode == HloOpcode::kMaximum && is_atomic_integral) { 225 // max(integral, integral) 226 auto opcode = primitive_util::IsSignedIntegralType(element_type) 227 ? llvm::AtomicRMWInst::Max 228 : llvm::AtomicRMWInst::UMax; 229 AtomicRMW(opcode, output_address, source, 230 llvm::AtomicOrdering::SequentiallyConsistent); 231 return true; 232 } 233 234 if (root_opcode == HloOpcode::kMinimum && is_atomic_integral) { 235 // min(integral, integral) 236 auto opcode = primitive_util::IsSignedIntegralType(element_type) 237 ? llvm::AtomicRMWInst::Min 238 : llvm::AtomicRMWInst::UMin; 239 AtomicRMW(opcode, output_address, source, 240 llvm::AtomicOrdering::SequentiallyConsistent); 241 return true; 242 } 243 244 return false; 245 } 246 247 // Implements atomic binary operations using atomic compare-and-swap 248 // (atomicCAS) as follows: 249 // 1. Reads the value from the memory pointed to by output_address and 250 // records it as old_output. 251 // 2. Uses old_output as one of the source operand to perform the binary 252 // operation and stores the result in new_output. 253 // 3. Calls atomicCAS which implements compare-and-swap as an atomic 254 // operation. In particular, atomicCAS reads the value from the memory 255 // pointed to by output_address, and compares the value with old_output. If 256 // the two values equal, new_output is written to the same memory location 257 // and true is returned to indicate that the atomic operation succeeds. 258 // Otherwise, the new value read from the memory is returned. In this case, 259 // the new value is copied to old_output, and steps 2. and 3. are repeated 260 // until atomicCAS succeeds. 261 // 262 // On Nvidia GPUs, atomicCAS can only operate on 32 bit and 64 bit integers. If 263 // the element type of the binary operation is 32 bits or 64 bits, the integer 264 // type of the same size is used for the atomicCAS operation. On the other hand, 265 // if the element type is smaller than 32 bits, int32 is used for the atomicCAS 266 // operation. In this case, atomicCAS reads and writes 32 bit values from 267 // the memory, which is larger than the memory size required by the original 268 // atomic binary operation. We mask off the last two bits of the output_address 269 // and use the result as an address to read the 32 bit values from the memory. 270 // This can avoid out of bound memory accesses if tensor buffers are 4 byte 271 // aligned and have a size of 4N, an assumption that the runtime can guarantee. 272 // 273 // The pseudo code is shown below. Variables *_address are pointers to a memory 274 // region with a size equal to the size of the atomicCAS operation, with the 275 // exception that new_output_address is a pointer to a memory region with a size 276 // equal to the element size of the binary operation. 277 // 278 // element_size = sizeof(element_type); 279 // atomic_size = max(32, element_size); 280 // cas_new_output_address = alloca(atomic_size); 281 // cas_old_output_address = alloca(atomic_size); 282 // if (atomic_size != element_size) { 283 // atomic_address = output_address & ((int64)(-4)); 284 // new_output_address = cas_new_output_address + (output_address & 3); 285 // } else { 286 // atomic_address = output_address; 287 // new_output_address = cas_new_output_address; 288 // } 289 // 290 // *cas_old_output_address = *atomic_address; 291 // do { 292 // *cas_new_output_address = *cas_old_output_address; 293 // *new_output_address = operation(*new_output_address, *source_address); 294 // (*cas_old_output_address, success) = 295 // atomicCAS(atomic_address, *cas_old_output_address, 296 // *cas_new_output_address); 297 // } while (!success); 298 // 299 Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation, 300 llvm::Value* output_address, 301 llvm::Value* source_address) { 302 llvm::PointerType* output_address_type = 303 llvm::dyn_cast<llvm::PointerType>(output_address->getType()); 304 CHECK_NE(output_address_type, nullptr); 305 306 // element_type is the data type for the binary operation. 307 llvm::Type* element_type = output_address_type->getPointerElementType(); 308 int element_size = llvm_ir::GetSizeInBits(element_type); 309 llvm::Type* element_address_type = element_type->getPointerTo(); 310 311 int atomic_size = (element_size < 32) ? 32 : element_size; 312 llvm::Type* atomic_type = b_.getIntNTy(atomic_size); 313 llvm::Type* atomic_address_type = 314 atomic_type->getPointerTo(output_address_type->getPointerAddressSpace()); 315 316 // cas_old_output_address and cas_new_output_address point to the scratch 317 // memory where we store the old and new values for the repeated atomicCAS 318 // operations. 319 llvm::Value* cas_old_output_address = 320 Alloca(atomic_type, /*ArraySize=*/nullptr, "cas_old_output_address"); 321 llvm::Value* cas_new_output_address = 322 Alloca(atomic_type, /*ArraySize=*/nullptr, "cas_new_output_address"); 323 324 // Emit preparation code to the preheader. 325 llvm::BasicBlock* loop_preheader_bb = b_.GetInsertBlock(); 326 327 llvm::Value* atomic_memory_address; 328 // binop_output_address points to the scratch memory that stores the 329 // result of the binary operation. 330 llvm::Value* binop_output_address; 331 if (element_size < 32) { 332 // Assume the element size is an integer number of bytes. 333 CHECK_EQ((element_size % sizeof(char)), 0); 334 llvm::Type* address_int_type = 335 module_->getDataLayout().getIntPtrType(output_address_type); 336 atomic_memory_address = PtrToInt(output_address, address_int_type); 337 llvm::Value* mask = llvm::ConstantInt::get(address_int_type, 3); 338 llvm::Value* offset = And(atomic_memory_address, mask); 339 mask = llvm::ConstantInt::get(address_int_type, -4); 340 atomic_memory_address = And(atomic_memory_address, mask); 341 atomic_memory_address = 342 IntToPtr(atomic_memory_address, atomic_address_type); 343 binop_output_address = 344 Add(PtrToInt(cas_new_output_address, address_int_type), offset); 345 binop_output_address = IntToPtr(binop_output_address, element_address_type); 346 } else { 347 atomic_memory_address = BitCast(output_address, atomic_address_type); 348 binop_output_address = 349 BitCast(cas_new_output_address, element_address_type); 350 } 351 352 // Use the value from the memory that atomicCAS operates on to initialize 353 // cas_old_output. 354 llvm::Value* cas_old_output = Load(atomic_memory_address, "cas_old_output"); 355 Store(cas_old_output, cas_old_output_address); 356 357 llvm::BasicBlock* loop_exit_bb = loop_preheader_bb->splitBasicBlock( 358 b_.GetInsertPoint(), "atomic_op_loop_exit"); 359 llvm::BasicBlock* loop_body_bb = llvm::BasicBlock::Create( 360 b_.getContext(), "atomic_op_loop_body", b_.GetInsertBlock()->getParent()); 361 b_.SetInsertPoint(loop_body_bb); 362 // Change preheader's successor from loop_exit_bb to loop_body_bb. 363 loop_preheader_bb->getTerminator()->setSuccessor(0, loop_body_bb); 364 365 // Emit the body of the loop that repeatedly invokes atomicCAS. 366 // 367 // Use cas_old_output to initialize cas_new_output. 368 cas_old_output = Load(cas_old_output_address, "cas_old_output"); 369 Store(cas_old_output, cas_new_output_address); 370 // Emits code to calculate new_output = operation(old_output, source); 371 TF_RETURN_IF_ERROR(EmitCallToNestedComputation( 372 computation, {binop_output_address, source_address}, 373 binop_output_address)); 374 375 llvm::Value* cas_new_output = Load(cas_new_output_address, "cas_new_output"); 376 377 // Emit code to perform the atomicCAS operation 378 // (cas_old_output, success) = atomicCAS(memory_address, cas_old_output, 379 // cas_new_output); 380 llvm::Value* ret_value = 381 AtomicCmpXchg(atomic_memory_address, cas_old_output, cas_new_output, 382 llvm::AtomicOrdering::SequentiallyConsistent, 383 llvm::AtomicOrdering::SequentiallyConsistent); 384 385 // Extract the memory value returned from atomicCAS and store it as 386 // cas_old_output. 387 Store(ExtractValue(ret_value, 0, "cas_old_output"), cas_old_output_address); 388 // Extract the success bit returned from atomicCAS and generate a 389 // conditional branch on the success bit. 390 CondBr(ExtractValue(ret_value, 1, "success"), loop_exit_bb, loop_body_bb); 391 392 // Set the insertion point to the exit basic block so that the caller of 393 // this method can continue emitting code to the right place. 394 SetToFirstInsertPoint(loop_exit_bb, &b_); 395 return Status::OK(); 396 } 397 398 Status IrEmitter::EmitAtomicOperationForNestedComputation( 399 const HloComputation& computation, llvm::Value* output_address, 400 llvm::Value* source_address) { 401 if (computation.num_parameters() != 2) { 402 // TODO(b/30258929): We only accept binary computations so far. 403 return Unimplemented( 404 "We only support atomic functions with exactly two parameters, but " 405 "computation %s has %d.", 406 computation.name(), computation.num_parameters()); 407 } 408 409 if (MaybeEmitDirectAtomicOperation(computation, output_address, 410 source_address)) { 411 return Status::OK(); 412 } 413 414 return EmitAtomicOperationUsingCAS(computation, output_address, 415 source_address); 416 } 417 418 Status IrEmitter::HandleSelect(HloInstruction* select) { 419 auto pred = select->operand(0); 420 TF_RET_CHECK(pred->shape().element_type() == PRED); 421 // We must not call the subclass `DefaultAction` method, lest its 422 // `HandleSelect` call `IrEmitter::HandleSelect` and its `DefaultAction` 423 // assume no handler has already been called. 424 return IrEmitter::DefaultAction(select); 425 } 426 427 Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) { 428 auto pred = tuple_select->operand(0); 429 auto on_true = tuple_select->operand(1); 430 auto on_false = tuple_select->operand(2); 431 TF_RET_CHECK(pred->shape().element_type() == PRED); 432 TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape())); 433 TF_RET_CHECK(tuple_select->shape().IsTuple()); 434 llvm_ir::EmitTupleSelect(GetIrArray(*tuple_select, *tuple_select), 435 GetIrArray(*pred, *tuple_select), 436 GetBasePointer(*on_true), GetBasePointer(*on_false), 437 &b_); 438 return Status::OK(); 439 } 440 441 namespace { 442 llvm::Value* Real(llvm::Value* x, llvm::IRBuilder<>* b) { 443 return b->CreateExtractValue(x, {0}); 444 } 445 446 llvm::Value* Imag(llvm::Value* x, llvm::IRBuilder<>* b) { 447 return b->CreateExtractValue(x, {1}); 448 } 449 450 std::pair<llvm::Value*, llvm::Value*> MultiplyComplex(llvm::Value* lhs_value, 451 llvm::Value* rhs_value, 452 llvm::IRBuilder<>* b) { 453 llvm::Value* lhs_real = Real(lhs_value, b); 454 llvm::Value* lhs_imag = Imag(lhs_value, b); 455 llvm::Value* rhs_real = Real(rhs_value, b); 456 llvm::Value* rhs_imag = Imag(rhs_value, b); 457 llvm::Value* real_result1 = b->CreateFMul(lhs_real, rhs_real); 458 llvm::Value* real_result2 = b->CreateFMul(lhs_imag, rhs_imag); 459 llvm::Value* real_result = b->CreateFSub(real_result1, real_result2); 460 llvm::Value* imag_result1 = b->CreateFMul(lhs_real, rhs_imag); 461 llvm::Value* imag_result2 = b->CreateFMul(lhs_imag, rhs_real); 462 llvm::Value* imag_result = b->CreateFAdd(imag_result1, imag_result2); 463 return {real_result, imag_result}; 464 } 465 } // namespace 466 467 Status IrEmitter::HandleDot(HloInstruction* dot) { 468 auto lhs_instruction = dot->operand(0); 469 auto rhs_instruction = dot->operand(1); 470 const llvm_ir::IrArray& target_array = GetIrArray(*dot, *dot); 471 const llvm_ir::IrArray& lhs_array = GetIrArray(*lhs_instruction, *dot); 472 const llvm_ir::IrArray& rhs_array = GetIrArray(*rhs_instruction, *dot); 473 474 const Shape& lhs_shape = lhs_instruction->shape(); 475 const Shape& rhs_shape = rhs_instruction->shape(); 476 const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); 477 CHECK_EQ(dnums.lhs_batch_dimensions_size(), 478 dnums.rhs_batch_dimensions_size()); 479 480 // TODO(b/110211620): Convert to use i32 index_type when it is possible. 481 llvm::Type* index_type = b_.getInt64Ty(); 482 llvm_ir::IrArray::Index element_index(index_type); 483 if (ShapeUtil::IsScalar(lhs_shape) && ShapeUtil::IsScalar(rhs_shape)) { 484 // If the operands are scalar, don't emit any loops. 485 llvm::Value* lhs_value = 486 lhs_array.EmitReadArrayElement(/*index=*/element_index, &b_); 487 llvm::Value* rhs_value = 488 rhs_array.EmitReadArrayElement(/*index=*/element_index, &b_); 489 llvm::Value* result; 490 if (ShapeUtil::ElementIsComplex(lhs_shape)) { 491 auto value = MultiplyComplex(lhs_value, rhs_value, &b_); 492 result = llvm::ConstantAggregateZero::get(lhs_array.GetElementLlvmType()); 493 result = InsertValue(result, value.first, {0}); 494 result = InsertValue(result, value.second, {1}); 495 } else if (ShapeUtil::ElementIsFloating(lhs_shape)) { 496 result = FMul(lhs_value, rhs_value); 497 } else { 498 TF_RET_CHECK(ShapeUtil::ElementIsIntegral(lhs_shape)); 499 result = Mul(lhs_value, rhs_value); 500 } 501 target_array.EmitWriteArrayElement(/*index=*/element_index, result, &b_); 502 return Status::OK(); 503 } 504 505 // "Scalar dot non-scalar" or "non-scalar dot scalar" is invalid. See 506 // the semantics of Dot in the XLA documentation for details. 507 TF_RET_CHECK(!ShapeUtil::IsScalar(lhs_shape) && 508 !ShapeUtil::IsScalar(rhs_shape)); 509 510 const int64 lhs_reduction_dimension = dnums.lhs_contracting_dimensions(0); 511 const int64 rhs_reduction_dimension = dnums.rhs_contracting_dimensions(0); 512 513 // Check that the batch dims don't cover the reduction dimensions. 514 for (int64 batch_dim : dnums.lhs_batch_dimensions()) { 515 CHECK_NE(lhs_reduction_dimension, batch_dim); 516 CHECK_NE(rhs_reduction_dimension, batch_dim); 517 } 518 519 // Verify the reduction dimension in the two operands are the same size. 520 TF_RET_CHECK(lhs_shape.dimensions(lhs_reduction_dimension) == 521 rhs_shape.dimensions(rhs_reduction_dimension)) 522 << "lhs_shape.dimensions(" << lhs_reduction_dimension 523 << ") = " << lhs_shape.dimensions(lhs_reduction_dimension) 524 << ", and rhs_shape.dimensions(" << rhs_reduction_dimension 525 << ") = " << rhs_shape.dimensions(rhs_reduction_dimension); 526 527 // Create loop nests which loop through the LHS operand dimensions and the RHS 528 // operand dimensions. The reduction dimension of the LHS and RHS are handled 529 // in a separate innermost loop which performs the sum of products. 530 llvm_ir::ForLoopNest loop_nest(IrName(dot), &b_); 531 std::vector<llvm::Value*> lhs_multi_index = 532 loop_nest.EmitOperandArrayLoopNest( 533 lhs_array, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs"); 534 std::vector<llvm::Value*> rhs_multi_index = 535 loop_nest.EmitOperandArrayLoopNest( 536 rhs_array, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs"); 537 538 // We don't have to iterate over the batch dimensions in both arrays, simplify 539 // the loop nest of the rhs. 540 for (int i = 0; i != dnums.lhs_batch_dimensions_size(); ++i) { 541 DCHECK(absl::c_linear_search(dnums.lhs_batch_dimensions(), i)); 542 rhs_multi_index[i] = lhs_multi_index[i]; 543 } 544 545 // Create the reduction loop which does the sum of products reduction. 546 std::unique_ptr<llvm_ir::ForLoop> reduction_loop = loop_nest.AddLoop( 547 /*start_index=*/0, 548 /*end_index=*/lhs_shape.dimensions(lhs_reduction_dimension), 549 /*suffix=*/"reduction"); 550 551 // The final entry in the rhs and lhs indexes is the indvar of the reduction 552 // loop. 553 lhs_multi_index[lhs_reduction_dimension] = reduction_loop->GetIndVarValue(); 554 rhs_multi_index[rhs_reduction_dimension] = reduction_loop->GetIndVarValue(); 555 556 // For computing the sum of products we alloca a single location to store the 557 // dot product result as we accumulate it within the reduction loop. After the 558 // reduction loop we load the result and store into the output array. 559 llvm::Type* accum_type = target_array.GetElementLlvmType(); 560 llvm::Value* accum_address = llvm_ir::EmitAllocaAtFunctionEntry( 561 accum_type, // The pointee type of the alloca instruction. 562 "accum_address", // The name of the alloca instruction. 563 &b_); 564 565 // Initialize the accumulator in the preheader to zero. 566 new llvm::StoreInst( 567 llvm::Constant::getNullValue(lhs_array.GetElementLlvmType()), // init 0 568 accum_address, // The address. 569 reduction_loop->GetPreheaderBasicBlock() 570 ->getTerminator()); // The instruction this store is inserted before. 571 572 // Emit the body of the reduction loop: 573 // accum = *accum_address 574 // updated_accum = accum + lhs_element * rhs_element 575 // *accum_address = updated_accum 576 TF_RET_CHECK(!reduction_loop->GetBodyBasicBlock()->empty()); 577 b_.SetInsertPoint( 578 &*reduction_loop->GetBodyBasicBlock()->getFirstInsertionPt()); 579 llvm_ir::IrArray::Index lhs_index(lhs_multi_index, lhs_array.GetShape(), 580 b_.getInt64Ty()); 581 llvm::Value* lhs_element = lhs_array.EmitReadArrayElement(lhs_index, &b_); 582 llvm_ir::IrArray::Index rhs_index(rhs_multi_index, rhs_array.GetShape(), 583 b_.getInt64Ty()); 584 llvm::Value* rhs_element = rhs_array.EmitReadArrayElement(rhs_index, &b_); 585 llvm::Value* accum = Load(accum_address); 586 llvm::Value* updated_accum; 587 if (ShapeUtil::ElementIsComplex(lhs_shape)) { 588 auto value = MultiplyComplex(lhs_element, rhs_element, &b_); 589 llvm::Value* accum_real = Real(accum, &b_); 590 llvm::Value* real_sum = FAdd(accum_real, value.first); 591 updated_accum = InsertValue(accum, real_sum, {0}); 592 llvm::Value* accum_imag = Imag(accum, &b_); 593 llvm::Value* imag_sum = FAdd(accum_imag, value.second); 594 updated_accum = InsertValue(updated_accum, imag_sum, {1}); 595 } else if (ShapeUtil::ElementIsFloating(lhs_shape)) { 596 llvm::Value* product = FMul(lhs_element, rhs_element); 597 updated_accum = FAdd(accum, product); 598 } else { 599 TF_RET_CHECK(ShapeUtil::ElementIsIntegral(lhs_shape)); 600 llvm::Value* product = Mul(lhs_element, rhs_element); 601 updated_accum = Add(accum, product); 602 } 603 Store(updated_accum, accum_address); 604 605 // After the reduction loop exits, store the accumulator into the target 606 // address. The index into the target address is the concatenation of the rhs 607 // and lhs indexes with the reduction dimensions removed. The terms from the 608 // rhs index are the lower dimensions in the index so we add them first. 609 std::vector<llvm::Value*> target_multi_index; 610 for (size_t dimension = 0; dimension < lhs_index.size(); ++dimension) { 611 if (dimension != lhs_reduction_dimension) { 612 target_multi_index.push_back(lhs_index[dimension]); 613 } 614 } 615 // Skip over the batch dimensions to not have them in the index twice. 616 for (size_t dimension = dnums.lhs_batch_dimensions_size(); 617 dimension < rhs_index.size(); ++dimension) { 618 if (dimension != rhs_reduction_dimension) { 619 target_multi_index.push_back(rhs_index[dimension]); 620 } 621 } 622 SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), &b_); 623 llvm_ir::IrArray::Index target_index(target_multi_index, 624 target_array.GetShape(), index_type); 625 target_array.EmitWriteArrayElement( 626 target_index, 627 Load(accum_address), // The value written to the target array. 628 &b_); 629 630 // Set the IR builder insert point to the exit basic block of the outer most 631 // loop. This ensures later instructions are inserted after this loop nest. 632 b_.SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock()); 633 634 return Status::OK(); 635 } 636 637 Status IrEmitter::HandleConvolution(HloInstruction* convolution) { 638 if (ShapeUtil::IsZeroElementArray(convolution->shape())) { 639 // Emit no code for an empty output. 640 return Status::OK(); 641 } 642 // TODO(b/31409998): Support convolution with dilation. 643 return Unimplemented( 644 "Hit a case for convolution that is not implemented on GPU."); 645 } 646 647 Status IrEmitter::HandleFft(HloInstruction* fft) { 648 if (ShapeUtil::IsZeroElementArray(fft->shape())) { 649 // Emit no code for an empty output. 650 return Status::OK(); 651 } 652 return Unimplemented("Hit a case for fft that is not implemented on GPU."); 653 } 654 655 Status IrEmitter::HandleAllReduce(HloInstruction* crs) { 656 // TODO(b/33011107): Support cross replica sum on GPU. 657 return Unimplemented("AllReduce is not implemented on GPU."); 658 } 659 660 Status IrEmitter::HandleParameter(HloInstruction* parameter) { 661 return Status::OK(); 662 } 663 664 Status IrEmitter::HandleReduce(HloInstruction* reduce) { 665 // TODO(b/118332391): Support variadic reduce. 666 if (!reduce->shape().IsArray()) { 667 return Unimplemented("Variadic reduce is not supported on GPU"); 668 } 669 auto arg = reduce->operand(0); 670 auto init_value = reduce->operand(1); 671 absl::Span<const int64> dimensions(reduce->dimensions()); 672 HloComputation* function = reduce->to_apply(); 673 return EmitTargetElementLoop( 674 *reduce, 675 [=](const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> { 676 // Initialize an accumulator with init_value. 677 llvm::AllocaInst* accumulator_addr = 678 Alloca(llvm_ir::PrimitiveTypeToIrType( 679 reduce->shape().element_type(), module_)); 680 Store(Load(GetBasePointer(*init_value)), accumulator_addr); 681 682 // The enclosing loops go over all the target elements. Now we have to 683 // compute the actual target element. For this, we build a new loop nest 684 // to iterate over all the reduction dimensions in the argument. 685 // AddLoopsForShapeOnDimensions will return an Index where induction 686 // Value*s are placed for each dimension in dimensions, and all the rest 687 // are nullptrs. 688 llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &b_); 689 std::vector<llvm::Value*> input_multi_index = 690 loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions, 691 "reduction_dim"); 692 693 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); 694 695 // Build a full index for the input argument, using reduced_dims_index 696 // as the base. In reduced_dims_index only the reduction dimensions are 697 // filled in. We fill in the rest of the dimensions with induction 698 // Value*s taken from 'index' which iterates over the target array. 699 // See the high-level description in the XLA documentation for details. 700 llvm_ir::IrArray::Index::const_iterator it = index.begin(); 701 702 for (auto& i : input_multi_index) { 703 if (i == nullptr) { 704 i = *it++; 705 } 706 } 707 CHECK(index.end() == it); 708 709 // Apply the reduction function to the loaded value. 710 llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(), 711 b_.getInt64Ty()); 712 llvm::Value* input_address = 713 GetIrArray(*arg, *reduce).EmitArrayElementAddress(input_index, &b_); 714 TF_RETURN_IF_ERROR(EmitCallToNestedComputation( 715 *function, {accumulator_addr, input_address}, accumulator_addr)); 716 717 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); 718 return Load(accumulator_addr); 719 }); 720 } 721 722 Status IrEmitter::HandleFusion(HloInstruction* fusion) { 723 // kFusion for library calls should be handled by 724 // IrEmitterUnnested::HandleFusion. 725 CHECK_EQ(HloInstruction::FusionKind::kLoop, fusion->fusion_kind()); 726 GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_, 727 GetNestedComputer()); 728 FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(fusion), 729 &elemental_emitter); 730 TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter)); 731 732 return EmitTargetElementLoop(*fusion, fused_emitter.GetRootGenerator()); 733 } 734 735 Status IrEmitter::HandleCall(HloInstruction* call) { 736 std::vector<llvm::Value*> operand_addresses; 737 for (HloInstruction* operand : call->operands()) { 738 operand_addresses.push_back(GetBasePointer(*operand)); 739 } 740 return EmitCallToNestedComputation(*call->to_apply(), operand_addresses, 741 GetBasePointer(*call)); 742 } 743 744 Status IrEmitter::HandleCustomCall(HloInstruction*) { 745 return Unimplemented("custom-call"); 746 } 747 748 Status IrEmitter::HandleInfeed(HloInstruction*) { 749 // TODO(b/30467474): Implement infeed on GPU. 750 return Unimplemented("Infeed is not supported on GPU."); 751 } 752 753 Status IrEmitter::HandleOutfeed(HloInstruction*) { 754 // TODO(b/34359662): Implement outfeed on GPU. 755 return Unimplemented("Outfeed is not supported on GPU."); 756 } 757 758 Status IrEmitter::HandleBatchNormInference(HloInstruction*) { 759 return Unimplemented( 760 "The GPU backend does not implement BatchNormInference directly. It " 761 "should be lowered before IR emission to HLO-soup using " 762 "BatchNormRewriter or to a cudnn CustomCall using " 763 "CudnnBatchNormRewriter."); 764 } 765 766 Status IrEmitter::HandleBatchNormTraining(HloInstruction*) { 767 return Unimplemented( 768 "The GPU backend does not implement BatchNormTraining directly. It " 769 "should be lowered before IR emission to HLO-soup using " 770 "BatchNormRewriter or to a cudnn CustomCall using " 771 "CudnnBatchNormRewriter."); 772 } 773 774 Status IrEmitter::HandleBatchNormGrad(HloInstruction*) { 775 return Unimplemented( 776 "The GPU backend does not implement BatchNormGrad directly. It should " 777 "be lowered before IR emission to HLO-soup (using BatchNormRewriter) or " 778 "to a cudnn CustomCall using CudnnBatchNormRewriter."); 779 } 780 781 StatusOr<llvm::Value*> IrEmitter::ComputeNestedElement( 782 const HloComputation& computation, 783 absl::Span<llvm::Value* const> parameter_elements) { 784 llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry( 785 llvm_ir::PrimitiveTypeToIrType( 786 computation.root_instruction()->shape().element_type(), module_), 787 "return_buffer", &b_); 788 std::vector<llvm::Value*> parameter_buffers; 789 for (llvm::Value* parameter_element : parameter_elements) { 790 parameter_buffers.push_back(llvm_ir::EmitAllocaAtFunctionEntry( 791 parameter_element->getType(), "parameter_buffer", &b_)); 792 Store(parameter_element, parameter_buffers.back()); 793 } 794 TF_RETURN_IF_ERROR(EmitCallToNestedComputation(computation, parameter_buffers, 795 return_buffer)); 796 return Load(return_buffer); 797 } 798 799 std::vector<llvm_ir::IrArray> IrEmitter::ConstructIrArrayForOutputs( 800 const HloInstruction& hlo) { 801 std::vector<llvm_ir::IrArray> output_arrays; 802 if (hlo.shape().IsTuple()) { 803 int64 num_outputs = ShapeUtil::TupleElementCount(hlo.shape()); 804 output_arrays.reserve(num_outputs); 805 for (int64 i = 0; i < num_outputs; ++i) { 806 output_arrays.push_back(GetIrArray(hlo, hlo, {i})); 807 } 808 } else { 809 output_arrays.push_back(GetIrArray(hlo, hlo)); 810 } 811 return output_arrays; 812 } 813 814 } // namespace gpu 815 } // namespace xla 816