1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <memory> 17 #include <string> 18 #include <vector> 19 20 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" 21 22 #include "llvm/ADT/StringRef.h" 23 #include "llvm/IR/BasicBlock.h" 24 #include "llvm/IR/Function.h" 25 #include "llvm/IR/IRBuilder.h" 26 #include "llvm/IR/Instructions.h" 27 #include "llvm/IR/LLVMContext.h" 28 #include "llvm/IR/Module.h" 29 #include "tensorflow/compiler/xla/literal_util.h" 30 #include "tensorflow/compiler/xla/ptr_util.h" 31 #include "tensorflow/compiler/xla/service/buffer_assignment.h" 32 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" 33 #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h" 34 #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" 35 #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" 36 #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h" 37 #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" 38 #include "tensorflow/compiler/xla/service/gpu/fft_thunk.h" 39 #include "tensorflow/compiler/xla/service/gpu/for_thunk.h" 40 #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h" 41 #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" 42 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" 43 #include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h" 44 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" 45 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" 46 #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" 47 #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" 48 #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" 49 #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" 50 #include "tensorflow/compiler/xla/service/gpu/thunk.h" 51 #include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h" 52 #include "tensorflow/compiler/xla/service/gpu/while_thunk.h" 53 #include "tensorflow/compiler/xla/service/gpu/while_transformer.h" 54 #include "tensorflow/compiler/xla/service/hlo_computation.h" 55 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 56 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 57 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" 58 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" 59 #include "tensorflow/compiler/xla/service/llvm_ir/ops.h" 60 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" 61 #include "tensorflow/compiler/xla/service/name_uniquer.h" 62 #include "tensorflow/compiler/xla/shape_util.h" 63 #include "tensorflow/compiler/xla/status_macros.h" 64 #include "tensorflow/compiler/xla/types.h" 65 #include "tensorflow/compiler/xla/util.h" 66 #include "tensorflow/compiler/xla/window_util.h" 67 #include "tensorflow/compiler/xla/xla_data.pb.h" 68 #include "tensorflow/core/lib/core/status.h" 69 #include "tensorflow/core/lib/gtl/array_slice.h" 70 #include "tensorflow/core/platform/logging.h" 71 72 namespace xla { 73 namespace gpu { 74 75 namespace { 76 77 using llvm_ir::IrName; 78 using tensorflow::gtl::ArraySlice; 79 using tensorflow::gtl::nullopt; 80 using tensorflow::gtl::optional; 81 using tensorflow::strings::StrCat; 82 83 // If a dimensions is smaller than this, untiled transposition may be more 84 // efficient. 85 const int64 kMinDimensionToTransposeTiled = 16; 86 87 // Returns true if all paths from `hlo` to `root` contain only tuples. The 88 // result of such an HloInstruction does not need to be materialized, when the 89 // computation can have a hybrid result. 90 bool ReachRootViaOnlyTuples(const HloInstruction& hlo, 91 const HloInstruction& root) { 92 if (hlo.opcode() != HloOpcode::kTuple) { 93 return false; 94 } 95 96 if (&hlo == &root) { 97 return true; 98 } 99 100 for (HloInstruction* user : hlo.users()) { 101 if (!ReachRootViaOnlyTuples(*user, root)) { 102 return false; 103 } 104 } 105 106 return true; 107 } 108 109 // If `hlo` is a Transpose, returns its operand; otherwise returns `hlo` itself. 110 const HloInstruction* StripTranspose(const HloInstruction& hlo) { 111 if (hlo.IsRank2Transpose()) { 112 return hlo.operand(0); 113 } 114 return &hlo; 115 } 116 117 // Updates the launch dimensions in "thunk" and annotate the launch dimensions 118 // of the corresponding IR kernel in "llvm_module". 119 // Precondition: "thunk" must be a KernelThunk. 120 void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk, 121 llvm::Module* llvm_module) { 122 CHECK(Thunk::Kind::kKernel == thunk->kind()); 123 KernelThunk* kernel_thunk = static_cast<KernelThunk*>(thunk); 124 kernel_thunk->SetLaunchDimensions(launch_dims); 125 126 // Add __launch_bounds__ to metadata. This limits registers per thread to 127 // avoid out-of-resources launching errors. 128 llvm::NamedMDNode* nvvm_annotations_node = 129 llvm_module->getOrInsertNamedMetadata("nvvm.annotations"); 130 llvm::Function* ir_kernel = 131 llvm_module->getFunction(kernel_thunk->kernel_name().c_str()); 132 llvm::LLVMContext& llvm_context = llvm_module->getContext(); 133 llvm::ConstantInt* threads_per_block_ir_value = llvm::ConstantInt::get( 134 llvm::IntegerType::get(llvm_context, /*NumBits=*/32), 135 launch_dims.threads_per_block()); 136 // Our launch bounds are exact, so we can specify them as reqntidx rather than 137 // maxntidx. 138 nvvm_annotations_node->addOperand(llvm::MDNode::get( 139 llvm_context, 140 {llvm::ConstantAsMetadata::get(ir_kernel), 141 llvm::MDString::get(llvm_context, "reqntidx"), 142 llvm::ConstantAsMetadata::get(threads_per_block_ir_value)})); 143 } 144 145 // Tries to get a Slice for the given instruction at the given index, but 146 // returns nullopt if we might not know the slice's address at runtime without 147 // dereferencing a containing tuple. 148 // 149 // In particular, when XLA accepts a parameter of tuple type, the caller has the 150 // option of telling XLA what are the values inside of the tuple, or just giving 151 // XLA a pointer to the top-level tuple and letting us chase the pointers on the 152 // GPU. We therefore cannot rely having these pointers to parameter sub-buffers 153 // being present when we run the program. 154 optional<BufferAllocation::Slice> GetKnownAtRuntimeSlice( 155 const HloInstruction* instr, const ShapeIndex& index, 156 const BufferAssignment& buffer_assn) { 157 auto maybe_slice = buffer_assn.GetUniqueSlice(instr, index); 158 if (!maybe_slice.ok()) { 159 return nullopt; 160 } 161 // BufferAllocation gives a slice and alloc to every buffer accessed by XLA, 162 // but we don't necessarily know the runtime address of sub-buffers of input 163 // parameters. 164 const BufferAllocation::Slice& slice = maybe_slice.ValueOrDie(); 165 const BufferAllocation* alloc = slice.allocation(); 166 if (alloc->IsInputOrOutput() && !alloc->maybe_live_out() && 167 !alloc->param_shape_index().empty()) { 168 return nullopt; 169 } 170 171 // Otherwise, we will know the address of this slice at runtime without having 172 // to dereference a tuple. 173 return slice; 174 } 175 176 } // namespace 177 178 IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config, 179 const HloComputation* hlo_computation, 180 IrEmitterContext* ir_emitter_context) 181 : IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/false), 182 hlo_computation_(hlo_computation) { 183 // Initialize thunk_sequence_ to an empty list of thunks. 184 thunk_sequence_.reset(new ThunkSequence()); 185 } 186 187 Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) { 188 bindings_.UnbindAllLocalIrValues(); 189 return DfsHloVisitor::Postprocess(hlo); 190 } 191 192 namespace { 193 bool ImplementedAsHostToDeviceMemcpy(const BufferAssignment& buffer_assignment, 194 const HloInstruction& hlo) { 195 // `hlo` needs to satisfy the following conditions to be implemented as a 196 // host-to-device cuMemcpy. 197 // 198 // 1. `hlo` is a kCopy instruction. 199 // 2. `hlo`'s only operand is a kConstant instruction. 200 // 3. `hlo` and its operand have the same shape (thus the same layout too). 201 // 4. The address of `hlo`'s buffer is known at runtime (without dereferencing 202 // pointers in a tuple). 203 return hlo.opcode() == HloOpcode::kCopy && 204 hlo.operand(0)->opcode() == HloOpcode::kConstant && 205 ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) && 206 GetKnownAtRuntimeSlice(&hlo, {}, buffer_assignment).has_value(); 207 } 208 209 bool ImplementedAsDeviceToDeviceMemcpy( 210 const BufferAssignment& buffer_assignment, const HloInstruction& hlo) { 211 // `hlo` needs to satisfy three conditions to be implemented as a 212 // device-to-device cuMemcpy. 213 // 214 // 1. `hlo` is a kCopy instruction. 215 // 2. `hlo` and its operand have the same shape (thus the same layout too). 216 // 3. The operand to `hlo` has a buffer assignment (constants do not, for 217 // instance) which means the source buffer also resides on the device. 218 return hlo.opcode() == HloOpcode::kCopy && 219 ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) && 220 GetKnownAtRuntimeSlice(&hlo, {}, buffer_assignment).has_value() && 221 GetKnownAtRuntimeSlice(hlo.operand(0), {}, buffer_assignment) 222 .has_value(); 223 } 224 } // namespace 225 226 llvm::Function* IrEmitterUnnested::BuildKernelPrototype( 227 const HloInstruction& inst, 228 tensorflow::gtl::ArraySlice<const BufferAllocation*> args) { 229 // Compute the kernel name. The opcode string may contain "-" which cannot be 230 // in a PTX function name, so sanitize the name before uniquifying it. 231 string kernel_name = ir_emitter_context_->name_uniquer()->GetUniqueName( 232 llvm_ir::SanitizeFunctionName(inst.name())); 233 234 // Create the kernel and add it to the module. 235 llvm::Module* module = ir_emitter_context_->llvm_module(); 236 llvm::LLVMContext& context = module->getContext(); 237 llvm::FunctionType* kernel_type = llvm::FunctionType::get( 238 /*Result=*/llvm::Type::getVoidTy(context), 239 std::vector<llvm::Type*>(args.size(), ir_builder_.getInt8PtrTy()), 240 /*isVarArg=*/false); 241 llvm::Function* kernel = 242 llvm::Function::Create(kernel_type, llvm::GlobalValue::ExternalLinkage, 243 kernel_name.c_str(), module); 244 245 // Add dereferenceable and alignment information to each of the kernel's 246 // parameters. 247 auto arg_it = kernel->arg_begin(); 248 for (size_t arg_no = 0; arg_no < args.size(); ++arg_no) { 249 const BufferAllocation* alloc = args[arg_no]; 250 llvm::Argument* fn_arg = &*arg_it; 251 ++arg_it; 252 253 kernel->addDereferenceableAttr(arg_no + 1, alloc->size()); 254 kernel->addParamAttr( 255 arg_no, llvm::Attribute::get(context, llvm::Attribute::Alignment, 256 kCudaMallocAlignBytes)); 257 258 if (alloc->IsPreallocatedTempBuffer()) { 259 fn_arg->setName("temp_buf"); 260 } else { 261 fn_arg->setName(llvm_ir::AsStringRef(StrCat("alloc", alloc->index()))); 262 } 263 } 264 265 // TODO(b/65380986): Investigate if adding fast math flags for generated 266 // kernels makes sense. 267 268 // Add the declaration of this kernel to llvm.nvvm.annotations so that NVPTX 269 // treats it as a CUDA kernel. 270 llvm::NamedMDNode* nvvm_annotations_node = 271 module->getOrInsertNamedMetadata("nvvm.annotations"); 272 nvvm_annotations_node->addOperand(llvm::MDNode::get( 273 context, {llvm::ConstantAsMetadata::get(kernel), 274 llvm::MDString::get(context, "kernel"), 275 llvm::ConstantAsMetadata::get(ir_builder_.getInt32(1))})); 276 277 // Update the insert point to the entry basic block. 278 llvm::BasicBlock* entry_bb = 279 llvm::BasicBlock::Create(context, /*Name=*/"entry", /*Parent=*/kernel); 280 281 // Emit a "return void" at entry_bb's end, and set the insert point before 282 // that return instruction. 283 ir_builder_.SetInsertPoint(llvm::ReturnInst::Create(context, entry_bb)); 284 285 return kernel; 286 } 287 288 Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) { 289 thunk_sequence_->emplace_back(BuildKernelThunk(hlo)); 290 return IrEmitter::DefaultAction(hlo); 291 } 292 293 Status IrEmitterUnnested::HandleDot(HloInstruction* dot) { 294 const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); 295 if (dnums.lhs_batch_dimensions_size() > 0 || 296 dnums.rhs_batch_dimensions_size() > 0) { 297 return Unimplemented("Dot with batch dimensions not implemented."); 298 } 299 if (ImplementedAsGemm(*dot)) { 300 thunk_sequence_->emplace_back(BuildGemmThunk(dot)); 301 return Status::OK(); 302 } 303 thunk_sequence_->emplace_back(BuildKernelThunk(dot)); 304 return IrEmitter::HandleDot(dot); 305 } 306 307 Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) { 308 thunk_sequence_->emplace_back(BuildConditionalThunk(conditional)); 309 return Status::OK(); 310 } 311 312 Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution) { 313 thunk_sequence_->emplace_back(BuildKernelThunk(convolution)); 314 return IrEmitter::HandleConvolution(convolution); 315 } 316 317 Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { 318 // A CustomCall on the GPU backend can either be a custom-call to a 319 // user-supplied kernel, or a call into a library like cudnn. 320 321 // Lower custom-calls to cudnn batchnorm ops to specialized thunks. It's part 322 // of the contract of these cudnn batchnorm calls that the epsilon and 323 // feature_index operands be constants. 324 if (custom_call->custom_call_target() == 325 kCudnnBatchNormForwardInferenceCallTarget) { 326 const HloInstruction* epsilon = custom_call->operand(5); 327 CHECK(epsilon->IsConstant()); 328 float epsilon_value = epsilon->literal().Get<float>({}); 329 330 const HloInstruction* feature_index = custom_call->operand(6); 331 CHECK(feature_index->IsConstant()); 332 int64 feature_index_value = feature_index->literal().Get<int64>({}); 333 334 thunk_sequence_->emplace_back( 335 MakeUnique<CudnnBatchNormForwardInferenceThunk>( 336 /*operand=*/GetAllocationSlice(*custom_call->operand(0)), 337 /*scale=*/GetAllocationSlice(*custom_call->operand(1)), 338 /*offset=*/GetAllocationSlice(*custom_call->operand(2)), 339 /*mean=*/GetAllocationSlice(*custom_call->operand(3)), 340 /*variance=*/GetAllocationSlice(*custom_call->operand(4)), 341 /*epsilon=*/epsilon_value, 342 /*feature_index=*/feature_index_value, 343 /*output=*/GetAllocationSlice(*custom_call), 344 /*hlo=*/custom_call)); 345 return Status::OK(); 346 } 347 348 if (custom_call->custom_call_target() == 349 kCudnnBatchNormForwardTrainingCallTarget) { 350 const HloInstruction* epsilon = custom_call->operand(3); 351 CHECK(epsilon->IsConstant()); 352 float epsilon_value = epsilon->literal().Get<float>({}); 353 354 const HloInstruction* feature_index = custom_call->operand(4); 355 CHECK(feature_index->IsConstant()); 356 int64 feature_index_value = feature_index->literal().Get<int64>({}); 357 358 // BatchNormTraining returns a tuple of three elements: data, calculated 359 // mean, and calculated 1/sqrt(variance + epsilon). 360 const auto& assn = ir_emitter_context_->buffer_assignment(); 361 auto output_data = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie(); 362 auto output_mean = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); 363 auto output_inv_stddev = assn.GetUniqueSlice(custom_call, {2}).ValueOrDie(); 364 thunk_sequence_->emplace_back( 365 MakeUnique<CudnnBatchNormForwardTrainingThunk>( 366 /*operand=*/GetAllocationSlice(*custom_call->operand(0)), 367 /*scale=*/GetAllocationSlice(*custom_call->operand(1)), 368 /*offset=*/GetAllocationSlice(*custom_call->operand(2)), 369 /*epsilon=*/epsilon_value, 370 /*feature_index=*/feature_index_value, 371 /*output_data=*/output_data, 372 /*output_mean=*/output_mean, 373 /*output_inv_stddev=*/output_inv_stddev, 374 /*output_tuple=*/GetAllocationSlice(*custom_call), 375 /*hlo=*/custom_call)); 376 return Status::OK(); 377 } 378 379 if (custom_call->custom_call_target() == kCudnnBatchNormBackwardCallTarget) { 380 const HloInstruction* epsilon = custom_call->operand(5); 381 CHECK(epsilon->IsConstant()); 382 float epsilon_value = epsilon->literal().Get<float>({}); 383 384 const HloInstruction* feature_index = custom_call->operand(6); 385 CHECK(feature_index->IsConstant()); 386 int64 feature_index_value = feature_index->literal().Get<int64>({}); 387 388 // BatchNormGrad returns a tuple of three elements: grad_data, grad_scale, 389 // grad_offset. 390 const auto& assn = ir_emitter_context_->buffer_assignment(); 391 auto output_grad_data = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie(); 392 auto output_grad_scale = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); 393 auto output_grad_offset = 394 assn.GetUniqueSlice(custom_call, {2}).ValueOrDie(); 395 thunk_sequence_->emplace_back(MakeUnique<CudnnBatchNormBackwardThunk>( 396 /*operand=*/GetAllocationSlice(*custom_call->operand(0)), 397 /*scale=*/GetAllocationSlice(*custom_call->operand(1)), 398 /*mean=*/GetAllocationSlice(*custom_call->operand(2)), 399 /*inv_stddev=*/GetAllocationSlice(*custom_call->operand(3)), 400 /*grad_output=*/GetAllocationSlice(*custom_call->operand(4)), 401 /*epsilon=*/epsilon_value, 402 /*feature_index=*/feature_index_value, 403 /*output_grad_data=*/output_grad_data, 404 /*output_grad_scale=*/output_grad_scale, 405 /*output_grad_offset=*/output_grad_offset, 406 /*output_tuple=*/GetAllocationSlice(*custom_call), 407 /*hlo=*/custom_call)); 408 return Status::OK(); 409 } 410 411 if (IsCustomCallToDnnConvolution(*custom_call)) { 412 const auto& assn = ir_emitter_context_->buffer_assignment(); 413 const auto& lhs_shape = custom_call->operand(0)->shape(); 414 const auto& rhs_shape = custom_call->operand(1)->shape(); 415 const auto& conv_result_shape = custom_call->shape().tuple_shapes(0); 416 auto lhs_slice = GetAllocationSlice(*custom_call->operand(0)); 417 auto rhs_slice = GetAllocationSlice(*custom_call->operand(1)); 418 auto tuple_result_slice = GetAllocationSlice(*custom_call); 419 auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie(); 420 auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); 421 422 const HloInstruction* algorithm_inst = custom_call->operand(2); 423 CHECK(algorithm_inst->IsConstant()) << algorithm_inst->ToString(); 424 int64 algorithm = algorithm_inst->literal().Get<int64>({}); 425 426 const HloInstruction* tensor_ops_enabled_inst = custom_call->operand(3); 427 CHECK(tensor_ops_enabled_inst->IsConstant()) 428 << tensor_ops_enabled_inst->ToString(); 429 bool tensor_ops_enabled = tensor_ops_enabled_inst->literal().Get<bool>({}); 430 431 const auto& target = custom_call->custom_call_target(); 432 std::unique_ptr<ConvolutionThunk> thunk; 433 if (target == kCudnnConvForwardCallTarget) { 434 thunk = MakeUnique<ConvolutionThunk>( 435 CudnnConvKind::kForward, 436 /*input_buffer=*/lhs_slice, 437 /*filter_buffer=*/rhs_slice, 438 /*output_buffer=*/conv_result_slice, 439 /*tuple_result_buffer=*/tuple_result_slice, 440 /*scratch_buffer=*/scratch_slice, 441 /*input_shape=*/lhs_shape, 442 /*filter_shape=*/rhs_shape, 443 /*output_shape=*/conv_result_shape, // 444 custom_call->window(), custom_call->convolution_dimension_numbers(), 445 algorithm, tensor_ops_enabled, custom_call); 446 } else if (target == kCudnnConvBackwardInputCallTarget) { 447 thunk = MakeUnique<ConvolutionThunk>( 448 CudnnConvKind::kBackwardInput, 449 /*input_buffer=*/conv_result_slice, 450 /*filter_buffer=*/rhs_slice, 451 /*output_buffer=*/lhs_slice, 452 /*tuple_result_buffer=*/tuple_result_slice, 453 /*scratch_buffer=*/scratch_slice, 454 /*input_shape=*/conv_result_shape, 455 /*filter_shape=*/rhs_shape, 456 /*output_shape=*/lhs_shape, // 457 custom_call->window(), custom_call->convolution_dimension_numbers(), 458 algorithm, tensor_ops_enabled, custom_call); 459 } else if (target == kCudnnConvBackwardFilterCallTarget) { 460 thunk = MakeUnique<ConvolutionThunk>( 461 CudnnConvKind::kBackwardFilter, 462 /*input_buffer=*/lhs_slice, 463 /*filter_buffer=*/conv_result_slice, 464 /*output_buffer=*/rhs_slice, 465 /*tuple_result_buffer=*/tuple_result_slice, 466 /*scratch_buffer=*/scratch_slice, 467 /*input_shape=*/lhs_shape, 468 /*filter_shape=*/conv_result_shape, 469 /*output_shape=*/rhs_shape, // 470 custom_call->window(), custom_call->convolution_dimension_numbers(), 471 algorithm, tensor_ops_enabled, custom_call); 472 } else { 473 LOG(FATAL) << "Unexpected custom call target: " 474 << custom_call->custom_call_target(); 475 } 476 477 thunk_sequence_->emplace_back(std::move(thunk)); 478 return Status::OK(); 479 } 480 481 return IrEmitter::HandleCustomCall(custom_call); 482 } 483 484 Status IrEmitterUnnested::HandleFft(HloInstruction* fft) { 485 TF_RET_CHECK( 486 LayoutUtil::IsMonotonicWithDim0Major(fft->operand(0)->shape().layout())); 487 TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(fft->shape().layout())); 488 thunk_sequence_->emplace_back(BuildFftThunk(fft)); 489 return Status::OK(); 490 } 491 492 Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { 493 HloInstruction* root = fusion->fused_expression_root(); 494 // HandleFusion specializes reduction from a multi-dimensional array to a 1D 495 // array. The specialized version requires a initializer thunk that 496 // initializes the output array to the initial value of the reduce. 497 if (HloInstruction::FusionKind::kInput == fusion->fusion_kind()) { 498 switch (root->opcode()) { 499 case HloOpcode::kReduce: { 500 VLOG(3) << "Emitting fused reduction to vector: " << fusion->ToString(); 501 std::vector<std::unique_ptr<Thunk>> thunks; 502 thunks.emplace_back(BuildKernelThunk(fusion)); 503 TF_RETURN_IF_ERROR(EmitInitializer( 504 fusion, static_cast<KernelThunk*>(thunks.back().get()))); 505 bindings_.UnbindAllLocalIrValues(); 506 thunks.emplace_back(BuildKernelThunk(fusion)); 507 thunk_sequence_->emplace_back( 508 MakeUnique<SequentialThunk>(std::move(thunks), fusion)); 509 std::vector<llvm_ir::IrArray> parameter_arrays; 510 for (HloInstruction* operand : fusion->operands()) { 511 parameter_arrays.push_back(GetIrArray(*operand, *fusion)); 512 } 513 GpuElementalIrEmitter elemental_emitter( 514 hlo_module_config_, ir_emitter_context_->llvm_module(), 515 &ir_builder_, GetNestedComputer()); 516 FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter); 517 TF_RETURN_IF_ERROR(root->Accept(&fused_emitter)); 518 519 Shape input_shape = root->operand(0)->shape(); 520 // EmitReductionToVector requires the input shape to have a layout, but 521 // fused instructions don't have one. So we determine its layout from 522 // the fusion's operands. The choice of the layout only affects 523 // performance but not correctness. 524 auto choose_input_layout = []( 525 tensorflow::gtl::ArraySlice<const HloInstruction*> operands, 526 Shape* input_shape) -> Status { 527 // Prefer the layout of an operand whose shape is compatible with 528 // input_shape. 529 for (const HloInstruction* operand : operands) { 530 if (ShapeUtil::Compatible(*input_shape, operand->shape())) { 531 return LayoutUtil::CopyLayoutBetweenShapes(operand->shape(), 532 input_shape); 533 } 534 } 535 // If no operand has a compatible shape, prefer an operand that has 536 // the same rank at least. 537 for (const HloInstruction* operand : operands) { 538 if (ShapeUtil::Rank(*input_shape) == 539 ShapeUtil::Rank(operand->shape())) { 540 // Do not use CopyLayoutBetweenShapes because input_shape and 541 // operand->shape() may be incompatible. 542 *input_shape->mutable_layout() = operand->shape().layout(); 543 return Status::OK(); 544 } 545 } 546 // When all the above fails, which is rare, set the default layout. 547 LayoutUtil::SetToDefaultLayout(input_shape); 548 return Status::OK(); 549 }; 550 TF_RETURN_IF_ERROR( 551 choose_input_layout(fusion->operands(), &input_shape)); 552 553 return EmitReductionToVector( 554 root, input_shape, fused_emitter.GetGenerator(root->operand(0)), 555 fused_emitter.GetGenerator(root->operand(1)), root->dimensions(), 556 root->to_apply()); 557 } 558 default: 559 LOG(FATAL) << "Bad opcode for input fusion: " 560 << fusion->fused_expression_root()->opcode(); 561 } 562 } else if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace( 563 fusion, ir_emitter_context_->buffer_assignment())) { 564 // Fusion node with dynamic-update-slice as the root where the op's input 565 // (i.e. array to update) shares the same slice as its output. In this case 566 // we have a special algorithm that modifies the output in place without 567 // touching the un-updated elements. 568 569 // Set up kernel thunk and fused ir emitter. 570 thunk_sequence_->emplace_back(BuildKernelThunk(fusion)); 571 std::vector<llvm_ir::IrArray> operand_arrays; 572 for (HloInstruction* operand : fusion->operands()) { 573 operand_arrays.push_back(GetIrArray(*operand, *fusion)); 574 } 575 GpuElementalIrEmitter elemental_emitter(hlo_module_config_, 576 ir_emitter_context_->llvm_module(), 577 &ir_builder_, GetNestedComputer()); 578 579 // Shape of the dynamic-update-slice's "update" operand. 580 Shape update_shape = root->operand(1)->shape(); 581 582 // Array to write into. Because this is an in-place operation, this is the 583 // same as operand 0's array. 584 llvm_ir::IrArray output_array = GetIrArray(*fusion, *fusion); 585 586 LaunchDimensions launch_dimensions = CalculateLaunchDimensions( 587 update_shape, ir_emitter_context_->device_description()); 588 CHECK(Thunk::Kind::kKernel == LastThunk()->kind()); 589 UpdateLaunchDimensions(launch_dimensions, 590 static_cast<KernelThunk*>(LastThunk()), 591 ir_emitter_context_->llvm_module()); 592 593 return llvm_ir::EmitParallelFusedDynamicUpdateSliceInPlace( 594 fusion, operand_arrays, output_array, &elemental_emitter, 595 launch_dimensions, &ir_builder_); 596 } 597 if (ImplementedAsGemm(*fusion)) { 598 thunk_sequence_->emplace_back(BuildGemmThunk(fusion)); 599 return Status::OK(); 600 } 601 thunk_sequence_->emplace_back(BuildKernelThunk(fusion)); 602 return IrEmitter::HandleFusion(fusion); 603 } 604 605 namespace { 606 607 // Returns the indices of the first elements of all consecutive subarrays of the 608 // given array. For example: 609 // ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4} 610 std::vector<size_t> ConsecutiveSegments(tensorflow::gtl::ArraySlice<int64> xs) { 611 std::vector<size_t> is = {0}; 612 for (size_t i = 1; i < xs.size(); ++i) { 613 if (1 != xs[i] - xs[i - 1]) { 614 is.push_back(i); 615 } 616 } 617 return is; 618 } 619 620 // Merges the sequences of dimensions of the given shape which start at the 621 // given indices `segs`. 622 Shape MergeDimensions(tensorflow::gtl::ArraySlice<size_t> segs, 623 const Shape& shape) { 624 std::vector<int64> dimensions; 625 for (size_t i = 1; i <= segs.size(); ++i) { 626 dimensions.push_back(std::accumulate( 627 shape.dimensions().begin() + segs[i - 1], 628 shape.dimensions().begin() + 629 (segs.size() == i ? shape.dimensions().size() : segs[i]), 630 1, std::multiplies<int64>())); 631 } 632 return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(), 633 dimensions); 634 } 635 636 // Returns whether the given shapes and permutation are a 0-2-1 transpose, and 637 // if so, the normalized and rank-reduced shapes. The shapes must have the same 638 // dimensions, so this considers layout only. 639 // 640 // This function recognizes higher-rank transposes which are elementwise 641 // equivalent to a 0-2-1 transpose. 642 std::tuple<bool, Shape, Shape> IsTranspose021(const Shape& a, const Shape& b) { 643 CHECK(ShapeUtil::Compatible(a, b)); 644 std::vector<int64> perm(a.dimensions().size()); 645 { 646 auto layout_a_orig = LayoutUtil::MinorToMajor(a); 647 std::vector<int64> layout_a(layout_a_orig.rbegin(), layout_a_orig.rend()); 648 auto layout_b_orig = LayoutUtil::MinorToMajor(b); 649 std::vector<int64> layout_b(layout_b_orig.rbegin(), layout_b_orig.rend()); 650 for (size_t i = 0; i < perm.size(); ++i) { 651 perm[i] = PositionInContainer(layout_b, layout_a[i]); 652 } 653 } 654 auto segs = ConsecutiveSegments(perm); 655 Shape norm_a = 656 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a); 657 Shape norm_b = 658 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(b); 659 if (3 == segs.size() && 0 == perm[0]) { 660 Shape reduced_a = MergeDimensions(segs, norm_a); 661 Shape reduced_b = ShapeUtil::MakeShapeWithDescendingLayout( 662 b.element_type(), 663 Permute({0, 2, 1}, AsInt64Slice(reduced_a.dimensions()))); 664 return std::make_tuple(true, reduced_a, reduced_b); 665 } 666 return std::make_tuple(false, ShapeUtil::MakeNil(), ShapeUtil::MakeNil()); 667 } 668 669 // Returns whether the given shapes are potentially of a 0-2-1 transpose. 670 // As 0-2-1 is a self-inverse permutation, which shape is input or output is 671 // arbitrary. 672 bool AreShapesForTranspose021(const Shape& a, const Shape& b) { 673 return 3 == b.dimensions().size() && 674 ShapeUtil::Compatible( 675 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a), 676 ShapeUtil::PermuteDimensions( 677 {0, 2, 1}, 678 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( 679 b))); 680 } 681 682 // Emits a tiled 0-2-1 transpose, assuming both input and output lain out from 683 // major to minor. The x- and y- dimensions are tiled in square tiles of edge 684 // length `tile_size`. Each thread block of `tile_size` x `num_rows` threads 685 // transposes one tile: each thread copies a row from the input to a shared 686 // memory tile, then copies a column from the shared memory tile to the output. 687 // 688 // `tile_size` should usually be same as warp size. 689 // 690 // Returns (number of tiles = number of thread blocks needed). 691 // 692 // TODO(b/33320379): Here each block transposes 1 tile. It may be more efficient 693 // to launch fewer blocks so each transposes many tiles, and 694 // in any case, the number of blocks we can launch is limited. 695 // 696 // This is the same algorithm in CUDA: 697 // https://github.com/tensorflow/tensorflow/blob/d2693c8a70567cc78b2e8a9ac8020d321620ca83/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc#L189 698 int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output, 699 const int64 tile_size, const int64 num_rows, 700 llvm::IRBuilder<>* builder) { 701 // Adds `addend` to the given `dim` of `index`. 702 auto offset_dim = [builder](llvm_ir::IrArray::Index index, 703 llvm::Value* addend, int64 dim) { 704 index[dim] = builder->CreateAdd(index[dim], addend); 705 return index; 706 }; 707 708 CHECK(AreShapesForTranspose021(input.GetShape(), output.GetShape())); 709 710 Shape input_shape = 711 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( 712 input.GetShape()); 713 Shape output_shape = 714 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( 715 output.GetShape()); 716 input = input.CastToShape(input_shape, builder); 717 output = output.CastToShape(output_shape, builder); 718 719 llvm::Type* tile_type = llvm::ArrayType::get( 720 llvm::ArrayType::get(input.GetElementLlvmType(), tile_size), 721 // One extra here to avoid share memory bank conflict 722 tile_size + 1); 723 auto* tile = new llvm::GlobalVariable( 724 *builder->GetInsertBlock()->getParent()->getParent(), tile_type, 725 /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage, 726 llvm::UndefValue::get(tile_type), "tile", nullptr, 727 llvm::GlobalValue::NotThreadLocal, 728 /*AddressSpace=*/3 /* GPU shared memory */); 729 730 // let x = threadIdx.x 731 llvm::Value* x = llvm_ir::EmitCallToIntrinsic( 732 llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, builder); 733 llvm_ir::AddRangeMetadata(0, num_rows * tile_size, 734 static_cast<llvm::Instruction*>(x)); 735 x = builder->CreateIntCast(x, builder->getInt64Ty(), /*isSigned=*/true, 736 "thread.id.x"); 737 738 // computing logical thread ids 739 // logical_x = x % tile_size 740 auto logical_x = builder->CreateURem(x, builder->getInt64(tile_size)); 741 742 // logical_y = x / tile_size 743 auto logical_y = builder->CreateUDiv(x, builder->getInt64(tile_size)); 744 745 // `emit_cp` emits equivalent to following pseudocode: 746 // if (tile_size == tile_width && tile_size == tile_height) { 747 // unroll for (i in range(0, tile_size, num_rows)) { 748 // emit_cp_element(index + {0, i, 0}, y + logical_y); 749 // } 750 // } else if (x < tile_width) { 751 // tile_height_upperbound = ceil(tile_height / num_rows) * num_rows; 752 // for (i in range(0, tile_height_upperbound, num_rows)) { 753 // y_loc = i + logical_y; 754 // if (y_loc < tile_height) 755 // emit_cp_element(index + {0, i, 0}, y_loc); 756 // } 757 // } 758 // 759 // We use this to emit both the copy from input to tile and the copy from tile 760 // to output. 761 // 762 // `index` is the origin of the row or column in the input or output array. 763 // 764 // `emit_cp_element(index, y)` emits code to copy a single element between the 765 // tile and the input or output array, where `y` is the `y`-position in the 766 // tile, whether which is row or column is a function of whether we're copying 767 // from input or to output, and `index` is the index into the input or output 768 // array. 769 auto emit_cp_tile = [builder, tile_size, &offset_dim, num_rows, logical_x, 770 logical_y]( 771 std::function<void(const llvm_ir::IrArray::Index&, 772 llvm::Value*)> 773 emit_cp_element, 774 llvm::Value* tile_width, llvm::Value* tile_height, 775 const llvm_ir::IrArray::Index& index, 776 const string& loop_name) { 777 llvm_ir::LlvmIfData if_not_last_row = llvm_ir::EmitIfThenElse( 778 builder->CreateAnd( 779 builder->CreateICmpEQ(builder->getInt64(tile_size), tile_width), 780 builder->CreateICmpEQ(builder->getInt64(tile_size), tile_height)), 781 "not_last_row", builder); 782 builder->SetInsertPoint(if_not_last_row.true_block->getTerminator()); 783 for (int64 i = 0; i < tile_size; i += num_rows) { 784 auto source_idx = offset_dim(index, builder->getInt64(i), /*dim=*/1); 785 auto y_loc = builder->CreateAdd(builder->getInt64(i), logical_y); 786 emit_cp_element(source_idx, y_loc); 787 } 788 builder->SetInsertPoint(if_not_last_row.false_block->getTerminator()); 789 llvm_ir::LlvmIfData if_in_tile = llvm_ir::EmitIfThenElse( 790 builder->CreateICmpULT(logical_x, tile_width), "x_in_tile", builder); 791 builder->SetInsertPoint(if_in_tile.true_block->getTerminator()); 792 793 // tile_height_upper_bound = ceil(tile_height / num_rows) * num_rows 794 auto tile_height_upper_bound = builder->CreateMul( 795 builder->CreateUDiv( 796 builder->CreateAdd(tile_height, builder->getInt64(num_rows - 1)), 797 builder->getInt64(num_rows)), 798 builder->getInt64(num_rows)); 799 800 auto loop = llvm_ir::ForLoop::EmitForLoop( 801 loop_name, builder->getInt64(0), tile_height_upper_bound, 802 builder->getInt64(num_rows), builder); 803 llvm_ir::SetToFirstInsertPoint(loop->GetHeaderBasicBlock(), builder); 804 builder->SetInsertPoint(loop->GetBodyBasicBlock()->getTerminator()); 805 806 auto y_loc = builder->CreateAdd(loop->GetIndVarValue(), logical_y); 807 auto if_y_in_tile = llvm_ir::EmitIfThenElse( 808 builder->CreateICmpULT(y_loc, tile_height), "y_in_tile", builder); 809 builder->SetInsertPoint(if_y_in_tile.true_block->getTerminator()); 810 811 emit_cp_element(offset_dim(index, loop->GetIndVarValue(), /*dim=*/1), 812 y_loc); 813 builder->SetInsertPoint(if_not_last_row.after_block->getTerminator()); 814 }; 815 816 auto input_dims_in_tiles = input_shape.dimensions(); 817 // Unpermuted dimensions are untiled. 818 for (int i = 1; i < 3; ++i) { 819 input_dims_in_tiles[i] = 820 CeilOfRatio<int64>(input_dims_in_tiles[i], tile_size); 821 } 822 int64 num_tiles = 823 std::accumulate(input_dims_in_tiles.begin(), input_dims_in_tiles.end(), 1, 824 std::multiplies<int64>()); 825 const llvm_ir::IrArray::Index input_tile_index( 826 /*linear=*/builder->CreateIntCast( 827 llvm_ir::AddRangeMetadata( 828 0, num_tiles, 829 static_cast<llvm::Instruction*>(llvm_ir::EmitCallToIntrinsic( 830 llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, 831 builder))), 832 builder->getInt64Ty(), /*isSigned=*/true, "block.id.x"), 833 ShapeUtil::MakeShapeWithDescendingLayout( 834 PRED /*arbitrary*/, AsInt64Slice(input_dims_in_tiles)), 835 builder); 836 const llvm_ir::IrArray::Index input_tile_origin = ({ 837 llvm_ir::IrArray::Index index = input_tile_index; 838 for (int i = 1; i < 3; ++i) { 839 index[i] = builder->CreateMul(index[i], builder->getInt64(tile_size), 840 "tile_origin." + std::to_string(i)); 841 } 842 index; 843 }); 844 const llvm_ir::IrArray::Index input_index = 845 offset_dim(offset_dim(input_tile_origin, logical_x, /*dim=*/2), logical_y, 846 /*dim=*/1); 847 std::vector<llvm::Value*> tile_dims(input_shape.dimensions().size()); 848 // Only last row or column may not have full size. 849 for (int i = 1; i < 3; ++i) { 850 tile_dims[i] = builder->CreateSelect( 851 builder->CreateICmpEQ(input_tile_index[i], 852 builder->getInt64(input_dims_in_tiles[i] - 1)), 853 builder->getInt64(input_shape.dimensions(i) - 854 (input_dims_in_tiles[i] - 1) * tile_size), 855 builder->getInt64(tile_size), "tile_size"); 856 } 857 858 // Load data from input memory to shared memory tile. 859 emit_cp_tile( 860 // tile[y, x] = input_array[index] 861 [builder, tile, &input, logical_x](const llvm_ir::IrArray::Index& index, 862 llvm::Value* y) { 863 builder->CreateStore( 864 input.EmitReadArrayElement(index, builder, "input_element"), 865 builder->CreateGEP(tile, {builder->getInt64(0), y, logical_x})); 866 }, 867 tile_dims[2], tile_dims[1], input_index, "input"); 868 869 // Wait for all threads to reach this point, lest we copy a value from tile to 870 // output before the other thread copies it from input to tile. 871 // This is `__syncthreads` in CUDA. 872 llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, builder); 873 874 const llvm_ir::IrArray::Index output_tile_index( 875 Permute({0, 2, 1}, input_tile_index.multidim())); 876 const llvm_ir::IrArray::Index output_tile_origin( 877 Permute({0, 2, 1}, input_tile_origin.multidim())); 878 const llvm_ir::IrArray::Index output_index = 879 offset_dim(offset_dim(output_tile_origin, logical_x, /*dim=*/2), 880 logical_y, /*dim=*/1); 881 882 // Store data from shared memory tile to output memory. 883 emit_cp_tile( 884 // output_array[index] = tile[x, y] 885 [builder, tile, &output, logical_x](const llvm_ir::IrArray::Index& index, 886 llvm::Value* y) { 887 output.EmitWriteArrayElement( 888 index, 889 builder->CreateLoad( 890 builder->CreateGEP(tile, {builder->getInt64(0), logical_x, y}), 891 "output_element"), 892 builder); 893 }, 894 tile_dims[1], tile_dims[2], output_index, "output"); 895 896 return num_tiles; 897 } 898 899 } // namespace 900 901 Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { 902 if (ImplementedAsHostToDeviceMemcpy(ir_emitter_context_->buffer_assignment(), 903 *copy)) { 904 thunk_sequence_->emplace_back(BuildHostToDeviceCopyThunk(copy)); 905 return Status::OK(); 906 } 907 if (ImplementedAsDeviceToDeviceMemcpy( 908 ir_emitter_context_->buffer_assignment(), *copy)) { 909 thunk_sequence_->emplace_back(BuildDeviceToDeviceCopyThunk(copy)); 910 return Status::OK(); 911 } 912 bool is_transpose_021; 913 Shape reduced_input_shape, reduced_output_shape; 914 std::tie(is_transpose_021, reduced_input_shape, reduced_output_shape) = 915 IsTranspose021(copy->operand(0)->shape(), copy->shape()); 916 if (is_transpose_021 && 917 reduced_input_shape.dimensions(1) >= kMinDimensionToTransposeTiled && 918 reduced_input_shape.dimensions(2) >= kMinDimensionToTransposeTiled) { 919 thunk_sequence_->emplace_back(BuildKernelThunk(copy)); 920 VLOG(3) << "Emitting tiled 0-2-1 transposition"; 921 constexpr int64 tile_size = 32; 922 constexpr int64 num_rows = 8; 923 int64 num_tiles = EmitTranspose021Tiled( 924 GetIrArray(*copy->operand(0), *copy) 925 .CastToShape(reduced_input_shape, &ir_builder_), 926 GetIrArray(*copy, *copy) 927 .CastToShape(reduced_output_shape, &ir_builder_), 928 tile_size, num_rows, &ir_builder_); 929 UpdateLaunchDimensions(LaunchDimensions(num_tiles, num_rows * tile_size), 930 LastThunk(), ir_emitter_context_->llvm_module()); 931 return Status::OK(); 932 } 933 934 return IrEmitter::HandleCopy(copy); 935 } 936 937 Status IrEmitterUnnested::EmitReductionToScalar( 938 HloInstruction* reduce, const Shape& input_shape, 939 const llvm_ir::ElementGenerator& input_gen, 940 const llvm_ir::ElementGenerator& init_value_gen, HloComputation* reducer) { 941 // Number of elements processed by a single thread. 942 constexpr int64 kTileSize = 16; 943 int64 num_elems = ShapeUtil::ElementsIn(input_shape); 944 945 // Round up the number of tiles to a multiple of the warp size. This is 946 // necessary for correctness. We launch one thread per tile, and if the 947 // number of threads isn't a multiple of the number of the warp size, our 948 // shuffles will read from inactive threads, producing undefined values. 949 int64 num_tiles = 950 RoundUpToNearest(CeilOfRatio(num_elems, kTileSize), kWarpSize); 951 952 // Check whether every thread will process a full tile's worth of elements 953 // without reading outside the bounds of the input. If this is true, we can 954 // skip some bounds checks in the final algorithm. 955 bool all_threads_in_bounds = num_tiles * kTileSize == num_elems; 956 957 // __global__ void full_reduce_kernel() { 958 // x_in_tiles = threadIdx.x + blockIdx.x * blockDim.x; 959 // x = x_in_tiles * kTileSize; 960 // 961 // partial_result = init_value; 962 // if (all_threads_in_bounds || x + kTileSize <= num_elems) { 963 // for (i = 0; i < kTileSize; ++i) { 964 // partial_result = Reducer(partial_result, input[x + i]); 965 // } 966 // } else { 967 // for (i = 0; i < kTileSize; ++i) { 968 // if (x + i < num_elems) { 969 // partial_result = Reducer(partial_result, input[x + i]); 970 // } 971 // } 972 // } 973 // for (i = warpSize / 2; i > 0; i /= 2) { 974 // partial_result = Reducer(partial_result, 975 // __shfl_down(partial_result, i)); 976 // } 977 // if (lane_id == 0) { 978 // AtomicReducer(&output[y], partial_result); 979 // } 980 // } 981 // 982 // // Choose num_blocks and threads_per_block such that: 983 // // 984 // // num_blocks * threads_per_block = 985 // // RoundUpToNextMultipleOf(Ceil(num_elems / kTileSize), warpSize), 986 // // 987 // // and threads_per_block is a multiple of warpSize. 988 // reduce_kernel<<<num_blocks, threads_per_block>>>(); 989 // 990 auto loop_body_emitter = 991 [=](const llvm_ir::IrArray::Index& tile_index) -> Status { 992 llvm::Type* element_ir_type = 993 llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_); 994 llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( 995 element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result"); 996 { 997 TF_ASSIGN_OR_RETURN(llvm::Value * init_ir_value, 998 init_value_gen(llvm_ir::IrArray::Index({}))); 999 ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); 1000 } 1001 1002 llvm::Value* x_in_tiles = tile_index[0]; 1003 1004 // Emit an inner for-loop that reduces the elements in the tile. 1005 auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status { 1006 std::unique_ptr<llvm_ir::ForLoop> tile_element_loop = 1007 llvm_ir::ForLoop::EmitForLoop("element_id_in_tile", 1008 ir_builder_.getInt64(0), 1009 ir_builder_.getInt64(kTileSize), 1010 ir_builder_.getInt64(1), &ir_builder_); 1011 1012 // Emit the body of the partial reduction loop. 1013 llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), 1014 &ir_builder_); 1015 llvm::Value* x = ir_builder_.CreateNSWAdd( 1016 ir_builder_.CreateNSWMul(x_in_tiles, ir_builder_.getInt64(kTileSize)), 1017 tile_element_loop->GetIndVarValue()); 1018 // Unless we know the tile is entirely in bounds, we have to emit a 1019 // x-in-bounds check before reading from the input. 1020 if (!tile_in_bounds) { 1021 llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( 1022 ir_builder_.CreateICmpULT(x, ir_builder_.getInt64(num_elems)), 1023 "x_in_bounds", &ir_builder_); 1024 1025 // Emit code that reads the input element and accumulates it to 1026 // the partial reduction result. 1027 llvm_ir::SetToFirstInsertPoint(if_data.true_block, &ir_builder_); 1028 } 1029 llvm_ir::IrArray::Index input_index( 1030 /*linear=*/x, input_shape, &ir_builder_); 1031 llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type); 1032 TF_ASSIGN_OR_RETURN(llvm::Value * input_ir_value, input_gen(input_index)); 1033 ir_builder_.CreateStore(input_ir_value, input_address); 1034 return (EmitCallToNestedComputation( 1035 *reducer, {partial_reduction_result_address, input_address}, 1036 partial_reduction_result_address)); 1037 }; 1038 1039 // x_end = kTileSize + x_in_tiles * kTileSize, i.e., the location that's 1040 // immediately beyond the tile. 1041 llvm::Value* x_end = ir_builder_.CreateNSWAdd( 1042 ir_builder_.getInt64(kTileSize), 1043 ir_builder_.CreateNSWMul(x_in_tiles, ir_builder_.getInt64(kTileSize))); 1044 // The tile is entirely in bound if all_threads_in_bounds or 1045 // x_end <= num_elems. 1046 llvm::Value* tile_in_bounds = ir_builder_.CreateOr( 1047 ir_builder_.CreateICmpULE(x_end, ir_builder_.getInt64(num_elems)), 1048 ir_builder_.getInt1(all_threads_in_bounds)); 1049 llvm_ir::LlvmIfData if_tile_in_bounds_data = 1050 llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_); 1051 llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, 1052 &ir_builder_); 1053 TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true)); 1054 llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block, 1055 &ir_builder_); 1056 TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false)); 1057 1058 // After the if-then-else statement on tile_in_bounds, emit calls to 1059 // shfl_down that accumulate the partial reduction results of all threads 1060 // from the warp. 1061 llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, 1062 &ir_builder_); 1063 int bit_width = llvm_ir::GetSizeInBits(element_ir_type); 1064 // bitcast cannot be applied to aggregate types (even packed ones), so we 1065 // instead bitcast addresses of load/store to intN* of the same bit-width. 1066 llvm::Type* shuffle_ir_type = element_ir_type->isStructTy() 1067 ? ir_builder_.getIntNTy(bit_width) 1068 : element_ir_type; 1069 for (int shuffle_distance = kWarpSize / 2; shuffle_distance >= 1; 1070 shuffle_distance /= 2) { 1071 llvm::Value* partial_reduction_result = ir_builder_.CreateLoad( 1072 ir_builder_.CreateBitCast(partial_reduction_result_address, 1073 shuffle_ir_type->getPointerTo()), 1074 "partial_reduction_result"); 1075 llvm::Value* result_from_other_lane = ir_builder_.CreateAlloca( 1076 element_ir_type, nullptr, "result_from_other_lane"); 1077 ir_builder_.CreateStore( 1078 EmitShuffleDown(partial_reduction_result, 1079 ir_builder_.getInt32(shuffle_distance), &ir_builder_), 1080 ir_builder_.CreateBitCast(result_from_other_lane, 1081 shuffle_ir_type->getPointerTo())); 1082 TF_RETURN_IF_ERROR(EmitCallToNestedComputation( 1083 *reducer, {partial_reduction_result_address, result_from_other_lane}, 1084 partial_reduction_result_address)); 1085 } 1086 1087 const HloInstruction* output = 1088 reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; 1089 1090 // Emit an atomic operation that accumulates the partial reduction result of 1091 // lane 0 (which holds the partially accumulated result for its warp) to the 1092 // output element. 1093 llvm::Value* lane_id = ir_builder_.CreateURem( 1094 x_in_tiles, ir_builder_.getInt64(kWarpSize), "lane_id"); 1095 llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( 1096 ir_builder_.CreateICmpEQ(lane_id, ir_builder_.getInt64(0)), 1097 "lane_id_is_zero", &ir_builder_); 1098 llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, 1099 &ir_builder_); 1100 llvm::Value* output_address = 1101 GetIrArray(*output, *output) 1102 .EmitArrayElementAddress( 1103 llvm_ir::IrArray::Index(/*linear=*/ir_builder_.getInt64(0), 1104 output->shape(), &ir_builder_), 1105 &ir_builder_, "output_element_address"); 1106 return EmitAtomicOperationForNestedComputation( 1107 *reducer, output_address, partial_reduction_result_address); 1108 }; 1109 1110 // Emit a parallel loop that iterates through all input tiles, one per thread. 1111 Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( 1112 reduce->shape().element_type(), {num_tiles}, {0}); 1113 LaunchDimensions launch_dimensions = CalculateLaunchDimensions( 1114 tiled_input_shape, ir_emitter_context_->device_description()); 1115 CHECK(LastThunk()->kind() == Thunk::Kind::kSequential); 1116 UpdateLaunchDimensions( 1117 launch_dimensions, 1118 static_cast<SequentialThunk*>(LastThunk())->thunks().back().get(), 1119 ir_emitter_context_->llvm_module()); 1120 return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, 1121 launch_dimensions, &ir_builder_) 1122 .EmitLoop(IrName(reduce)); 1123 } 1124 1125 Status IrEmitterUnnested::EmitColumnReduction( 1126 int64 height, int64 width, HloInstruction* reduce, const Shape& input_shape, 1127 const llvm_ir::ElementGenerator& input_gen, 1128 const llvm_ir::ElementGenerator& init_value_gen, HloComputation* reducer) { 1129 // Divide the input matrix into tiles of size Kx1. For example, when the 1130 // input matrix is 4x4 and K=2, the tiled matrix looks like 1131 // 1132 // 0123 1133 // 0123 1134 // 4567 1135 // 4567 // Numbers indicate tile IDs. 1136 // 1137 // Each tile is first partially reduced to a scalar by a thread, and then the 1138 // scalar is accumulated to the output vector using atomic operations. We 1139 // choose 16 as the tile size, which matches Eigen's ColumnReduceKernel. 1140 constexpr int64 kTileSize = 16; 1141 // If the height is not a multiple of the tile size, we pad the bottom of the 1142 // input matrix. 1143 const int64 height_in_tiles = CeilOfRatio(height, kTileSize); 1144 1145 // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x; 1146 // linear_index < height_in_tiles * width; 1147 // linear_index += blockDim.x * gridDim.x) { 1148 // y_in_tiles = linear_index / width; 1149 // x = linear_index % width; 1150 // 1151 // partial_result = init_value; 1152 // if (height % kTileSize == 0 || 1153 // y_in_tiles * kTileSize + kTileSize <= height) { 1154 // for (element_id_in_tile : range(kTileSize)) { 1155 // y = y_in_tiles * kTileSize + element_id_in_tile; 1156 // partial_result = Reducer(partial_result, input[y][x]); 1157 // } 1158 // } else { 1159 // for (element_id_in_tile : range(kTileSize)) { 1160 // y = y_in_tiles * kTileSize + element_id_in_tile; 1161 // if (y < height) { 1162 // partial_result = Reducer(partial_result, input[y][x]); 1163 // } 1164 // } 1165 // } 1166 // AtomicReducer(&output[x], partial_result); 1167 // } 1168 auto loop_body_emitter = 1169 [=](const llvm_ir::IrArray::Index& tile_index) -> Status { 1170 // Emit the loop body that reduces one tile. 1171 llvm::Type* element_ir_type = 1172 llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_); 1173 llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( 1174 element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result"); 1175 { 1176 TF_ASSIGN_OR_RETURN(llvm::Value * init_ir_value, 1177 init_value_gen(llvm_ir::IrArray::Index({}))); 1178 ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); 1179 } 1180 1181 // Emit an inner for-loop that partially reduces the elements in the given 1182 // tile. 1183 llvm::Value* y_in_tiles = tile_index[0]; 1184 llvm::Value* x = tile_index[1]; 1185 1186 auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status { 1187 std::unique_ptr<llvm_ir::ForLoop> tile_element_loop = 1188 llvm_ir::ForLoop::EmitForLoop("element_id_in_tile", 1189 ir_builder_.getInt64(0), 1190 ir_builder_.getInt64(kTileSize), 1191 ir_builder_.getInt64(1), &ir_builder_); 1192 1193 // Emit the body of the partial reduction loop. 1194 llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), 1195 &ir_builder_); 1196 llvm::Value* y = ir_builder_.CreateNSWAdd( 1197 ir_builder_.CreateNSWMul(y_in_tiles, ir_builder_.getInt64(kTileSize)), 1198 tile_element_loop->GetIndVarValue()); 1199 // Unless we know the tile is entirely in bounds, we have to emit a 1200 // y-in-bounds check before reading from the input. 1201 if (!tile_in_bounds) { 1202 llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( 1203 ir_builder_.CreateICmpULT(y, ir_builder_.getInt64(height)), 1204 "y_in_bounds", &ir_builder_); 1205 1206 // Emit code that reads the input element and accumulates it to 1207 // the partial reduction result. 1208 llvm_ir::SetToFirstInsertPoint(if_data.true_block, &ir_builder_); 1209 } 1210 llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type); 1211 { 1212 // {y,x} is an index to input_matrix_shape [height,width]. We need to 1213 // convert that to an index to input_shape (the shape of the operand of 1214 // "reduce"). This conversion is composed of a transposition from 1215 // input_shape to normalized_input_shape and a reshape from 1216 // normalized_input_shape to input_matrix_shape. 1217 const Shape normalized_input_shape = 1218 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( 1219 input_shape); 1220 auto input_shape_min2maj = LayoutUtil::MinorToMajor(input_shape); 1221 const std::vector<int64> transpose_dimension_mapping( 1222 input_shape_min2maj.rbegin(), input_shape_min2maj.rend()); 1223 1224 const Shape input_matrix_shape = 1225 ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(), 1226 {height, width}); 1227 const llvm_ir::IrArray::Index input_matrix_index( 1228 {y, x}, input_matrix_shape, &ir_builder_); 1229 const llvm_ir::IrArray::Index input_index = 1230 input_matrix_index 1231 .SourceIndexOfReshape(input_matrix_shape, 1232 normalized_input_shape, &ir_builder_) 1233 .SourceIndexOfTranspose(normalized_input_shape, input_shape, 1234 transpose_dimension_mapping, 1235 &ir_builder_); 1236 TF_ASSIGN_OR_RETURN(llvm::Value * input_ir_value, 1237 input_gen(input_index)); 1238 ir_builder_.CreateStore(input_ir_value, input_address); 1239 } 1240 return (EmitCallToNestedComputation( 1241 *reducer, {partial_reduction_result_address, input_address}, 1242 partial_reduction_result_address)); 1243 }; 1244 1245 // y_end = kTileSize + y_in_tiles * kTileSize, i.e., the y location that's 1246 // immediately beyond the tile. 1247 llvm::Value* y_end = ir_builder_.CreateNSWAdd( 1248 ir_builder_.getInt64(kTileSize), 1249 ir_builder_.CreateNSWMul(y_in_tiles, ir_builder_.getInt64(kTileSize))); 1250 llvm::Value* tile_in_bounds = ir_builder_.CreateOr( 1251 ir_builder_.CreateICmpULE(y_end, ir_builder_.getInt64(height)), 1252 ir_builder_.getInt1(height % kTileSize == 0)); 1253 // The tile is entirely in bound if "height" is a multiple of kTileSize or 1254 // y_end <= height. 1255 llvm_ir::LlvmIfData if_tile_in_bounds_data = 1256 llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_); 1257 llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, 1258 &ir_builder_); 1259 TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true)); 1260 llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block, 1261 &ir_builder_); 1262 TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false)); 1263 1264 // After the if-then-else statement on tile_in_bounds, emit atomic 1265 // operations to accumulate the partial reduction result to the output 1266 // element. 1267 llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, 1268 &ir_builder_); 1269 const HloInstruction* output = 1270 reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; 1271 llvm::Value* output_address = 1272 GetIrArray(*output, *output) 1273 .EmitArrayElementAddress( 1274 llvm_ir::IrArray::Index(x, output->shape(), &ir_builder_), 1275 &ir_builder_, "output_element_address"); 1276 return EmitAtomicOperationForNestedComputation( 1277 *reducer, output_address, partial_reduction_result_address); 1278 }; 1279 1280 // Emit a parallel loop that iterate through all input tiles. 1281 Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( 1282 reduce->shape().element_type(), {height_in_tiles, width}, {1, 0}); 1283 LaunchDimensions launch_dimensions = CalculateLaunchDimensions( 1284 tiled_input_shape, ir_emitter_context_->device_description()); 1285 CHECK(LastThunk()->kind() == Thunk::Kind::kSequential); 1286 UpdateLaunchDimensions( 1287 launch_dimensions, 1288 static_cast<SequentialThunk*>(LastThunk())->thunks().back().get(), 1289 ir_emitter_context_->llvm_module()); 1290 return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, 1291 launch_dimensions, &ir_builder_) 1292 .EmitLoop(IrName(reduce)); 1293 } 1294 1295 Status IrEmitterUnnested::EmitRowReduction( 1296 int64 depth, int64 height, int64 width, HloInstruction* reduce, 1297 const Shape& input_shape, const llvm_ir::ElementGenerator& input_gen, 1298 const llvm_ir::ElementGenerator& init_value_gen, HloComputation* reducer) { 1299 // A naive algorithm is: 1300 // 1. Divide the input tensor into tiles of size 1x1xK. 1301 // 2. Partially reduces each tile to a scalar using one thread. 1302 // 3. Accumulates that scalar to the output vector using atomic operations. 1303 // 1304 // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x; 1305 // linear_index < depth * height * width_in_tiles; 1306 // linear_index += blockDim.x * gridDim.x) { 1307 // int x_in_tiles = linear_index % width_in_tiles; 1308 // int y = linear_index / width_in_tiles % height; 1309 // int z = linear_index / (height * width_in_tiles); 1310 // float partial_result = 0; 1311 // for (element_id_in_tile : range(kTileSize)) { 1312 // int x = x_in_tiles * kTileSize + element_id_in_tile; 1313 // if (x < width) 1314 // partial_result = reducer(partial_result, input[z][y][z]); 1315 // } 1316 // AtomicReducer(&output[y], partial_result); 1317 // } 1318 // 1319 // Three optimizations are performed. 1320 // 1321 // 1. To coalesce global memory accesses, dilate the tile with a factor of 32 1322 // (i.e. the warp size). For example, suppose the width is 8x32=256. Instead 1323 // of making each tile consecutive, we let make tile 0 column 1324 // [0,32,64,...,224], tile 1 column [1,33,65,...,225], and so on. This ensures 1325 // that threads in a warp access consecutive memory in one iteration (i.e. 1326 // coalesced). In the above example, the warp that contains thread 0-31 1327 // accesses column 0-31 in the first iteration, and 32-63 in the second 1328 // iteration, and so on. 1329 // 1330 // 2. Partially accumulate partial reduced results computed by threads in the 1331 // same warp using shfl_down. Using shfl_down is faster than directly using 1332 // atomic operations because shfl_down transfers the data between threads 1333 // using shared memory and threads in the same warp run in lock step (thus no 1334 // extra synchronization needed). See 1335 // https://devblogs.nvidia.com/parallelforall/faster-parallel-reductions-kepler/ 1336 // for details. The downside is, to produce correct results when using 1337 // shfl_down, we need to guarantee threads in the same warp work on input 1338 // elements with the same y, so the number of tiles in each row must be a 1339 // multiple of 32. 1340 // 1341 // 3. Specialize the case that the entire tile is in bounds. When that is 1342 // true, we don't need to emit "if(x<width)" inside the loop on 1343 // element_id_in_tile, which makes the code more friendly to optimizations 1344 // such as LICM. 1345 // 1346 // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x; 1347 // linear_index < depth * height * width_in_tiles; 1348 // linear_index += blockDim.x * gridDim.x) { 1349 // int x_in_tiles = linear_index % width_in_tiles; 1350 // int y = linear_index / width_in_tiles % height; 1351 // int z = linear_index / (height * width_in_tiles); 1352 // int warp_id = x_in_tiles / warpSize; 1353 // int lane_id = x_in_tiles % warpSize; 1354 // float partial_result = 0; 1355 // int x = warp_id * kTileSize * warpSize + lane_id; 1356 // if (width % (kTileSize * warpSize) == 0 || 1357 // x + (kTileSize - 1) * warpSize < width) { 1358 // // The entire tile is in bounds. 1359 // for (int element_id_in_tile = 0; element_id_in_tile < kTileSize; 1360 // ++element_id_in_tile, x += warpSize) { 1361 // partial_result = Reducer(partial_result, input[z][y][x]); 1362 // } 1363 // } else { 1364 // // The tile is partially in bounds. 1365 // for (int element_id_in_tile = 0; element_id_in_tile < kTileSize; 1366 // ++element_id_in_tile, x += warpSize) { 1367 // if (x < width) 1368 // partial_result = Reducer(partial_result, input[z][y][x]); 1369 // } 1370 // } 1371 // for (shuffle_distance = 16; shuffle_distance > 0; shuffle_distance /= 2) 1372 // partial_result = Reducer( 1373 // partial_result, 1374 // __shfl_down_sync(CUDA_WARP_ALL, partial_result, shuffle_distance)); 1375 // if (lane_id == 0) 1376 // AtomicReducer(&output[y], partial_result); 1377 // } 1378 // 1379 // Choose 8 as the tile size, which matches Eigen's RowReduceKernel. 1380 constexpr int64 kTileSize = 8; 1381 // Round the width in tiles up to the nearest multiple of kWarpSize, so that 1382 // the use of shfl_down is valid. 1383 const int64 width_in_tiles = 1384 RoundUpToNearest(CeilOfRatio(width, kTileSize), kWarpSize); 1385 1386 auto loop_body_emitter = 1387 [=](const llvm_ir::IrArray::Index& tile_index) -> Status { 1388 // Emit the loop body that reduces one tile. 1389 llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType( 1390 input_shape.element_type(), ir_emitter_context_->llvm_module()); 1391 llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( 1392 element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result"); 1393 { 1394 TF_ASSIGN_OR_RETURN(llvm::Value * init_ir_value, 1395 init_value_gen(llvm_ir::IrArray::Index({}))); 1396 ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); 1397 } 1398 1399 // Emit an inner for-loop that partially reduces the elements in the given 1400 // tile. 1401 llvm::Value* z = tile_index[0]; 1402 llvm::Value* y = tile_index[1]; 1403 llvm::Value* x_tile = tile_index[2]; 1404 llvm::Value* warp_id = ir_builder_.CreateUDiv( 1405 x_tile, ir_builder_.getInt64(kWarpSize), "warp_id"); 1406 llvm::Value* lane_id = ir_builder_.CreateURem( 1407 x_tile, ir_builder_.getInt64(kWarpSize), "lane_id"); 1408 1409 // The x-location of the last element in this tile. 1410 // last_x = lane_id + warpSize * (kTileSize - 1 + warp_id * kTileSize); 1411 llvm::Value* last_x = ir_builder_.CreateNSWAdd( 1412 lane_id, 1413 ir_builder_.CreateNSWMul( 1414 ir_builder_.getInt64(kWarpSize), 1415 ir_builder_.CreateNSWAdd( 1416 ir_builder_.getInt64(kTileSize - 1), 1417 ir_builder_.CreateNSWMul(warp_id, 1418 ir_builder_.getInt64(kTileSize))))); 1419 1420 auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status { 1421 std::unique_ptr<llvm_ir::ForLoop> tile_element_loop = 1422 llvm_ir::ForLoop::EmitForLoop("element_id_in_tile", 1423 ir_builder_.getInt64(0), 1424 ir_builder_.getInt64(kTileSize), 1425 ir_builder_.getInt64(1), &ir_builder_); 1426 1427 // Emit the body of the partial reduction loop. 1428 llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), 1429 &ir_builder_); 1430 // x = lane_id + warpSize * (element_id_in_tile + warp_id * kTileSize); 1431 llvm::Value* x = ir_builder_.CreateNSWAdd( 1432 lane_id, 1433 ir_builder_.CreateNSWMul( 1434 ir_builder_.getInt64(kWarpSize), 1435 ir_builder_.CreateNSWAdd( 1436 tile_element_loop->GetIndVarValue(), 1437 ir_builder_.CreateNSWMul(warp_id, 1438 ir_builder_.getInt64(kTileSize))))); 1439 1440 // Unless we know the tile is entirely in bounds, we have to emit a 1441 // x-in-bounds check before reading from the input. 1442 if (!tile_in_bounds) { 1443 llvm_ir::LlvmIfData if_x_in_bounds_data = llvm_ir::EmitIfThenElse( 1444 ir_builder_.CreateICmpULT(x, ir_builder_.getInt64(width)), 1445 "x_in_bounds", &ir_builder_); 1446 1447 // Points ir_builder_ to the then-block. 1448 llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block, 1449 &ir_builder_); 1450 } 1451 1452 // Emit code that reads the input element and accumulates it to the 1453 // partial reduction result. 1454 llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type); 1455 { 1456 // {z,y,x} is an index to input_3d_tensor_shape [depth,height,width]. We 1457 // need to convert that to an index to input_shape (the shape of the 1458 // operand of "reduce"). This conversion is composed of a transposition 1459 // from input_shape to normalized_input_shape and a reshape from 1460 // normalized_input_shape to input_3d_tensor_shape. 1461 const Shape normalized_input_shape = 1462 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( 1463 input_shape); 1464 auto input_shape_min2maj = LayoutUtil::MinorToMajor(input_shape); 1465 const std::vector<int64> transpose_dimension_mapping( 1466 input_shape_min2maj.rbegin(), input_shape_min2maj.rend()); 1467 const Shape input_3d_tensor_shape = 1468 ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(), 1469 {depth, height, width}); 1470 const llvm_ir::IrArray::Index input_3d_tensor_index( 1471 {z, y, x}, input_3d_tensor_shape, &ir_builder_); 1472 const llvm_ir::IrArray::Index input_index = 1473 input_3d_tensor_index 1474 .SourceIndexOfReshape(input_3d_tensor_shape, 1475 normalized_input_shape, &ir_builder_) 1476 .SourceIndexOfTranspose(normalized_input_shape, input_shape, 1477 transpose_dimension_mapping, 1478 &ir_builder_); 1479 TF_ASSIGN_OR_RETURN(llvm::Value * input_ir_value, 1480 input_gen(input_index)); 1481 ir_builder_.CreateStore(input_ir_value, input_address); 1482 } 1483 return EmitCallToNestedComputation( 1484 *reducer, {partial_reduction_result_address, input_address}, 1485 partial_reduction_result_address); 1486 }; 1487 1488 llvm::Value* tile_in_bounds = ir_builder_.CreateOr( 1489 ir_builder_.getInt1(width % (kTileSize * kWarpSize) == 0), 1490 ir_builder_.CreateICmpULT(last_x, ir_builder_.getInt64(width))); 1491 llvm_ir::LlvmIfData if_tile_in_bounds_data = 1492 llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_); 1493 llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, 1494 &ir_builder_); 1495 TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true)); 1496 llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block, 1497 &ir_builder_); 1498 TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false)); 1499 1500 // After the if-then-else statement on tile_in_bounds, emit calls to 1501 // shfl_down that accumulate the partial reduction results of all threads 1502 // from the warp. 1503 llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, 1504 &ir_builder_); 1505 int bit_width = llvm_ir::GetSizeInBits(element_ir_type); 1506 // bitcast cannot be applied to aggregate types (even packed ones), so we 1507 // instead bitcast addresses of load/store to intN* of the same bit-width. 1508 llvm::Type* shuffle_ir_type = element_ir_type->isStructTy() 1509 ? ir_builder_.getIntNTy(bit_width) 1510 : element_ir_type; 1511 for (int shuffle_distance = 16; shuffle_distance >= 1; 1512 shuffle_distance /= 2) { 1513 llvm::Value* partial_reduction_result = ir_builder_.CreateLoad( 1514 ir_builder_.CreateBitCast(partial_reduction_result_address, 1515 shuffle_ir_type->getPointerTo()), 1516 "partial_reduction_result"); 1517 llvm::Value* result_from_other_lane = ir_builder_.CreateAlloca( 1518 element_ir_type, nullptr, "result_from_other_lane"); 1519 ir_builder_.CreateStore( 1520 EmitShuffleDown(partial_reduction_result, 1521 ir_builder_.getInt32(shuffle_distance), &ir_builder_), 1522 ir_builder_.CreateBitCast(result_from_other_lane, 1523 shuffle_ir_type->getPointerTo())); 1524 TF_RETURN_IF_ERROR(EmitCallToNestedComputation( 1525 *reducer, {partial_reduction_result_address, result_from_other_lane}, 1526 partial_reduction_result_address)); 1527 } 1528 1529 const HloInstruction* output = 1530 reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; 1531 1532 // Emit an atomic operation that accumulates the partial reduction result of 1533 // lane 0 (which holds the partially accumulated result for its warp) to the 1534 // output element. 1535 llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( 1536 ir_builder_.CreateICmpEQ(lane_id, ir_builder_.getInt64(0)), 1537 "lane_id_is_zero", &ir_builder_); 1538 llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, 1539 &ir_builder_); 1540 llvm::Value* output_address = 1541 GetIrArray(*output, *output) 1542 .EmitArrayElementAddress( 1543 llvm_ir::IrArray::Index(y, output->shape(), &ir_builder_), 1544 &ir_builder_, "output_element_address"); 1545 return EmitAtomicOperationForNestedComputation( 1546 *reducer, output_address, partial_reduction_result_address); 1547 }; 1548 1549 // Emit a parallel loop that iterates through every input tiles. 1550 Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( 1551 reduce->shape().element_type(), {depth, height, width_in_tiles}, 1552 {2, 1, 0}); 1553 LaunchDimensions launch_dimensions = CalculateLaunchDimensions( 1554 tiled_input_shape, ir_emitter_context_->device_description()); 1555 CHECK(LastThunk()->kind() == Thunk::Kind::kSequential); 1556 UpdateLaunchDimensions( 1557 launch_dimensions, 1558 static_cast<SequentialThunk*>(LastThunk())->thunks().back().get(), 1559 ir_emitter_context_->llvm_module()); 1560 return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, 1561 launch_dimensions, &ir_builder_) 1562 .EmitLoop(IrName(reduce)); 1563 } 1564 1565 // Figures out whether `reduce` is a row or column reduction, and which 1566 // dimensions to reduce, and calls either `EmitRowReduction` or 1567 // `EmitColumnReduction` as appropriate. 1568 // Prerequisite: all the dimensions to keep are contiguous in the input layout 1569 // and, if `reduce` is fused, the fused subgraph is pure 1570 // elementwise. 1571 Status IrEmitterUnnested::EmitReductionToVector( 1572 HloInstruction* reduce, const Shape& input_shape, 1573 const llvm_ir::ElementGenerator& input_gen, 1574 const llvm_ir::ElementGenerator& init_value_gen, 1575 tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce, 1576 HloComputation* reducer) { 1577 // This emission requires "reduce" to have an input layout. It is either set 1578 // by LayoutAssignment (for a top-level kReduce) or by InstructionFusion (for 1579 // a fused kReduce). 1580 CHECK(input_shape.has_layout()) << "LayoutAssignment or InstructionFusion " 1581 "doesn't set the input layout of " 1582 << reduce->ToString(); 1583 1584 // Specialize multi-dimensional-array-to-vector reduction. 1585 std::vector<int64> input_dims_to_keep; 1586 for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape); 1587 ++input_dim) { 1588 if (std::find(dimensions_to_reduce.begin(), dimensions_to_reduce.end(), 1589 input_dim) == dimensions_to_reduce.end()) { 1590 input_dims_to_keep.push_back(input_dim); 1591 } 1592 } 1593 1594 // Sort the dimensions to keep from minor to major, to facilitate checking 1595 // whether another dimension is major or minor of them. 1596 std::sort(input_dims_to_keep.begin(), input_dims_to_keep.end(), 1597 [&input_shape](int64 dim_a, int64 dim_b) { 1598 return PositionInContainer(LayoutUtil::MinorToMajor(input_shape), 1599 dim_a) < 1600 PositionInContainer(LayoutUtil::MinorToMajor(input_shape), 1601 dim_b); 1602 }); 1603 // Now, if output rank is at least 1, `input_dims_to_keep.front()` is 1604 // minormost and `input_dims_to_keep.back()` is majormost. 1605 1606 // If the dimensions to keep are minormost, emit a column reduction. As all 1607 // the dimensions to keep are contiguous, by prerequisite of 1608 // `EmitReductionToVector`, we only need to check whether the minormost 1609 // dimension of the input is to keep. 1610 if (input_dims_to_keep.empty()) { 1611 return EmitReductionToScalar(reduce, input_shape, input_gen, init_value_gen, 1612 reducer); 1613 } else if (input_dims_to_keep.front() == 1614 LayoutUtil::Minor(input_shape.layout(), 0)) { 1615 // Column reduction. Treat the result of "input" as a matrix whose width 1616 // is the most minor dimension and height the product of other dimensions, 1617 // and treat "reduce" as a column reduction of the input matrix. 1618 const int64 width = ShapeUtil::ElementsIn(reduce->shape()); 1619 // "width" can be zero, so don't do 1620 // height = ShapeUtil::ElementsIn(input_shape) / width; 1621 int64 height = 1; 1622 for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape); 1623 ++input_dim) { 1624 if (!std::count(input_dims_to_keep.begin(), input_dims_to_keep.end(), 1625 input_dim)) { 1626 height *= input_shape.dimensions(input_dim); 1627 } 1628 } 1629 return EmitColumnReduction(height, width, reduce, input_shape, input_gen, 1630 init_value_gen, reducer); 1631 } else { 1632 // Reduce the row dimension of a matrix or reduce dimension 0 and 2 in a 1633 // 3D tensor. The size of dimension 1 (the height) is the size of the 1634 // dimension to keep, the size of dimension 0 (the depth) is the product 1635 // of dimensions that are more major than the dimension to keep, and the 1636 // size of dimension 2 (the width) is the product of more minor 1637 // dimensions. 1638 int64 depth = 1; 1639 int64 width = 1; 1640 for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape); 1641 ++input_dim) { 1642 if (PositionInContainer(LayoutUtil::MinorToMajor(input_shape), 1643 input_dim) > 1644 PositionInContainer(LayoutUtil::MinorToMajor(input_shape), 1645 input_dims_to_keep.back())) { 1646 depth *= input_shape.dimensions(input_dim); 1647 } else if (PositionInContainer(LayoutUtil::MinorToMajor(input_shape), 1648 input_dim) < 1649 PositionInContainer(LayoutUtil::MinorToMajor(input_shape), 1650 input_dims_to_keep.front())) { 1651 width *= input_shape.dimensions(input_dim); 1652 } 1653 } 1654 const int64 height = ShapeUtil::ElementsIn(reduce->shape()); 1655 return EmitRowReduction(depth, height, width, reduce, input_shape, 1656 input_gen, init_value_gen, reducer); 1657 } 1658 } 1659 1660 Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { 1661 auto input = reduce->operand(0); 1662 auto init_value = reduce->operand(1); 1663 tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce(reduce->dimensions()); 1664 HloComputation* reducer = reduce->to_apply(); 1665 // HandleReduce specializes reduction from a multi-dimensional array to a 1D 1666 // array. The specialized version requires an initializer thunk that 1667 // initializes the output array to the initial value of the reduce. 1668 if (IsReductionToVector(*reduce) && 1669 // NVPTX backend can't do atomic cmpxchg any narrower than 32 bits 1670 32 <= primitive_util::BitWidth(reduce->shape().element_type())) { 1671 std::vector<std::unique_ptr<Thunk>> thunks; 1672 thunks.emplace_back(BuildKernelThunk(reduce)); 1673 TF_RETURN_IF_ERROR(EmitInitializer( 1674 reduce, static_cast<KernelThunk*>(thunks.back().get()))); 1675 bindings_.UnbindAllLocalIrValues(); 1676 thunks.emplace_back(BuildKernelThunk(reduce)); 1677 thunk_sequence_->emplace_back( 1678 MakeUnique<SequentialThunk>(std::move(thunks), reduce)); 1679 return EmitReductionToVector( 1680 reduce, input->shape(), 1681 [&](const llvm_ir::IrArray::Index& index) { 1682 return GetIrArray(*input, *reduce) 1683 .EmitReadArrayElement(index, &ir_builder_); 1684 }, 1685 [&](const llvm_ir::IrArray::Index& index) { 1686 return GetIrArray(*init_value, *reduce) 1687 .EmitReadArrayElement(index, &ir_builder_); 1688 }, 1689 dimensions_to_reduce, reducer); 1690 } 1691 1692 thunk_sequence_->emplace_back(BuildKernelThunk(reduce)); 1693 return IrEmitter::HandleReduce(reduce); 1694 } 1695 1696 Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) { 1697 bool all_tuple_elements_have_buffer = 1698 c_all_of(tuple->operands(), [&](HloInstruction* tuple_element) { 1699 return ir_emitter_context_->buffer_assignment().HasTopLevelAllocation( 1700 tuple_element); 1701 }); 1702 // Tuples (especially tuples that are the final result of a computation) can 1703 // be so huge that if we were to emit a kernel that took each tuple element as 1704 // a parameter, we would exceed the max allowable number of parameters to a 1705 // GPU kernel, b/31336476. As an optimization, if all tuple elements have a 1706 // buffer, we collect their buffer addresses in a host array, and then copy 1707 // that array to the tuple's buffer. 1708 // 1709 // Some tuple elements (e.g. const or bitcast of const) might not have a 1710 // buffer -- their contents are stored in code. In that case, we fall back to 1711 // emitting kernels which have access to their buffer addresses in code. 1712 if (all_tuple_elements_have_buffer) { 1713 std::vector<BufferAllocation::Slice> tuple_element_buffers; 1714 for (const HloInstruction* tuple_element : tuple->operands()) { 1715 tuple_element_buffers.push_back(GetAllocationSlice(*tuple_element)); 1716 } 1717 thunk_sequence_->emplace_back(MakeUnique<TupleThunk>( 1718 tuple_element_buffers, GetAllocationSlice(*tuple), tuple)); 1719 return Status::OK(); 1720 } 1721 thunk_sequence_->emplace_back(BuildKernelThunk(tuple)); 1722 return IrEmitter::HandleTuple(tuple); 1723 } 1724 1725 Status IrEmitterUnnested::HandleGetTupleElement(HloInstruction*) { 1726 // GetTupleElement IR is emitted in the IR context of the user instruction, 1727 // and so we do not build a kernel for GetTupleElement instructions. 1728 return Status::OK(); 1729 } 1730 1731 Status IrEmitterUnnested::HandleSelectAndScatter( 1732 HloInstruction* select_and_scatter) { 1733 CHECK_EQ(select_and_scatter->operand_count(), 3); 1734 const auto* operand = select_and_scatter->operand(0); 1735 const auto* source = select_and_scatter->operand(1); 1736 const Window& window = select_and_scatter->window(); 1737 PrimitiveType operand_element_type = operand->shape().element_type(); 1738 const int64 rank = ShapeUtil::Rank(operand->shape()); 1739 CHECK_EQ(rank, ShapeUtil::Rank(source->shape())); 1740 CHECK_EQ(rank, window.dimensions_size()); 1741 1742 { 1743 std::vector<std::unique_ptr<Thunk>> thunks; 1744 thunks.emplace_back(BuildKernelThunk(select_and_scatter)); 1745 TF_RETURN_IF_ERROR(EmitInitializer( 1746 select_and_scatter, static_cast<KernelThunk*>(thunks.back().get()))); 1747 bindings_.UnbindAllLocalIrValues(); 1748 thunks.emplace_back(BuildKernelThunk(select_and_scatter)); 1749 thunk_sequence_->emplace_back( 1750 MakeUnique<SequentialThunk>(std::move(thunks), select_and_scatter)); 1751 } 1752 1753 // TODO(b/31410564): Implement dilation rate for select-and-scatter. 1754 if (window_util::HasDilation(window)) { 1755 return Unimplemented( 1756 "Dilation for SelectAndScatter not implemented on GPU."); 1757 } 1758 1759 // kSelectAndScatter is implemented as two kernel launches: the first launch 1760 // initializes the output array to the given initial value, 1761 // and the second accumulates the "source" matrix to the 1762 // selected elements in the output array. The first launch is already 1763 // implemented by the initializer thunk generated earlier, so this function 1764 // only needs to take care of the select-and-scatter part. 1765 // 1766 // Pseudo code for select-and-scatter: 1767 // 1768 // for (coordinates S in the source): # This loop is parallel. 1769 // initialized_flag = false 1770 // for (coordinates W in the window): 1771 // I = S * stride + W - pad_low 1772 // if I within bounds of operand: 1773 // if !(initialized_flag and select(selected_value, operand(I))): 1774 // selected_value = operand(I) 1775 // selected_index = I 1776 // initialized_flag = true 1777 // output(selected_index) = scatter(output(selected_index), source(S)) 1778 auto loop_body_emitter = 1779 [=](const llvm_ir::IrArray::Index& source_index) -> Status { 1780 // Allocate space to keep the currently selected value, its index, and a 1781 // boolean flag if the value is initialized. The initialized_flag is set 1782 // false. 1783 llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry( 1784 llvm_ir::PrimitiveTypeToIrType(operand_element_type, 1785 ir_emitter_context_->llvm_module()), 1786 "selected_value_address", &ir_builder_); 1787 llvm::Value* selected_index_address = 1788 llvm_ir::EmitAllocaAtFunctionEntryWithCount( 1789 ir_builder_.getInt64Ty(), ir_builder_.getInt32(rank), 1790 "selected_index_address", &ir_builder_); 1791 llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry( 1792 ir_builder_.getInt1Ty(), "initialized_flag_address", &ir_builder_); 1793 ir_builder_.CreateStore(ir_builder_.getInt1(false), 1794 initialized_flag_address); 1795 1796 // Create the inner loop to iterate over the window. 1797 llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "inner"), 1798 &ir_builder_); 1799 std::vector<int64> window_size; 1800 for (const auto& dim : window.dimensions()) { 1801 window_size.push_back(dim.size()); 1802 CHECK_GT(dim.size(), 0); 1803 } 1804 const llvm_ir::IrArray::Index window_index = window_loops.AddLoopsForShape( 1805 ShapeUtil::MakeShape(operand_element_type, window_size), "window"); 1806 llvm_ir::SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(), 1807 &ir_builder_); 1808 1809 // Compute the operand index to visit and evaluate the condition whether the 1810 // operand index is within the bounds. The unsigned comparison includes 1811 // checking whether the operand index >= 0. 1812 llvm_ir::IrArray::Index operand_index(source_index.size()); 1813 llvm::Value* in_bounds_condition = ir_builder_.getInt1(true); 1814 for (int64 i = 0; i < rank; ++i) { 1815 llvm::Value* strided_index = ir_builder_.CreateNSWMul( 1816 source_index[i], ir_builder_.getInt64(window.dimensions(i).stride())); 1817 operand_index[i] = ir_builder_.CreateNSWSub( 1818 ir_builder_.CreateNSWAdd(strided_index, window_index[i]), 1819 ir_builder_.getInt64(window.dimensions(i).padding_low())); 1820 llvm::Value* index_condition = ir_builder_.CreateICmpULT( 1821 operand_index[i], 1822 ir_builder_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); 1823 in_bounds_condition = 1824 ir_builder_.CreateAnd(in_bounds_condition, index_condition); 1825 } 1826 CHECK(in_bounds_condition != nullptr); 1827 1828 // Only need to do something if the operand index is within the bounds. 1829 // First check if the initialized_flag is set. 1830 llvm_ir::LlvmIfData if_in_bounds = 1831 llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &ir_builder_); 1832 llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, &ir_builder_); 1833 llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse( 1834 ir_builder_.CreateLoad(initialized_flag_address), "initialized", 1835 &ir_builder_); 1836 1837 // If the initialized_flag is false, initialize the selected value and index 1838 // with the currently visiting operand. 1839 llvm_ir::SetToFirstInsertPoint(if_initialized.false_block, &ir_builder_); 1840 const auto save_operand_index = [&]( 1841 const llvm_ir::IrArray::Index& operand_index) { 1842 for (int64 i = 0; i < rank; ++i) { 1843 llvm::Value* selected_index_address_slot = 1844 ir_builder_.CreateInBoundsGEP(selected_index_address, 1845 {ir_builder_.getInt32(i)}); 1846 ir_builder_.CreateStore(operand_index[i], selected_index_address_slot); 1847 } 1848 }; 1849 llvm_ir::IrArray operand_array = GetIrArray(*operand, *select_and_scatter); 1850 llvm::Value* operand_data = 1851 operand_array.EmitReadArrayElement(operand_index, &ir_builder_); 1852 ir_builder_.CreateStore(operand_data, selected_value_address); 1853 save_operand_index(operand_index); 1854 ir_builder_.CreateStore(ir_builder_.getInt1(true), 1855 initialized_flag_address); 1856 1857 // If the initialized_flag is true, call the `select` function to 1858 // potentially update the selected value and index with the currently 1859 // visiting operand. 1860 llvm_ir::SetToFirstInsertPoint(if_initialized.true_block, &ir_builder_); 1861 const Shape output_shape = ShapeUtil::MakeShape(PRED, {}); 1862 llvm::Value* operand_address = 1863 operand_array.EmitArrayElementAddress(operand_index, &ir_builder_); 1864 llvm::Value* select_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry( 1865 llvm_ir::PrimitiveTypeToIrType(PRED, 1866 ir_emitter_context_->llvm_module()), 1867 "select_return_buffer", &ir_builder_); 1868 TF_RETURN_IF_ERROR(EmitCallToNestedComputation( 1869 *select_and_scatter->select(), 1870 {selected_value_address, operand_address}, select_return_buffer)); 1871 llvm::Value* result = ir_builder_.CreateLoad(select_return_buffer); 1872 1873 // If the 'select' function returns false, update the selected value and the 1874 // index to the currently visiting operand. 1875 llvm::Value* cond = ir_builder_.CreateICmpNE( 1876 result, 1877 llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType( 1878 PRED, ir_emitter_context_->llvm_module()), 1879 0), 1880 "boolean_predicate"); 1881 llvm_ir::LlvmIfData if_select_lhs = 1882 llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &ir_builder_); 1883 llvm_ir::SetToFirstInsertPoint(if_select_lhs.false_block, &ir_builder_); 1884 ir_builder_.CreateStore(ir_builder_.CreateLoad(operand_address), 1885 selected_value_address); 1886 save_operand_index(operand_index); 1887 1888 // After iterating over the window elements, scatter the source element to 1889 // the selected index of the output. The value we store at the output 1890 // location is computed by calling the `scatter` function with the source 1891 // value and the current output value. 1892 llvm_ir::SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(), 1893 &ir_builder_); 1894 llvm_ir::IrArray::Index selected_index; 1895 for (int64 i = 0; i < rank; ++i) { 1896 llvm::Value* selected_index_address_slot = ir_builder_.CreateInBoundsGEP( 1897 selected_index_address, {ir_builder_.getInt32(i)}); 1898 selected_index.push_back( 1899 ir_builder_.CreateLoad(selected_index_address_slot)); 1900 } 1901 llvm::Value* source_value_address = 1902 GetIrArray(*source, *select_and_scatter) 1903 .EmitArrayElementAddress(source_index, &ir_builder_); 1904 llvm::Value* output_value_address = 1905 GetIrArray(*select_and_scatter, *select_and_scatter) 1906 .EmitArrayElementAddress(selected_index, &ir_builder_); 1907 return EmitAtomicOperationForNestedComputation( 1908 *select_and_scatter->scatter(), output_value_address, 1909 source_value_address); 1910 }; 1911 1912 LaunchDimensions launch_dimensions = CalculateLaunchDimensions( 1913 source->shape(), ir_emitter_context_->device_description()); 1914 UpdateLaunchDimensions( 1915 launch_dimensions, 1916 // IrEmitterUnnested implements kSelectAndScatter as a SequentialThunk 1917 // consisting of two thunks, an initializer KernelThunk that initializes 1918 // the output and another KernelThunk that accumulates the scattered 1919 // elements. 1920 static_cast<SequentialThunk*>(LastThunk())->thunks().back().get(), 1921 ir_emitter_context_->llvm_module()); 1922 return ParallelLoopEmitter(loop_body_emitter, source->shape(), 1923 launch_dimensions, &ir_builder_) 1924 .EmitLoop(IrName(select_and_scatter)); 1925 } 1926 1927 Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while) { 1928 HloComputation* condition = xla_while->while_condition(); 1929 TF_RET_CHECK(ShapeUtil::IsScalar(condition->root_instruction()->shape()) && 1930 condition->root_instruction()->shape().element_type() == PRED) 1931 << "While condition computation must return bool"; 1932 // Build ForThunk for conformant while loops, otherwise build WhileThunk. 1933 auto result = CanTransformWhileToFor(xla_while); 1934 if (result.ok()) { 1935 auto tuple = result.ConsumeValueOrDie(); 1936 // loop_trip_count = (limit - start + increment - 1) / increment 1937 const int64 loop_trip_count = 1938 (std::get<1>(tuple) - std::get<0>(tuple) + std::get<2>(tuple) - 1) / 1939 std::get<2>(tuple); 1940 thunk_sequence_->emplace_back(BuildForThunk(xla_while, loop_trip_count)); 1941 VLOG(3) << "Built ForThunk for while: " << xla_while->name(); 1942 } else { 1943 thunk_sequence_->emplace_back(BuildWhileThunk(xla_while)); 1944 VLOG(3) << "Built WhileThunk for while: " << xla_while->name() 1945 << " while-to-for transform status: " << result.status(); 1946 } 1947 return Status::OK(); 1948 } 1949 1950 Status IrEmitterUnnested::HandleRng(HloInstruction* random) { 1951 thunk_sequence_->push_back(BuildKernelThunk(random)); 1952 return IrEmitter::HandleRng(random); 1953 } 1954 1955 Status IrEmitterUnnested::HandleSelect(HloInstruction* select) { 1956 thunk_sequence_->push_back(BuildKernelThunk(select)); 1957 return IrEmitter::HandleSelect(select); 1958 } 1959 1960 Status IrEmitterUnnested::HandleInfeed(HloInstruction* infeed) { 1961 thunk_sequence_->emplace_back(BuildInfeedThunk(infeed)); 1962 return Status::OK(); 1963 } 1964 1965 // Figures out how to access the buffers for all subshapes of hlo's operands and 1966 // for hlo itself (i.e. all the buffers produced by HLO). 1967 // 1968 // Returns a map keyed on the pair {HloInstruction, ShapeIndex}. The value for 1969 // this key is a pair {Slice, ShapeIndex}, where the slice tells you the root 1970 // buffer to look in, and the ShapeIndex describes how to dereference starting 1971 // at that buffer to get to the buffer in question. 1972 // 1973 // For example, if {hlo, {1}} is mapped to {slice, {3, 4}}, then the buffer for 1974 // hlo at ShapeIndex {1} (i.e. the buffer for the second tuple element of hlo) 1975 // is found at slice[3][4]. That is, slice is a void***, which we dereference 1976 // twice -- first at index 3, and then at index 4 -- to get the address of our 1977 // buffer. 1978 // 1979 // This function conservatively assumes that we'll touch all sub-buffers of 1980 // every operand and of the output. 1981 static std::map<std::pair<const HloInstruction*, ShapeIndex>, 1982 std::pair<BufferAllocation::Slice, ShapeIndex>> 1983 GetHloBufferSlices(const HloInstruction* hlo, 1984 const BufferAssignment& buffer_assn) { 1985 std::map<std::pair<const HloInstruction*, ShapeIndex>, 1986 std::pair<BufferAllocation::Slice, ShapeIndex>> 1987 slices; 1988 1989 // Tries to find a slice plus an array of indices i1, ..., iN such that the 1990 // sub-buffer for instr at index can be found at slice[i1]...[iN]. 1991 auto find_slice_for = [&](const HloInstruction* instr, 1992 const ShapeIndex& index) 1993 -> optional<std::pair<BufferAllocation::Slice, ShapeIndex>> { 1994 // Simple, common case: Is the buffer for instr known at runtime? If so, 1995 // we're done. 1996 auto slice = GetKnownAtRuntimeSlice(instr, index, buffer_assn); 1997 if (slice.has_value()) { 1998 return {{*slice, ShapeIndex()}}; 1999 } 2000 2001 // If we don't know the buffer for instr at index, see if we know the buffer 2002 // for instr at index without its last element. If so, we can dynamically 2003 // find the buffer for instr by dereferencing a pointer in that buffer. 2004 // Continue looking this way until we run out of elements in 'index'. 2005 ShapeIndex new_index = index; 2006 ShapeIndex gte_indices; 2007 while (!new_index.empty()) { 2008 gte_indices.push_front(new_index.back()); 2009 new_index.pop_back(); 2010 auto slice = GetKnownAtRuntimeSlice(instr, new_index, buffer_assn); 2011 if (slice.has_value()) { 2012 return {{*slice, gte_indices}}; 2013 } 2014 } 2015 2016 // If *that* didn't work, check whether instr is a GTE instruction. If it 2017 // is, see if we can get a buffer for its parent, and continue walking up 2018 // parents until we find a defined buffer or we hit something that's not a 2019 // GTE. 2020 const HloInstruction* parent = instr; 2021 while (parent->opcode() == HloOpcode::kGetTupleElement) { 2022 gte_indices.push_front(parent->tuple_index()); 2023 parent = parent->operand(0); 2024 2025 auto slice = GetKnownAtRuntimeSlice(parent, {}, buffer_assn); 2026 if (slice.has_value()) { 2027 return {{*slice, gte_indices}}; 2028 } 2029 } 2030 2031 return nullopt; 2032 }; 2033 2034 // Adds entries for all subshapes of instr to `slices`. 2035 auto add_slices_for = [&](const HloInstruction* instr) { 2036 // GPU constants don't have buffers; don't bother looking for one. 2037 if (instr->IsConstant()) { 2038 return; 2039 } 2040 2041 ShapeUtil::ForEachSubshape( 2042 instr->shape(), [&](const Shape& /*shape*/, const ShapeIndex& index) { 2043 if (slices.count({instr, index})) { 2044 // HLOs can have duplicate operands; don't bother redoing work. 2045 return; 2046 } 2047 auto maybe_slice = find_slice_for(instr, index); 2048 if (maybe_slice.has_value()) { 2049 slices[{instr, index}] = *maybe_slice; 2050 } else { 2051 VLOG(1) << "Couldn't find buffer for " << instr->ToString() 2052 << " at index " << index.ToString(); 2053 } 2054 }); 2055 }; 2056 2057 add_slices_for(hlo); 2058 for (const HloInstruction* operand : hlo->operands()) { 2059 // Conservatively assume we'll need the buffers for all subshapes of the 2060 // operand. 2061 add_slices_for(operand); 2062 } 2063 2064 return slices; 2065 } 2066 2067 Status IrEmitterUnnested::HandleGather(HloInstruction* gather) { 2068 // TODO(b/72710576): Gather is not implemented on GPUs 2069 return Unimplemented("Gather is not implemented on GPUs."); 2070 } 2071 2072 std::unique_ptr<Thunk> IrEmitterUnnested::BuildKernelThunk( 2073 const HloInstruction* inst) { 2074 const BufferAssignment& buffer_assn = 2075 ir_emitter_context_->buffer_assignment(); 2076 2077 std::map<std::pair<const HloInstruction*, ShapeIndex>, 2078 std::pair<BufferAllocation::Slice, ShapeIndex>> 2079 hlo_slices = GetHloBufferSlices(inst, buffer_assn); 2080 2081 // Figure out which buffer allocations need to be passed as arguments to our 2082 // kernel. This is simply all of the allocations referenced in hlo_slices, 2083 // plus the XLA temp buffer (if we have it). We always include the temp 2084 // buffer because even if the kernel itself doesn't use it, a nested 2085 // subcomputation within the kernel (e.g. a kMap's computation) might. 2086 std::unordered_set<const BufferAllocation*> buffers_needed; 2087 for (const auto& kv : hlo_slices) { 2088 buffers_needed.insert(kv.second.first.allocation()); 2089 } 2090 tensorflow::gtl::optional<const BufferAllocation*> temp_buffer; 2091 for (const BufferAllocation& alloc : buffer_assn.Allocations()) { 2092 if (alloc.IsPreallocatedTempBuffer()) { 2093 if (!temp_buffer.has_value()) { 2094 temp_buffer = &alloc; 2095 } else { 2096 LOG(FATAL) << "Multiple temp buffers found, but only one is allowed!"; 2097 } 2098 } 2099 } 2100 if (temp_buffer.has_value()) { 2101 buffers_needed.insert(*temp_buffer); 2102 } 2103 2104 // We'll pass a pointer to each of the elements of `buffers` to our kernel, in 2105 // this order. 2106 std::vector<const BufferAllocation*> buffers(buffers_needed.begin(), 2107 buffers_needed.end()); 2108 std::sort(buffers.begin(), buffers.end(), 2109 [](const BufferAllocation* a, const BufferAllocation* b) { 2110 return a->index() < b->index(); 2111 }); 2112 2113 llvm::Function* kernel = BuildKernelPrototype(*inst, buffers); 2114 2115 // Build a map from a BufferAllocation to the corresponding argument in our 2116 // kernel. 2117 std::unordered_map<const BufferAllocation*, llvm::Value*> kernel_args; 2118 { 2119 auto arg_it = kernel->arg_begin(); 2120 auto buffers_it = buffers.begin(); 2121 for (; arg_it != kernel->arg_end(); ++arg_it, ++buffers_it) { 2122 kernel_args[*buffers_it] = arg_it; 2123 } 2124 } 2125 2126 // For each buffer our kernel might want to touch, bind it to a value derived 2127 // from our kernel args. 2128 for (const auto& kv : hlo_slices) { 2129 const HloInstruction* instr = kv.first.first; 2130 const ShapeIndex& index = kv.first.second; 2131 const BufferAllocation::Slice& slice = kv.second.first; 2132 const ShapeIndex& gte_index = kv.second.second; 2133 2134 VLOG(3) << "Buffer for " << instr->ToString() << " at " << index.ToString() 2135 << " is found in slice " << slice.ToString() << " at GTE index " 2136 << gte_index.ToString(); 2137 2138 llvm::Value* loc = 2139 ir_builder_.CreateInBoundsGEP(kernel_args.at(slice.allocation()), 2140 {ir_builder_.getInt64(slice.offset())}); 2141 2142 // If gte_index is nonempty, we have to dereference `loc` to get to the 2143 // value we're ultimately interested in. 2144 llvm::Type* int8_double_pointer = 2145 llvm::PointerType::get(ir_builder_.getInt8PtrTy(), /*AddressSpace=*/0); 2146 for (int64 idx : gte_index) { 2147 loc = ir_builder_.CreateBitCast(loc, int8_double_pointer); 2148 loc = ir_builder_.CreateLoad( 2149 ir_builder_.CreateInBoundsGEP(loc, {ir_builder_.getInt64(idx)})); 2150 } 2151 2152 bindings_.BindHloToIrValue(*instr, loc, index); 2153 } 2154 2155 // Bind the temp buffer so that nested subcomputations can find it if they 2156 // need. 2157 if (temp_buffer.has_value()) { 2158 bindings_.SetTempBufferBase(kernel_args.at(*temp_buffer)); 2159 } else { 2160 bindings_.SetTempBufferBase( 2161 llvm::ConstantPointerNull::get(ir_builder_.getInt8PtrTy())); 2162 } 2163 2164 return MakeUnique<KernelThunk>(buffers, llvm_ir::AsString(kernel->getName()), 2165 inst); 2166 } 2167 2168 std::unique_ptr<Thunk> IrEmitterUnnested::BuildHostToDeviceCopyThunk( 2169 const HloInstruction* inst) { 2170 const HloInstruction* operand = inst->operand(0); 2171 CHECK_EQ(HloOpcode::kConstant, operand->opcode()); 2172 return MakeUnique<HostToDeviceCopyThunk>( 2173 /*source_address=*/operand->literal().untyped_data(), 2174 /*destination_buffer=*/GetAllocationSlice(*inst), 2175 /*mem_size=*/ 2176 llvm_ir::ByteSizeOf(operand->shape(), 2177 ir_emitter_context_->llvm_module()->getDataLayout()), 2178 inst); 2179 } 2180 2181 std::unique_ptr<Thunk> IrEmitterUnnested::BuildDeviceToDeviceCopyThunk( 2182 const HloInstruction* inst) { 2183 const HloInstruction* operand = inst->operand(0); 2184 return MakeUnique<DeviceToDeviceCopyThunk>( 2185 /*source_address=*/GetAllocationSlice(*operand), 2186 /*destination_buffer=*/GetAllocationSlice(*inst), 2187 /*mem_size=*/ 2188 llvm_ir::ByteSizeOf(operand->shape(), 2189 ir_emitter_context_->llvm_module()->getDataLayout()), 2190 inst); 2191 } 2192 2193 std::unique_ptr<Thunk> IrEmitterUnnested::BuildInfeedThunk( 2194 const HloInstruction* inst) { 2195 CHECK_EQ(HloOpcode::kInfeed, inst->opcode()); 2196 2197 std::vector<BufferAllocation::Slice> tuple_element_buffers; 2198 for (int64 i = 0; i < inst->shape().tuple_shapes_size(); ++i) { 2199 BufferAllocation::Slice buffer = ir_emitter_context_->buffer_assignment() 2200 .GetUniqueSlice(inst, {i}) 2201 .ConsumeValueOrDie(); 2202 tuple_element_buffers.push_back(buffer); 2203 } 2204 2205 return MakeUnique<InfeedThunk>( 2206 tuple_element_buffers, 2207 /*destination_buffer=*/GetAllocationSlice(*inst), inst); 2208 } 2209 2210 std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk( 2211 const HloInstruction* inst) { 2212 if (inst->opcode() == HloOpcode::kDot) { 2213 const HloInstruction* lhs = inst->operand(0); 2214 const HloInstruction* rhs = inst->operand(1); 2215 return MakeUnique<GemmThunk>( 2216 GetAllocationSlice(*lhs), // The buffer assigned to LHS. 2217 GetAllocationSlice(*rhs), // The buffer assigned to RHS. 2218 GetAllocationSlice(*inst), // The output buffer. 2219 lhs->shape(), // The shape of LHS. 2220 rhs->shape(), // The shape of RHS. 2221 inst->shape(), // The shape of the output. 2222 false, // Do not transpose LHS. 2223 false, // Do not transpose RHS. 2224 inst); 2225 } 2226 2227 if (inst->opcode() == HloOpcode::kFusion) { 2228 const HloInstruction* dot = inst->fused_expression_root(); 2229 DCHECK(dot->opcode() == HloOpcode::kDot); 2230 const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0)); 2231 const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1)); 2232 DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter && 2233 rhs_parameter->opcode() == HloOpcode::kParameter); 2234 const HloInstruction* lhs = 2235 inst->operand(lhs_parameter->parameter_number()); 2236 const HloInstruction* rhs = 2237 inst->operand(rhs_parameter->parameter_number()); 2238 2239 return MakeUnique<GemmThunk>( 2240 GetAllocationSlice(*lhs), // The buffer assigned to LHS. 2241 GetAllocationSlice(*rhs), // The buffer assigned to RHS. 2242 GetAllocationSlice(*inst), // The output buffer. 2243 lhs->shape(), // The shape of LHS. 2244 rhs->shape(), // The shape of RHS. 2245 inst->shape(), // The shape of the output. 2246 dot->operand(0)->IsRank2Transpose(), // Transpose LHS. 2247 dot->operand(1)->IsRank2Transpose(), // Trasnpose RHS. 2248 inst); 2249 } 2250 2251 LOG(FATAL) << "Cannot build a GemmThunk for " << inst->ToString(); 2252 } 2253 2254 std::unique_ptr<Thunk> IrEmitterUnnested::BuildFftThunk( 2255 const HloInstruction* inst) { 2256 const HloInstruction* operand = inst->operand(0); 2257 return MakeUnique<FftThunk>(inst->fft_type(), inst->fft_length(), 2258 /*input_buffer=*/GetAllocationSlice(*operand), 2259 /*output_buffer=*/GetAllocationSlice(*inst), 2260 /*input_shape=*/operand->shape(), 2261 /*output_shape=*/inst->shape(), inst); 2262 } 2263 2264 Status IrEmitterUnnested::EmitInitializer(const HloInstruction* hlo, 2265 KernelThunk* thunk) { 2266 bool fused = HloOpcode::kFusion == hlo->opcode(); 2267 2268 const HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo; 2269 CHECK(inst->opcode() == HloOpcode::kSelectAndScatter || 2270 inst->opcode() == HloOpcode::kReduce); 2271 const HloInstruction* init_value = nullptr; 2272 switch (inst->opcode()) { 2273 case HloOpcode::kSelectAndScatter: 2274 init_value = inst->operand(2); 2275 break; 2276 case HloOpcode::kReduce: 2277 init_value = inst->operand(1); 2278 break; 2279 default: 2280 LOG(FATAL) << "Opcode " << inst->opcode() 2281 << " should not need an initializer."; 2282 } 2283 2284 if (fused && init_value->opcode() == HloOpcode::kParameter) { 2285 init_value = hlo->operand(init_value->parameter_number()); 2286 } 2287 2288 return EmitTargetElementLoopInThunk( 2289 *hlo, 2290 [=](const llvm_ir::IrArray::Index& index) { 2291 return GetIrArray(*init_value, *hlo) 2292 .EmitReadArrayElement(index, &ir_builder_); 2293 }, 2294 thunk); 2295 } 2296 2297 namespace { 2298 2299 // Checks that the buffers corresponding to the given two HLOs share the same 2300 // allocation. 2301 Status CheckHloBuffersShareAllocation( 2302 const HloInstruction* a, const HloInstruction* b, const ShapeIndex& index, 2303 const BufferAssignment& buffer_assignment) { 2304 const BufferAllocation::Slice slice_a = 2305 buffer_assignment.GetUniqueSlice(a, index).ConsumeValueOrDie(); 2306 const BufferAllocation::Slice slice_b = 2307 buffer_assignment.GetUniqueSlice(b, index).ConsumeValueOrDie(); 2308 if (slice_a != slice_b) { 2309 return InternalError( 2310 "instruction %s %s does not share allocation with instruction %s %s", 2311 a->ToString().c_str(), slice_a.ToString().c_str(), 2312 b->ToString().c_str(), slice_b.ToString().c_str()); 2313 } 2314 return Status::OK(); 2315 } 2316 2317 // Checks that all buffers used during while loop iteration share the same 2318 // buffer allocation. This includes buffers for while result, while init 2319 // operand, condition parameter, body parameter and body result. 2320 // Returns OK on success, error status otherwise. 2321 Status CheckWhileBuffersShareAllocation( 2322 const HloInstruction* xla_while, 2323 const BufferAssignment& buffer_assignment) { 2324 return ShapeUtil::ForEachSubshapeWithStatus( 2325 xla_while->shape(), 2326 [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status { 2327 const HloInstruction* condition_parameter = 2328 xla_while->while_condition()->parameter_instruction(0); 2329 const HloComputation* body = xla_while->while_body(); 2330 const HloInstruction* body_parameter = body->parameter_instruction(0); 2331 const HloInstruction* body_result = body->root_instruction(); 2332 TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation( 2333 xla_while, xla_while->operand(0), index, buffer_assignment)); 2334 TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation( 2335 xla_while, condition_parameter, index, buffer_assignment)); 2336 TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation( 2337 xla_while, body_parameter, index, buffer_assignment)); 2338 TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation( 2339 xla_while, body_result, index, buffer_assignment)); 2340 return Status::OK(); 2341 }); 2342 } 2343 2344 // Checks that the buffers used in a conditional instruction are shared with the 2345 // operands and result as follows: 2346 // * The result buffer of the conditional should share the allocation with the 2347 // result buffers of the true and false computations. 2348 // * The buffer of operand 1 should share the allocation with the buffer of 2349 // the parameter 0 instruction of the true computation. 2350 // * The buffer of operand 2 should share the allocation with the buffer of 2351 // the parameter 0 instruction of the false computation. 2352 Status CheckConditionalBuffersShareAllocation( 2353 const HloInstruction* conditional, 2354 const BufferAssignment& buffer_assignment) { 2355 TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( 2356 conditional->shape(), 2357 [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status { 2358 TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation( 2359 conditional, conditional->true_computation()->root_instruction(), 2360 index, buffer_assignment)); 2361 TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation( 2362 conditional, conditional->false_computation()->root_instruction(), 2363 index, buffer_assignment)); 2364 return Status::OK(); 2365 })); 2366 TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( 2367 conditional->operand(1)->shape(), 2368 [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status { 2369 return CheckHloBuffersShareAllocation( 2370 conditional->operand(1), 2371 conditional->true_computation()->parameter_instruction(0), index, 2372 buffer_assignment); 2373 })); 2374 TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( 2375 conditional->operand(2)->shape(), 2376 [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status { 2377 return CheckHloBuffersShareAllocation( 2378 conditional->operand(2), 2379 conditional->false_computation()->parameter_instruction(0), index, 2380 buffer_assignment); 2381 })); 2382 return Status::OK(); 2383 } 2384 2385 } // namespace 2386 2387 std::unique_ptr<Thunk> IrEmitterUnnested::BuildWhileThunk( 2388 const HloInstruction* hlo) { 2389 // Check that all while-related buffers share an allocation. 2390 TF_CHECK_OK(CheckWhileBuffersShareAllocation( 2391 hlo, ir_emitter_context_->buffer_assignment())); 2392 2393 // Generate thunk sequence for while 'condition'. 2394 HloComputation* condition = hlo->while_condition(); 2395 IrEmitterUnnested ir_emitter_condition(hlo_module_config_, condition, 2396 ir_emitter_context_); 2397 TF_CHECK_OK(condition->root_instruction()->Accept(&ir_emitter_condition)); 2398 2399 // Generate thunk sequence for while 'body'. 2400 HloComputation* body = hlo->while_body(); 2401 IrEmitterUnnested ir_emitter_body(hlo_module_config_, body, 2402 ir_emitter_context_); 2403 TF_CHECK_OK(body->root_instruction()->Accept(&ir_emitter_body)); 2404 2405 return MakeUnique<WhileThunk>( 2406 GetAllocationSlice(*condition->root_instruction()), // cond result 2407 ir_emitter_condition.ConsumeThunkSequence(), 2408 ir_emitter_body.ConsumeThunkSequence(), hlo); 2409 } 2410 2411 std::unique_ptr<Thunk> IrEmitterUnnested::BuildForThunk( 2412 const HloInstruction* hlo, const int64 loop_limit) { 2413 // Check that all while-related buffers share an allocation. 2414 TF_CHECK_OK(CheckWhileBuffersShareAllocation( 2415 hlo, ir_emitter_context_->buffer_assignment())); 2416 2417 // Generate thunk sequence for while 'body' (will be used a For loop body). 2418 HloComputation* body = hlo->while_body(); 2419 IrEmitterUnnested ir_emitter_body(hlo_module_config_, body, 2420 ir_emitter_context_); 2421 TF_CHECK_OK(body->root_instruction()->Accept(&ir_emitter_body)); 2422 2423 return MakeUnique<ForThunk>(loop_limit, 2424 ir_emitter_body.ConsumeThunkSequence(), hlo); 2425 } 2426 2427 std::unique_ptr<Thunk> IrEmitterUnnested::BuildConditionalThunk( 2428 const HloInstruction* hlo) { 2429 // Check that the buffers used in conditional are shared with the operands and 2430 // result appropriately. 2431 TF_CHECK_OK(CheckConditionalBuffersShareAllocation( 2432 hlo, ir_emitter_context_->buffer_assignment())); 2433 2434 HloComputation* true_computation = hlo->true_computation(); 2435 IrEmitterUnnested ir_emitter_true(hlo_module_config_, true_computation, 2436 ir_emitter_context_); 2437 TF_CHECK_OK(true_computation->root_instruction()->Accept(&ir_emitter_true)); 2438 2439 HloComputation* false_computation = hlo->false_computation(); 2440 IrEmitterUnnested ir_emitter_false(hlo_module_config_, false_computation, 2441 ir_emitter_context_); 2442 TF_CHECK_OK(false_computation->root_instruction()->Accept(&ir_emitter_false)); 2443 2444 return MakeUnique<ConditionalThunk>( 2445 GetAllocationSlice(*hlo->operand(0)), 2446 GetAllocationSlice(*hlo->operand(1)), 2447 GetAllocationSlice(*hlo->operand(2)), 2448 std::move(*ir_emitter_true.ConsumeThunkSequence()), 2449 std::move(*ir_emitter_false.ConsumeThunkSequence()), hlo); 2450 } 2451 2452 Status IrEmitterUnnested::EmitTargetElementLoopInThunk( 2453 const HloInstruction& hlo, 2454 const llvm_ir::ElementGenerator& element_generator, KernelThunk* thunk) { 2455 VLOG(3) << bindings_.ToString(); 2456 2457 const Shape& element_shape = hlo.IsMultiOutputFusion() 2458 ? ShapeUtil::GetSubshape(hlo.shape(), {0}) 2459 : hlo.shape(); 2460 LaunchDimensions launch_dimensions = CalculateLaunchDimensions( 2461 element_shape, ir_emitter_context_->device_description()); 2462 UpdateLaunchDimensions(launch_dimensions, thunk, 2463 ir_emitter_context_->llvm_module()); 2464 if (!hlo.IsMultiOutputFusion()) { 2465 return ParallelLoopEmitter(element_generator, GetIrArray(hlo, hlo), 2466 launch_dimensions, &ir_builder_) 2467 .EmitLoop(IrName(&hlo)); 2468 } 2469 2470 // For multiple outputs fusion, we need to emit each operand and the root. 2471 std::vector<llvm_ir::IrArray> output_arrays; 2472 for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) { 2473 output_arrays.push_back(GetIrArray(hlo, hlo, {i})); 2474 } 2475 TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, output_arrays, 2476 launch_dimensions, &ir_builder_) 2477 .EmitLoop(IrName(&hlo))); 2478 2479 std::vector<llvm::Value*> tuple_operand_ptrs; 2480 for (int64 i = 0; i < output_arrays.size(); ++i) { 2481 tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); 2482 } 2483 ir_builder_.SetInsertPoint(ir_builder_.GetInsertBlock()->getTerminator()); 2484 llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &ir_builder_, 2485 module_); 2486 return Status::OK(); 2487 } 2488 2489 Status IrEmitterUnnested::EmitTargetElementLoop( 2490 const HloInstruction& hlo, 2491 const llvm_ir::ElementGenerator& element_generator) { 2492 CHECK(Thunk::Kind::kKernel == LastThunk()->kind()); 2493 return EmitTargetElementLoopInThunk(hlo, element_generator, 2494 static_cast<KernelThunk*>(LastThunk())); 2495 } 2496 2497 } // namespace gpu 2498 } // namespace xla 2499