Home | History | Annotate | Download | only in gpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <algorithm>
     17 #include <cstring>
     18 #include <iterator>
     19 #include <memory>
     20 #include <string>
     21 #include <vector>
     22 
     23 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
     24 
     25 #include "absl/algorithm/container.h"
     26 #include "absl/memory/memory.h"
     27 #include "absl/strings/str_cat.h"
     28 #include "absl/types/optional.h"
     29 #include "absl/types/span.h"
     30 #include "llvm/ADT/StringRef.h"
     31 #include "llvm/IR/BasicBlock.h"
     32 #include "llvm/IR/Function.h"
     33 #include "llvm/IR/IRBuilder.h"
     34 #include "llvm/IR/Instructions.h"
     35 #include "llvm/IR/LLVMContext.h"
     36 #include "llvm/IR/Module.h"
     37 #include "tensorflow/compiler/xla/layout_util.h"
     38 #include "tensorflow/compiler/xla/literal.h"
     39 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
     40 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
     41 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
     42 #include "tensorflow/compiler/xla/service/gpu/cholesky_thunk.h"
     43 #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h"
     44 #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
     45 #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h"
     46 #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h"
     47 #include "tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h"
     48 #include "tensorflow/compiler/xla/service/gpu/fft_thunk.h"
     49 #include "tensorflow/compiler/xla/service/gpu/for_thunk.h"
     50 #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h"
     51 #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
     52 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
     53 #include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h"
     54 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
     55 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
     56 #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h"
     57 #include "tensorflow/compiler/xla/service/gpu/memset_thunk.h"
     58 #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
     59 #include "tensorflow/compiler/xla/service/gpu/outfeed_thunk.h"
     60 #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
     61 #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
     62 #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
     63 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
     64 #include "tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h"
     65 #include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h"
     66 #include "tensorflow/compiler/xla/service/gpu/while_thunk.h"
     67 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
     68 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     69 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     70 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
     71 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     72 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
     73 #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
     74 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
     75 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
     76 #include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h"
     77 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
     78 #include "tensorflow/compiler/xla/service/name_uniquer.h"
     79 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
     80 #include "tensorflow/compiler/xla/service/while_loop_analysis.h"
     81 #include "tensorflow/compiler/xla/shape_util.h"
     82 #include "tensorflow/compiler/xla/status_macros.h"
     83 #include "tensorflow/compiler/xla/types.h"
     84 #include "tensorflow/compiler/xla/util.h"
     85 #include "tensorflow/compiler/xla/window_util.h"
     86 #include "tensorflow/compiler/xla/xla_data.pb.h"
     87 #include "tensorflow/core/lib/core/bits.h"
     88 #include "tensorflow/core/lib/core/status.h"
     89 #include "tensorflow/core/platform/logging.h"
     90 
     91 namespace xla {
     92 namespace gpu {
     93 
     94 using llvm_ir::KernelMappingScheme;
     95 using EmitElementFunction =
     96     std::function<void(const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
     97                        llvm::Value* x_loc, int64 x_iter_num)>;
     98 
     99 namespace {
    100 
    101 using absl::InlinedVector;
    102 using absl::nullopt;
    103 using absl::optional;
    104 using absl::StrCat;
    105 using llvm_ir::IrArray;
    106 using llvm_ir::IrName;
    107 
    108 namespace m = match;
    109 
    110 // If a dimensions is smaller than this, untiled transposition may be more
    111 // efficient.
    112 const int64 kMinDimensionToTransposeTiled = 16;
    113 
    114 // Returns true if all paths from `hlo` to `root` contain only tuples. The
    115 // result of such an HloInstruction does not need to be materialized, when the
    116 // computation can have a hybrid result.
    117 bool ReachRootViaOnlyTuples(const HloInstruction& hlo,
    118                             const HloInstruction& root) {
    119   if (hlo.opcode() != HloOpcode::kTuple) {
    120     return false;
    121   }
    122 
    123   if (&hlo == &root) {
    124     return true;
    125   }
    126 
    127   for (HloInstruction* user : hlo.users()) {
    128     if (!ReachRootViaOnlyTuples(*user, root)) {
    129       return false;
    130     }
    131   }
    132 
    133   return true;
    134 }
    135 
    136 // If `hlo` is a Transpose, returns its operand; otherwise returns `hlo` itself.
    137 const HloInstruction* StripTranspose(const HloInstruction& hlo) {
    138   if (hlo.IsRank2Transpose()) {
    139     return hlo.operand(0);
    140   }
    141   return &hlo;
    142 }
    143 
    144 // Updates the launch dimensions in "thunk" and annotate the launch dimensions
    145 // of the corresponding IR kernel in "llvm_module".
    146 // Precondition: "thunk" must be a KernelThunk.
    147 void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk,
    148                             llvm::Module* llvm_module) {
    149   CHECK(Thunk::Kind::kKernel == thunk->kind());
    150   KernelThunk* kernel_thunk = static_cast<KernelThunk*>(thunk);
    151   kernel_thunk->SetLaunchDimensions(launch_dims);
    152 
    153   // Add __launch_bounds__ to metadata. This limits registers per thread to
    154   // avoid out-of-resources launching errors.
    155   llvm::NamedMDNode* nvvm_annotations_node =
    156       llvm_module->getOrInsertNamedMetadata("nvvm.annotations");
    157   llvm::Function* ir_kernel =
    158       llvm_module->getFunction(kernel_thunk->kernel_name().c_str());
    159   llvm::LLVMContext& llvm_context = llvm_module->getContext();
    160   llvm::ConstantInt* threads_per_block_ir_value = llvm::ConstantInt::get(
    161       llvm::IntegerType::get(llvm_context, /*NumBits=*/32),
    162       launch_dims.threads_per_block());
    163   // Our launch bounds are exact, so we can specify them as reqntidx rather than
    164   // maxntidx.
    165   nvvm_annotations_node->addOperand(llvm::MDNode::get(
    166       llvm_context,
    167       {llvm::ConstantAsMetadata::get(ir_kernel),
    168        llvm::MDString::get(llvm_context, "reqntidx"),
    169        llvm::ConstantAsMetadata::get(threads_per_block_ir_value)}));
    170 }
    171 
    172 }  // namespace
    173 
    174 IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
    175                                      const HloComputation* hlo_computation,
    176                                      IrEmitterContext* ir_emitter_context)
    177     : IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/false),
    178       hlo_computation_(hlo_computation) {
    179   // Initialize thunk_sequence_ to an empty list of thunks.
    180   thunk_sequence_.reset(new ThunkSequence());
    181 }
    182 
    183 Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) {
    184   bindings_.UnbindAllLocalIrValues();
    185   return DfsHloVisitor::Postprocess(hlo);
    186 }
    187 
    188 llvm::Function* IrEmitterUnnested::BuildKernelPrototype(
    189     const HloInstruction& inst,
    190     absl::Span<const BufferAllocation* const> args) {
    191   // Compute the kernel name. The opcode string may contain "-" which cannot be
    192   // in a PTX function name, so sanitize the name before uniquifying it.
    193   string kernel_name = ir_emitter_context_->name_uniquer()->GetUniqueName(
    194       llvm_ir::SanitizeFunctionName(inst.name()));
    195 
    196   // Create the kernel and add it to the module.
    197   llvm::Module* module = ir_emitter_context_->llvm_module();
    198   llvm::LLVMContext& context = module->getContext();
    199   llvm::FunctionType* kernel_type = llvm::FunctionType::get(
    200       /*Result=*/llvm::Type::getVoidTy(context),
    201       std::vector<llvm::Type*>(args.size(), b_.getInt8PtrTy()),
    202       /*isVarArg=*/false);
    203   llvm::Function* kernel =
    204       llvm::Function::Create(kernel_type, llvm::GlobalValue::ExternalLinkage,
    205                              kernel_name.c_str(), module);
    206 
    207   // Add dereferenceable and alignment information to each of the kernel's
    208   // parameters.
    209   auto arg_it = kernel->arg_begin();
    210   for (size_t arg_no = 0; arg_no < args.size(); ++arg_no) {
    211     const BufferAllocation* alloc = args[arg_no];
    212     llvm::Argument* fn_arg = &*arg_it;
    213     ++arg_it;
    214 
    215     kernel->addDereferenceableAttr(arg_no + 1, alloc->size());
    216 
    217     const int64 alignment = [&] {
    218       if (alloc->is_entry_computation_parameter()) {
    219         return kEntryParameterAlignBytes;
    220       } else if (alloc->is_constant()) {
    221         return kConstantBufferAlignBytes;
    222       } else {
    223         return kXlaAllocatedBufferAlignBytes;
    224       }
    225     }();
    226 
    227     kernel->addParamAttr(
    228         arg_no,
    229         llvm::Attribute::get(context, llvm::Attribute::Alignment, alignment));
    230 
    231     if (alloc->IsPreallocatedTempBuffer()) {
    232       fn_arg->setName("temp_buf");
    233     } else {
    234       fn_arg->setName(StrCat("alloc", alloc->index()));
    235     }
    236   }
    237 
    238   // TODO(b/65380986): Investigate if adding fast math flags for generated
    239   // kernels makes sense.
    240 
    241   // Add the declaration of this kernel to llvm.nvvm.annotations so that NVPTX
    242   // treats it as a CUDA kernel.
    243   llvm::NamedMDNode* nvvm_annotations_node =
    244       module->getOrInsertNamedMetadata("nvvm.annotations");
    245   nvvm_annotations_node->addOperand(llvm::MDNode::get(
    246       context, {llvm::ConstantAsMetadata::get(kernel),
    247                 llvm::MDString::get(context, "kernel"),
    248                 llvm::ConstantAsMetadata::get(b_.getInt32(1))}));
    249 
    250   // Update the insert point to the entry basic block.
    251   llvm::BasicBlock* entry_bb =
    252       llvm::BasicBlock::Create(context, /*Name=*/"entry", /*Parent=*/kernel);
    253 
    254   // Emit a "return void" at entry_bb's end, and set the insert point before
    255   // that return instruction.
    256   b_.SetInsertPoint(llvm::ReturnInst::Create(context, entry_bb));
    257 
    258   return kernel;
    259 }
    260 
    261 namespace {
    262 // Computes the maximum valid unroll factor for a given instruction.
    263 int ComputeMaxUnrollFactor(const HloInstruction* hlo) {
    264   int max_unroll_factor = hlo->GetModule()
    265                               ->config()
    266                               .debug_options()
    267                               .xla_gpu_max_kernel_unroll_factor();
    268 
    269   // Find the largest possible power of two to unroll by.
    270   // TODO(kramerb): Make this smarter.
    271   const Shape& element_shape = hlo->IsMultiOutputFusion()
    272                                    ? ShapeUtil::GetSubshape(hlo->shape(), {0})
    273                                    : hlo->shape();
    274   int64 num_elements = ShapeUtil::ElementsIn(element_shape);
    275   for (int i = max_unroll_factor; i > 1; i /= 2) {
    276     if (num_elements % i == 0) {
    277       return i;
    278     }
    279   }
    280 
    281   // Cannot unroll.
    282   return 1;
    283 }
    284 
    285 // Returns the llvm type for the indices used in the kernel that contains the
    286 // hlo instruction. Such indices include the index for the parallel loop and
    287 // the indices for the tensors accessed by the kernel. The return type is i32
    288 // iff the following conditions are met:
    289 //  . The launch_size of the kernel is within the range of i32.
    290 //  . The sizes of all the tensors accessed within the kernel are within the
    291 //    range of i32.
    292 // Otherwise, the return type is i64.
    293 llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size,
    294                                   llvm::IRBuilder<>* b) {
    295   // Find the unnested hlo instructon for which the kernel is generated for.
    296   const HloInstruction* unnested_hlo = hlo;
    297   const HloComputation* computation = hlo->parent();
    298   if (computation->IsFusionComputation()) {
    299     unnested_hlo = computation->FusionInstruction();
    300   }
    301 
    302   auto shape_in_range = [&](const Shape& s) {
    303     bool in_range = true;
    304     ShapeUtil::ForEachSubshape(s, [&](const Shape& sub_shape,
    305                                       const ShapeIndex& /*index*/) {
    306       if (sub_shape.IsArray() && !IsInt32(ShapeUtil::ElementsIn(sub_shape))) {
    307         in_range = false;
    308       }
    309     });
    310 
    311     return in_range;
    312   };
    313 
    314   llvm::Type* i64_ty = b->getInt64Ty();
    315   // Check launch dimension
    316   if (!IsInt32(launch_size)) {
    317     return i64_ty;
    318   }
    319 
    320   // Check the size of result tensors
    321   if (!shape_in_range(unnested_hlo->shape())) {
    322     return i64_ty;
    323   }
    324 
    325   auto hlo_shape_in_range = [&](const HloInstruction* operand) -> bool {
    326     return shape_in_range(operand->shape());
    327   };
    328 
    329   // Check the size of input tensors
    330   if (!absl::c_all_of(unnested_hlo->operands(), hlo_shape_in_range)) {
    331     return i64_ty;
    332   }
    333 
    334   // Check the size of the internal result tensors
    335   if (unnested_hlo->opcode() == HloOpcode::kFusion) {
    336     if (!absl::c_all_of(
    337             unnested_hlo->fused_instructions_computation()->instructions(),
    338             hlo_shape_in_range)) {
    339       return i64_ty;
    340     }
    341   }
    342 
    343   return b->getInt32Ty();
    344 }
    345 
    346 }  // namespace
    347 
    348 Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) {
    349   return IrEmitter::DefaultAction(hlo);
    350 }
    351 
    352 Status IrEmitterUnnested::HandleDot(HloInstruction* dot) {
    353   if (ImplementedAsGemm(*dot)) {
    354     AddThunkToThunkSequence(BuildGemmThunk(dot));
    355     return Status::OK();
    356   }
    357   AddThunkToThunkSequence(
    358       BuildKernelThunk(dot, /*implements_whole_instruction=*/true));
    359   return IrEmitter::HandleDot(dot);
    360 }
    361 
    362 Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) {
    363   AddThunkToThunkSequence(BuildConditionalThunk(conditional));
    364   return Status::OK();
    365 }
    366 
    367 Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution) {
    368   AddThunkToThunkSequence(
    369       BuildKernelThunk(convolution, /*implements_whole_instruction=*/true));
    370   return IrEmitter::HandleConvolution(convolution);
    371 }
    372 
    373 Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
    374   // A CustomCall on the GPU backend can either be a custom-call to a
    375   // user-supplied kernel, or a call into a library like cudnn.
    376 
    377   // Lower custom-calls to cudnn batchnorm ops to specialized thunks.  It's part
    378   // of the contract of these cudnn batchnorm calls that the epsilon and
    379   // feature_index operands be constants.
    380   if (custom_call->custom_call_target() ==
    381       kCudnnBatchNormForwardInferenceCallTarget) {
    382     const HloInstruction* epsilon = custom_call->operand(5);
    383     CHECK(epsilon->IsConstant());
    384     float epsilon_value = epsilon->literal().Get<float>({});
    385 
    386     const HloInstruction* feature_index = custom_call->operand(6);
    387     CHECK(feature_index->IsConstant());
    388     int64 feature_index_value = feature_index->literal().Get<int64>({});
    389 
    390     AddThunkToThunkSequence(
    391         absl::make_unique<CudnnBatchNormForwardInferenceThunk>(
    392             /*operand=*/GetAllocationSlice(*custom_call->operand(0)),
    393             /*scale=*/GetAllocationSlice(*custom_call->operand(1)),
    394             /*offset=*/GetAllocationSlice(*custom_call->operand(2)),
    395             /*mean=*/GetAllocationSlice(*custom_call->operand(3)),
    396             /*variance=*/GetAllocationSlice(*custom_call->operand(4)),
    397             /*epsilon=*/epsilon_value,
    398             /*feature_index=*/feature_index_value,
    399             /*output=*/GetAllocationSlice(*custom_call),
    400             /*hlo=*/custom_call));
    401     return Status::OK();
    402   }
    403 
    404   if (custom_call->custom_call_target() ==
    405       kCudnnBatchNormForwardTrainingCallTarget) {
    406     const HloInstruction* epsilon = custom_call->operand(3);
    407     CHECK(epsilon->IsConstant());
    408     float epsilon_value = epsilon->literal().Get<float>({});
    409 
    410     const HloInstruction* feature_index = custom_call->operand(4);
    411     CHECK(feature_index->IsConstant());
    412     int64 feature_index_value = feature_index->literal().Get<int64>({});
    413 
    414     // BatchNormTraining returns a tuple of three elements: data, calculated
    415     // mean, and calculated 1/sqrt(variance + epsilon).
    416     const auto& assn = ir_emitter_context_->buffer_assignment();
    417     auto output_data = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie();
    418     auto output_mean = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie();
    419     auto output_inv_stddev = assn.GetUniqueSlice(custom_call, {2}).ValueOrDie();
    420     AddThunkToThunkSequence(
    421         absl::make_unique<CudnnBatchNormForwardTrainingThunk>(
    422             /*operand=*/GetAllocationSlice(*custom_call->operand(0)),
    423             /*scale=*/GetAllocationSlice(*custom_call->operand(1)),
    424             /*offset=*/GetAllocationSlice(*custom_call->operand(2)),
    425             /*epsilon=*/epsilon_value,
    426             /*feature_index=*/feature_index_value,
    427             /*output_data=*/output_data,
    428             /*output_mean=*/output_mean,
    429             /*output_inv_stddev=*/output_inv_stddev,
    430             /*output_tuple=*/GetAllocationSlice(*custom_call),
    431             /*hlo=*/custom_call));
    432     return Status::OK();
    433   }
    434 
    435   if (custom_call->custom_call_target() == kCudnnBatchNormBackwardCallTarget) {
    436     const HloInstruction* epsilon = custom_call->operand(5);
    437     CHECK(epsilon->IsConstant());
    438     float epsilon_value = epsilon->literal().Get<float>({});
    439 
    440     const HloInstruction* feature_index = custom_call->operand(6);
    441     CHECK(feature_index->IsConstant());
    442     int64 feature_index_value = feature_index->literal().Get<int64>({});
    443 
    444     // BatchNormGrad returns a tuple of three elements: grad_data, grad_scale,
    445     // grad_offset.
    446     const auto& assn = ir_emitter_context_->buffer_assignment();
    447     auto output_grad_data = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie();
    448     auto output_grad_scale = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie();
    449     auto output_grad_offset =
    450         assn.GetUniqueSlice(custom_call, {2}).ValueOrDie();
    451     AddThunkToThunkSequence(absl::make_unique<CudnnBatchNormBackwardThunk>(
    452         /*operand=*/GetAllocationSlice(*custom_call->operand(0)),
    453         /*scale=*/GetAllocationSlice(*custom_call->operand(1)),
    454         /*mean=*/GetAllocationSlice(*custom_call->operand(2)),
    455         /*inv_stddev=*/GetAllocationSlice(*custom_call->operand(3)),
    456         /*grad_output=*/GetAllocationSlice(*custom_call->operand(4)),
    457         /*epsilon=*/epsilon_value,
    458         /*feature_index=*/feature_index_value,
    459         /*output_grad_data=*/output_grad_data,
    460         /*output_grad_scale=*/output_grad_scale,
    461         /*output_grad_offset=*/output_grad_offset,
    462         /*output_tuple=*/GetAllocationSlice(*custom_call),
    463         /*hlo=*/custom_call));
    464     return Status::OK();
    465   }
    466 
    467   if (IsCustomCallToDnnConvolution(*custom_call)) {
    468     const auto& assn = ir_emitter_context_->buffer_assignment();
    469     std::vector<BufferAllocation::Slice> operand_slices;
    470     operand_slices.reserve(custom_call->operand_count());
    471     for (const auto* operand : custom_call->operands()) {
    472       operand_slices.push_back(GetAllocationSlice(*operand));
    473     }
    474     auto tuple_result_slice = GetAllocationSlice(*custom_call);
    475     auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie();
    476     auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie();
    477 
    478     AddThunkToThunkSequence(absl::make_unique<ConvolutionThunk>(
    479         Cast<HloCustomCallInstruction>(custom_call), std::move(operand_slices),
    480         conv_result_slice, scratch_slice, tuple_result_slice));
    481     return Status::OK();
    482   }
    483 
    484   if (custom_call->custom_call_target() == kCusolverCholeskyCallTarget) {
    485     TF_ASSIGN_OR_RETURN(CholeskyOptions options,
    486                         custom_call->backend_config<CholeskyOptions>());
    487 
    488     const Shape& shape = custom_call->operand(0)->shape();
    489     int ndim = shape.dimensions_size();
    490     CHECK_GE(ndim, 2);
    491     int64 n = shape.dimensions(ndim - 1);
    492 
    493     const auto& dims = shape.dimensions();
    494     int64 batch_size = std::accumulate(dims.begin(), dims.end() - 2, int64{1},
    495                                        [](int64 a, int64 b) { return a * b; });
    496 
    497     auto operand_buffer = GetAllocationSlice(*custom_call->operand(0));
    498 
    499     const auto& assn = ir_emitter_context_->buffer_assignment();
    500     auto a_buffer = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie();
    501     auto workspace_buffer = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie();
    502     auto info_buffer = assn.GetUniqueSlice(custom_call, {2}).ValueOrDie();
    503 
    504     std::vector<std::unique_ptr<Thunk>> thunks;
    505 
    506     if (operand_buffer != a_buffer) {
    507       thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
    508           /*source_address=*/operand_buffer,
    509           /*destination_buffer=*/a_buffer,
    510           /*mem_size=*/ShapeUtil::ByteSizeOf(shape), custom_call));
    511     }
    512 
    513     thunks.push_back(absl::make_unique<CholeskyThunk>(
    514         options, a_buffer, workspace_buffer, info_buffer,
    515         custom_call->operand(0)->shape().element_type(), batch_size, n,
    516         custom_call));
    517 
    518     // Elide the sequential thunk if there's no copy.
    519     if (thunks.size() == 1) {
    520       AddThunkToThunkSequence(std::move(thunks[0]));
    521     } else {
    522       AddThunkToThunkSequence(
    523           absl::make_unique<SequentialThunk>(std::move(thunks), custom_call));
    524     }
    525 
    526     return Status::OK();
    527   }
    528 
    529   return IrEmitter::HandleCustomCall(custom_call);
    530 }
    531 
    532 Status IrEmitterUnnested::HandleFft(HloInstruction* fft) {
    533   TF_RET_CHECK(
    534       LayoutUtil::IsMonotonicWithDim0Major(fft->operand(0)->shape().layout()));
    535   TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(fft->shape().layout()));
    536   AddThunkToThunkSequence(BuildFftThunk(fft));
    537   return Status::OK();
    538 }
    539 
    540 Status IrEmitterUnnested::HandleTriangularSolve(HloInstruction* hlo) {
    541   auto has_fortran_layout = [](const Layout& layout) {
    542     int n = layout.minor_to_major_size();
    543     return layout.minor_to_major(0) == n - 2 &&
    544            layout.minor_to_major(1) == n - 1;
    545   };
    546   TF_RET_CHECK(has_fortran_layout(hlo->operand(0)->shape().layout()));
    547   TF_RET_CHECK(has_fortran_layout(hlo->operand(1)->shape().layout()));
    548   TF_RET_CHECK(has_fortran_layout(hlo->shape().layout()));
    549 
    550   std::vector<std::unique_ptr<Thunk>> thunks;
    551 
    552   // Triangular solve is in-place on 'b', so copy 'b' to the output if they
    553   // aren't the same buffer.
    554   auto operand_buffer = GetAllocationSlice(*hlo->operand(1));
    555   auto destination_buffer = GetAllocationSlice(*hlo);
    556   if (operand_buffer != destination_buffer) {
    557     thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
    558         /*source_address=*/operand_buffer,
    559         /*destination_buffer=*/destination_buffer,
    560         /*mem_size=*/ShapeUtil::ByteSizeOf(hlo->operand(1)->shape()), hlo));
    561   }
    562 
    563   thunks.push_back(BuildTriangularSolveThunk(hlo));
    564 
    565   // Elide the sequential thunk if there's no copy.
    566   if (thunks.size() == 1) {
    567     AddThunkToThunkSequence(std::move(thunks[0]));
    568   } else {
    569     AddThunkToThunkSequence(
    570         absl::make_unique<SequentialThunk>(std::move(thunks), hlo));
    571   }
    572   return Status::OK();
    573 }
    574 
    575 Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
    576   HloInstruction* root = fusion->fused_expression_root();
    577   if (HloInstruction::FusionKind::kInput == fusion->fusion_kind()) {
    578     switch (root->opcode()) {
    579       case HloOpcode::kScatter: {
    580         std::vector<std::unique_ptr<Thunk>> thunks;
    581         // The initialization from 'operand' is using different loop bounds, so
    582         // emit it in a separate kernel. Treat it like a loop fusion, writing to
    583         // the output buffer.
    584         {
    585           int unroll_factor = ComputeMaxUnrollFactor(fusion);
    586           thunks.push_back(BuildKernelThunk(
    587               fusion, /*implements_whole_instruction=*/false, unroll_factor));
    588 
    589           GpuElementalIrEmitter operand_elemental_emitter(
    590               hlo_module_config_, ir_emitter_context_->llvm_module(), &b_,
    591               GetNestedComputer());
    592           FusedIrEmitter operand_fused_emitter(
    593               GetGeneratorForOperandIrArrays(fusion),
    594               &operand_elemental_emitter);
    595           TF_RETURN_IF_ERROR(
    596               root->mutable_operand(0)->Accept(&operand_fused_emitter));
    597 
    598           TF_RETURN_IF_ERROR(EmitTargetElementLoopInThunk(
    599               *fusion, operand_fused_emitter.GetGenerator(root->operand(0)),
    600               static_cast<KernelThunk*>(thunks.back().get())));
    601         }
    602 
    603         // Now build the actual scatter, reading and writing to the freshly
    604         // filled output buffer.
    605         {
    606           thunks.push_back(
    607               BuildKernelThunk(fusion,
    608                                /*implements_whole_instruction=*/false));
    609           // Spin up a new fused emitter for the scatter kernel and emit it.
    610           GpuElementalIrEmitter scatter_elemental_emitter(
    611               hlo_module_config_, ir_emitter_context_->llvm_module(), &b_,
    612               GetNestedComputer());
    613           FusedIrEmitter scatter_fused_emitter(
    614               GetGeneratorForOperandIrArrays(fusion),
    615               &scatter_elemental_emitter);
    616           TF_RETURN_IF_ERROR(root->Accept(&scatter_fused_emitter));
    617           TF_RETURN_IF_ERROR(EmitScatter(
    618               thunks.back().get(), root,
    619               /*scatter_indices_gen=*/
    620               scatter_fused_emitter.GetGenerator(root->operand(1)),
    621               /*updates_gen=*/
    622               scatter_fused_emitter.GetGenerator(root->operand(2))));
    623         }
    624         AddThunkToThunkSequence(
    625             absl::make_unique<SequentialThunk>(std::move(thunks), fusion));
    626         return Status::OK();
    627       }
    628       case HloOpcode::kTuple:
    629       case HloOpcode::kReduce: {
    630         // HandleFusion specializes reduction from a multi-dimensional array to
    631         // a 1D array. The specialized version requires a initializer thunk that
    632         // initializes the output array to the initial value of the reduce.
    633         if (root->opcode() == HloOpcode::kReduce && root->shape().IsTuple()) {
    634           // TODO(b/118332391): Support variadic reduce.
    635           return Unimplemented("Variadic reduce is not supported on GPU");
    636         }
    637         return EmitReductionToVector(fusion);
    638       }
    639       default:
    640         LOG(FATAL) << "Bad opcode for input fusion: "
    641                    << fusion->fused_expression_root()->opcode();
    642     }
    643   } else if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(
    644                  fusion, ir_emitter_context_->buffer_assignment())) {
    645     // Fusion node with dynamic-update-slice as the root where the op's input
    646     // (i.e. array to update) shares the same slice as its output.  In this case
    647     // we have a special algorithm that modifies the output in place without
    648     // touching the un-updated elements.
    649 
    650     // Set up kernel thunk and fused ir emitter.
    651     std::unique_ptr<KernelThunk> fusion_thunk =
    652         BuildKernelThunk(fusion, /*implements_whole_instruction=*/true);
    653     GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
    654                                             ir_emitter_context_->llvm_module(),
    655                                             &b_, GetNestedComputer());
    656 
    657     // Shape of the dynamic-update-slice's "update" operand.
    658     Shape update_shape = root->operand(1)->shape();
    659 
    660     // Array to write into.  Because this is an in-place operation, this is the
    661     // same as operand 0's array.
    662     IrArray output_array = GetIrArray(*fusion, *fusion);
    663 
    664     LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
    665         update_shape, ir_emitter_context_->device_description());
    666     UpdateLaunchDimensions(launch_dimensions, fusion_thunk.get(),
    667                            ir_emitter_context_->llvm_module());
    668     AddThunkToThunkSequence(std::move(fusion_thunk));
    669 
    670     return llvm_ir::EmitParallelFusedDynamicUpdateSliceInPlace(
    671         fusion, GetGeneratorForOperandIrArrays(fusion), output_array,
    672         &elemental_emitter, launch_dimensions, &b_);
    673   }
    674 
    675   if (ImplementedAsGemm(*fusion)) {
    676     AddThunkToThunkSequence(BuildGemmThunk(fusion));
    677     return Status::OK();
    678   }
    679 
    680   CHECK_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kLoop);
    681 
    682   if (CheckAndEmitHloWithTile021(fusion)) {
    683     return Status::OK();
    684   }
    685 
    686   return IrEmitter::HandleFusion(fusion);
    687 }
    688 
    689 Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
    690   CHECK(ShapeUtil::Compatible(copy->operand(0)->shape(), copy->shape()));
    691   const BufferAssignment& buffer_assignment =
    692       ir_emitter_context_->buffer_assignment();
    693   if (LayoutUtil::Equal(copy->operand(0)->shape().layout(),
    694                         copy->shape().layout()) &&
    695       buffer_assignment.GetUniqueTopLevelSlice(copy->operand(0)).ok()) {
    696     AddThunkToThunkSequence(BuildDeviceToDeviceCopyThunk(copy));
    697     return Status::OK();
    698   }
    699   if (CheckAndEmitHloWithTile021(copy)) {
    700     return Status::OK();
    701   }
    702 
    703   return IrEmitter::HandleCopy(copy);
    704 }
    705 
    706 Status IrEmitterUnnested::EmitExtraOutputsForReduce(
    707     const HloInstruction* unnested_hlo, const IrArray::Index& index,
    708     absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
    709         extra_output_gens) {
    710   for (int i = 0; i != extra_output_gens.size(); ++i) {
    711     llvm::Value* extra_output_address =
    712         GetIrArray(*unnested_hlo, *unnested_hlo, extra_output_gens[i].second)
    713             .EmitArrayElementAddress(index, &b_,
    714                                      "extra_output_element_address");
    715     TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value,
    716                         extra_output_gens[i].first(index));
    717     Store(extra_output_ir_value, extra_output_address);
    718   }
    719   return Status::OK();
    720 }
    721 
    722 Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
    723   // TODO(b/118332391): Support multi-output reduce.
    724   if (!reduce->shape().IsArray()) {
    725     return Unimplemented("Multi-output reduce is not supported on GPU");
    726   }
    727   if (IsReductionToVector(*reduce)) {
    728     return EmitReductionToVector(reduce);
    729   }
    730 
    731   return IrEmitter::HandleReduce(reduce);
    732 }
    733 
    734 Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) {
    735   // For the root node of the entry computation we can elide writing the tuple
    736   // buffer. We can always figure out the contents of the tuples from buffer
    737   // assignment because we insert copies to ensure non-ambiguous output buffers.
    738   // GpuExecutable never reads the tuple buffer.
    739   if (tuple ==
    740       tuple->parent()->parent()->entry_computation()->root_instruction()) {
    741     return Status::OK();
    742   }
    743   bool all_tuple_elements_have_buffer =
    744       absl::c_all_of(tuple->operands(), [&](HloInstruction* tuple_element) {
    745         return ir_emitter_context_->buffer_assignment()
    746             .GetUniqueTopLevelSlice(tuple_element)
    747             .ok();
    748       });
    749   // TODO(b/111689850): This logic isn't quite correct.
    750   //
    751   // Tuples (especially tuples that are the final result of a computation) can
    752   // be so huge that if we were to emit a kernel that took each tuple element as
    753   // a parameter, we would exceed the max allowable number of parameters to a
    754   // GPU kernel, b/31336476. As an optimization, if all tuple elements have a
    755   // buffer, we collect their buffer addresses in a host array, and then copy
    756   // that array to the tuple's buffer.
    757   //
    758   // Some tuple elements might not have an unambiguous buffer (like the result
    759   // of a select-tuple). In that case, we fall back to emitting kernels which
    760   // have access to their buffer addresses in code.
    761   if (all_tuple_elements_have_buffer) {
    762     std::vector<BufferAllocation::Slice> tuple_element_buffers;
    763     for (const HloInstruction* tuple_element : tuple->operands()) {
    764       tuple_element_buffers.push_back(GetAllocationSlice(*tuple_element));
    765     }
    766     AddThunkToThunkSequence(absl::make_unique<TupleThunk>(
    767         tuple_element_buffers, GetAllocationSlice(*tuple), tuple));
    768     return Status::OK();
    769   }
    770   AddThunkToThunkSequence(
    771       BuildKernelThunk(tuple, /*implements_whole_instruction=*/true));
    772   return IrEmitter::HandleTuple(tuple);
    773 }
    774 
    775 Status IrEmitterUnnested::HandleGetTupleElement(HloInstruction*) {
    776   // GetTupleElement IR is emitted in the IR context of the user instruction,
    777   // and so we do not build a kernel for GetTupleElement instructions.
    778   return Status::OK();
    779 }
    780 
    781 Status IrEmitterUnnested::HandleSelectAndScatter(
    782     HloInstruction* select_and_scatter) {
    783   CHECK_EQ(select_and_scatter->operand_count(), 3);
    784   const auto* operand = select_and_scatter->operand(0);
    785   const auto* source = select_and_scatter->operand(1);
    786   const Window& window = select_and_scatter->window();
    787   PrimitiveType operand_element_type = operand->shape().element_type();
    788   const int64 rank = operand->shape().rank();
    789   CHECK_EQ(rank, source->shape().rank());
    790   CHECK_EQ(rank, window.dimensions_size());
    791 
    792   TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> initializer_thunk,
    793                       BuildInitializerThunk(select_and_scatter));
    794   std::vector<std::unique_ptr<Thunk>> thunks;
    795   thunks.push_back(std::move(initializer_thunk));
    796   thunks.push_back(BuildKernelThunk(select_and_scatter,
    797                                     /*implements_whole_instruction=*/false));
    798   std::unique_ptr<SequentialThunk> select_and_scatter_thunk =
    799       absl::make_unique<SequentialThunk>(std::move(thunks), select_and_scatter);
    800 
    801   // TODO(b/31410564): Implement dilation rate for select-and-scatter.
    802   if (window_util::HasDilation(window)) {
    803     return Unimplemented(
    804         "Dilation for SelectAndScatter not implemented on GPU.");
    805   }
    806 
    807   LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
    808       source->shape(), ir_emitter_context_->device_description());
    809   llvm::Type* index_type = GetIndexTypeForKernel(
    810       select_and_scatter, launch_dimensions.launch_bound(), &b_);
    811   auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
    812     return llvm::ConstantInt::get(index_type, c);
    813   };
    814 
    815   // kSelectAndScatter is implemented as two kernel launches: the first launch
    816   // initializes the output array to the given initial value,
    817   // and the second accumulates the "source" matrix to the
    818   // selected elements in the output array. The first launch is already
    819   // implemented by the initializer thunk generated earlier, so this function
    820   // only needs to take care of the select-and-scatter part.
    821   //
    822   // Pseudo code for select-and-scatter:
    823   //
    824   // for (coordinates S in the source):  # This loop is parallel.
    825   //   initialized_flag = false
    826   //   for (coordinates W in the window):
    827   //     I = S * stride + W - pad_low
    828   //     if I within bounds of operand:
    829   //       if !(initialized_flag and select(selected_value, operand(I))):
    830   //         selected_value = operand(I)
    831   //         selected_index = I
    832   //         initialized_flag = true
    833   //   output(selected_index) = scatter(output(selected_index), source(S))
    834   auto loop_body_emitter = [=](const IrArray::Index& source_index) -> Status {
    835     // Allocate space to keep the currently selected value, its index, and a
    836     // boolean flag if the value is initialized. The initialized_flag is set
    837     // false.
    838     llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry(
    839         llvm_ir::PrimitiveTypeToIrType(operand_element_type,
    840                                        ir_emitter_context_->llvm_module()),
    841         "selected_value_address", &b_);
    842     llvm::Value* selected_index_address =
    843         llvm_ir::EmitAllocaAtFunctionEntryWithCount(
    844             index_type, index_typed_constant(rank), "selected_index_address",
    845             &b_);
    846     llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry(
    847         b_.getInt1Ty(), "initialized_flag_address", &b_);
    848     Store(b_.getInt1(false), initialized_flag_address);
    849 
    850     // Create the inner loop to iterate over the window.
    851     llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "inner"), &b_,
    852                                       index_type);
    853     DimensionVector window_size;
    854     for (const auto& dim : window.dimensions()) {
    855       window_size.push_back(dim.size());
    856       CHECK_GT(dim.size(), 0);
    857     }
    858     const IrArray::Index window_index = window_loops.AddLoopsForShape(
    859         ShapeUtil::MakeShape(operand_element_type, window_size), "window");
    860     llvm_ir::SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(),
    861                                    &b_);
    862 
    863     // Compute the operand index to visit and evaluate the condition whether the
    864     // operand index is within the bounds. The unsigned comparison includes
    865     // checking whether the operand index >= 0.
    866     std::vector<llvm::Value*> operand_multi_index(source_index.size());
    867     llvm::Value* in_bounds_condition = b_.getInt1(true);
    868     for (int64 i = 0; i < rank; ++i) {
    869       llvm::Value* strided_index = NSWMul(
    870           source_index[i], index_typed_constant(window.dimensions(i).stride()));
    871       operand_multi_index[i] =
    872           NSWSub(NSWAdd(strided_index, window_index[i]),
    873                  index_typed_constant(window.dimensions(i).padding_low()));
    874       llvm::Value* index_condition = ICmpULT(
    875           operand_multi_index[i],
    876           index_typed_constant(ShapeUtil::GetDimension(operand->shape(), i)));
    877       in_bounds_condition = And(in_bounds_condition, index_condition);
    878     }
    879     CHECK(in_bounds_condition != nullptr);
    880 
    881     // Only need to do something if the operand index is within the bounds.
    882     // First check if the initialized_flag is set.
    883     llvm_ir::LlvmIfData if_in_bounds =
    884         llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
    885     llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, &b_);
    886     llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse(
    887         Load(initialized_flag_address), "initialized", &b_);
    888 
    889     // If the initialized_flag is false, initialize the selected value and index
    890     // with the currently visiting operand.
    891     llvm_ir::SetToFirstInsertPoint(if_initialized.false_block, &b_);
    892     const auto save_operand_index = [&](const IrArray::Index& operand_index) {
    893       for (int64 i = 0; i < rank; ++i) {
    894         llvm::Value* selected_index_address_slot =
    895             InBoundsGEP(selected_index_address, {b_.getInt32(i)});
    896         Store(operand_index[i], selected_index_address_slot);
    897       }
    898     };
    899     IrArray operand_array = GetIrArray(*operand, *select_and_scatter);
    900     IrArray::Index operand_index(operand_multi_index, operand->shape(),
    901                                  index_type);
    902     llvm::Value* operand_data =
    903         operand_array.EmitReadArrayElement(operand_index, &b_);
    904     Store(operand_data, selected_value_address);
    905     save_operand_index(operand_index);
    906     Store(b_.getInt1(true), initialized_flag_address);
    907 
    908     // If the initialized_flag is true, call the `select` function to
    909     // potentially update the selected value and index with the currently
    910     // visiting operand.
    911     llvm_ir::SetToFirstInsertPoint(if_initialized.true_block, &b_);
    912     llvm::Value* operand_address =
    913         operand_array.EmitArrayElementAddress(operand_index, &b_);
    914     llvm::Value* select_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
    915         llvm_ir::PrimitiveTypeToIrType(PRED,
    916                                        ir_emitter_context_->llvm_module()),
    917         "select_return_buffer", &b_);
    918     TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
    919         *select_and_scatter->select(),
    920         {selected_value_address, operand_address}, select_return_buffer));
    921     llvm::Value* result = Load(select_return_buffer);
    922 
    923     // If the 'select' function returns false, update the selected value and the
    924     // index to the currently visiting operand.
    925     llvm::Value* cond = ICmpNE(
    926         result,
    927         llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(
    928                                    PRED, ir_emitter_context_->llvm_module()),
    929                                0),
    930         "boolean_predicate");
    931     llvm_ir::LlvmIfData if_select_lhs =
    932         llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_);
    933     llvm_ir::SetToFirstInsertPoint(if_select_lhs.false_block, &b_);
    934     Store(Load(operand_address), selected_value_address);
    935     save_operand_index(operand_index);
    936 
    937     // After iterating over the window elements, scatter the source element to
    938     // the selected index of the output. The value we store at the output
    939     // location is computed by calling the `scatter` function with the source
    940     // value and the current output value.
    941     llvm_ir::SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(),
    942                                    &b_);
    943     std::vector<llvm::Value*> selected_multi_index;
    944     for (int64 i = 0; i < rank; ++i) {
    945       llvm::Value* selected_index_address_slot =
    946           InBoundsGEP(selected_index_address, {b_.getInt32(i)});
    947       selected_multi_index.push_back(Load(selected_index_address_slot));
    948     }
    949     llvm::Value* source_value_address =
    950         GetIrArray(*source, *select_and_scatter)
    951             .EmitArrayElementAddress(source_index, &b_);
    952     IrArray::Index selected_index(selected_multi_index,
    953                                   select_and_scatter->shape(),
    954                                   operand_index.GetType());
    955     llvm::Value* output_value_address =
    956         GetIrArray(*select_and_scatter, *select_and_scatter)
    957             .EmitArrayElementAddress(selected_index, &b_);
    958     return EmitAtomicOperationForNestedComputation(
    959         *select_and_scatter->scatter(), output_value_address,
    960         source_value_address);
    961   };
    962 
    963   UpdateLaunchDimensions(
    964       launch_dimensions,
    965       // IrEmitterUnnested implements kSelectAndScatter as a SequentialThunk
    966       // consisting of two thunks, an initializer KernelThunk that initializes
    967       // the output and another KernelThunk that accumulates the scattered
    968       // elements.
    969       select_and_scatter_thunk->thunks().back().get(),
    970       ir_emitter_context_->llvm_module());
    971   AddThunkToThunkSequence(std::move(select_and_scatter_thunk));
    972   return ParallelLoopEmitter(loop_body_emitter, source->shape(),
    973                              launch_dimensions, &b_)
    974       .EmitLoop(IrName(select_and_scatter), index_type);
    975 }
    976 
    977 Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while) {
    978   HloComputation* condition = xla_while->while_condition();
    979   TF_RET_CHECK(ShapeUtil::IsScalar(condition->root_instruction()->shape()) &&
    980                condition->root_instruction()->shape().element_type() == PRED)
    981       << "While condition computation must return bool";
    982   // Build ForThunk for conformant while loops, otherwise build WhileThunk.
    983   auto config = xla_while->backend_config<WhileLoopBackendConfig>();
    984   if (config.ok() && config.ValueOrDie().has_known_trip_count()) {
    985     AddThunkToThunkSequence(
    986         BuildForThunk(xla_while, config.ValueOrDie().known_trip_count().n()));
    987   } else {
    988     AddThunkToThunkSequence(BuildWhileThunk(xla_while));
    989   }
    990   return Status::OK();
    991 }
    992 
    993 Status IrEmitterUnnested::HandleRng(HloInstruction* rng) {
    994   // Build the kernel to generate the random numbers.
    995   //
    996   // Unroll the kernel so that the duplicated computation that calculates the
    997   // 128 bit sample can be optimized away by LLVM.
    998   std::unique_ptr<KernelThunk> rng_thunk = BuildKernelThunk(
    999       rng, /*implements_whole_instruction=*/false, ComputeMaxUnrollFactor(rng));
   1000   ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
   1001   for (const HloInstruction* operand : rng->operands()) {
   1002     operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
   1003       return GetIrArray(*operand, *rng).EmitReadArrayElement(index, &b_);
   1004     };
   1005   }
   1006   TF_RETURN_IF_ERROR(EmitTargetElementLoopInThunk(
   1007       *rng,
   1008       GpuElementalIrEmitter(hlo_module_config_, module_, &b_,
   1009                             GetNestedComputer())
   1010           .MakeElementGenerator(rng, operand_to_generator),
   1011       rng_thunk.get()));
   1012 
   1013   // Emit a kernel to increment the global state for Philox RNG algorithm.
   1014   std::unique_ptr<Thunk> increment_seed_thunk =
   1015       BuildKernelThunk(rng, /*implements_whole_instruction=*/false);
   1016   llvm_ir::IncrementVariableForPhiloxRngState(1, module_, &b_);
   1017 
   1018   // Build the SequentialThunk for the RNG hlo.
   1019   std::vector<std::unique_ptr<Thunk>> thunks;
   1020   thunks.reserve(2);
   1021   thunks.push_back(std::move(rng_thunk));
   1022   thunks.push_back(std::move(increment_seed_thunk));
   1023   AddThunkToThunkSequence(
   1024       absl::make_unique<SequentialThunk>(std::move(thunks), rng));
   1025 
   1026   return Status::OK();
   1027 }
   1028 
   1029 Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) {
   1030   const HloInstruction* operand = scatter->operand(0);
   1031   const HloInstruction* scatter_indices = scatter->operand(1);
   1032   const HloInstruction* updates = scatter->operand(2);
   1033 
   1034   std::vector<std::unique_ptr<Thunk>> thunks;
   1035 
   1036   // Copy the operand into the output if it's not the same buffer already.
   1037   auto operand_buffer = GetAllocationSlice(*operand);
   1038   auto destination_buffer = GetAllocationSlice(*scatter);
   1039   if (operand_buffer != destination_buffer) {
   1040     thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
   1041         /*source_address=*/operand_buffer,
   1042         /*destination_buffer=*/destination_buffer,
   1043         /*mem_size=*/ShapeUtil::ByteSizeOf(operand->shape()), scatter));
   1044   }
   1045 
   1046   thunks.push_back(
   1047       BuildKernelThunk(scatter,
   1048                        /*implements_whole_instruction=*/thunks.empty()));
   1049 
   1050   TF_RETURN_IF_ERROR(EmitScatter(
   1051       thunks.back().get(), scatter,
   1052       /*scatter_indices_gen=*/
   1053       [=](const IrArray::Index& index) {
   1054         return GetIrArray(*scatter_indices, *scatter)
   1055             .EmitReadArrayElement(index, &b_, "scatter_index");
   1056       },
   1057       /*updates_gen=*/
   1058       [=](const IrArray::Index& index) {
   1059         return GetIrArray(*updates, *scatter)
   1060             .EmitReadArrayElement(index, &b_, "update");
   1061       }));
   1062 
   1063   // Elide the sequential thunk if there's no copy.
   1064   if (thunks.size() == 1) {
   1065     AddThunkToThunkSequence(std::move(thunks[0]));
   1066   } else {
   1067     AddThunkToThunkSequence(
   1068         absl::make_unique<SequentialThunk>(std::move(thunks), scatter));
   1069   }
   1070 
   1071   return Status::OK();
   1072 }
   1073 
   1074 Status IrEmitterUnnested::EmitScatter(
   1075     Thunk* thunk, HloInstruction* scatter,
   1076     const llvm_ir::ElementGenerator& scatter_indices_gen,
   1077     const llvm_ir::ElementGenerator& updates_gen) {
   1078   const HloInstruction* operand = scatter->operand(0);
   1079   const HloInstruction* scatter_indices = scatter->operand(1);
   1080   const HloInstruction* updates = scatter->operand(2);
   1081   const ScatterDimensionNumbers& dim_numbers =
   1082       scatter->scatter_dimension_numbers();
   1083   CHECK(ShapeUtil::Equal(scatter->shape(), operand->shape()));
   1084 
   1085   auto loop_body_emitter = [&](const IrArray::Index& index) -> Status {
   1086     std::vector<llvm::Value*> raw_window_multidim;
   1087     std::vector<llvm::Value*> input_scatter_multidim;
   1088     std::vector<int64> raw_window_bounds;
   1089 
   1090     // Partition the index into window indices and scatter indices.
   1091     for (int64 i = 0, e = index.size(); i != e; ++i) {
   1092       // For window indices also remember the window size, this comes in handy
   1093       // later.
   1094       if (absl::c_binary_search(dim_numbers.update_window_dims(), i)) {
   1095         raw_window_multidim.push_back(index[i]);
   1096         raw_window_bounds.push_back(updates->shape().dimensions(i));
   1097       } else {
   1098         input_scatter_multidim.push_back(index[i]);
   1099       }
   1100     }
   1101     DCHECK_EQ(raw_window_multidim.size(),
   1102               dim_numbers.update_window_dims_size());
   1103 
   1104     // Apply inserted_window_dims to the window dimensions.
   1105     int64 raw_window_multidim_idx = 0;
   1106     std::vector<llvm::Value*> input_window_multidim;
   1107     std::vector<int64> input_window_bounds;
   1108     for (int64 i = 0, e = operand->shape().rank(); i != e; ++i) {
   1109       if (absl::c_binary_search(dim_numbers.inserted_window_dims(), i)) {
   1110         input_window_bounds.push_back(1);  // Trivial dimension.
   1111         input_window_multidim.push_back(index.GetConstantWithIndexType(0));
   1112       } else {
   1113         input_window_bounds.push_back(
   1114             raw_window_bounds[raw_window_multidim_idx]);
   1115         input_window_multidim.push_back(
   1116             raw_window_multidim[raw_window_multidim_idx]);
   1117         ++raw_window_multidim_idx;
   1118       }
   1119     }
   1120     DCHECK_EQ(input_window_multidim.size(), operand->shape().rank());
   1121 
   1122     // Insert a 1 dimension at the end if index_vector_dim requests one.
   1123     Shape scatter_indices_shape = scatter_indices->shape();
   1124     if (dim_numbers.index_vector_dim() == scatter_indices_shape.rank()) {
   1125       scatter_indices_shape.add_dimensions(1);
   1126       scatter_indices_shape.mutable_layout()->add_minor_to_major(
   1127           dim_numbers.index_vector_dim());
   1128     }
   1129 
   1130     // Now load the indices corresponding to the current window from
   1131     // scatter_indices.
   1132     std::vector<llvm::Value*> raw_scatter_index_multidim =
   1133         input_scatter_multidim;
   1134     raw_scatter_index_multidim.insert(
   1135         raw_scatter_index_multidim.begin() + dim_numbers.index_vector_dim(),
   1136         nullptr);
   1137     llvm::Value* is_in_bounds = b_.getTrue();
   1138     for (int64 i = 0, e = dim_numbers.scatter_dims_to_operand_dims_size();
   1139          i != e; ++i) {
   1140       // Our index is stored along index_vector_dim, insert that into the lookup
   1141       // index into scatter_indices.
   1142       raw_scatter_index_multidim[dim_numbers.index_vector_dim()] =
   1143           index.GetConstantWithIndexType(i);
   1144       llvm_ir::IrArray::Index raw_scatter_index_index(
   1145           raw_scatter_index_multidim, scatter_indices_shape, index.GetType());
   1146 
   1147       int64 operand_dim = dim_numbers.scatter_dims_to_operand_dims(i);
   1148       TF_ASSIGN_OR_RETURN(
   1149           llvm::Value* const loaded_scatter_index,
   1150           scatter_indices_gen(raw_scatter_index_index.SourceIndexOfReshape(
   1151               scatter_indices_shape, scatter_indices->shape(), &b_)));
   1152       // And add the index to our window index. This yields the output index.
   1153       llvm::Value* casted_scatter_index =
   1154           IntCast(loaded_scatter_index, index.GetType(),
   1155                   /*isSigned=*/true);
   1156       llvm::Value* dim_offset =
   1157           Add(input_window_multidim[operand_dim], casted_scatter_index);
   1158       input_window_multidim[operand_dim] = dim_offset;
   1159 
   1160       // Also do the bounds check now.
   1161       int64 max_index = operand->shape().dimensions(operand_dim) -
   1162                         input_window_bounds[operand_dim] + 1;
   1163       // is_in_bounds = index >= 0 && index < dim_size-window_size+1
   1164       //   --> index u< dim_size-window_size+1
   1165       is_in_bounds =
   1166           And(is_in_bounds, ICmpULT(casted_scatter_index,
   1167                                     index.GetConstantWithIndexType(max_index)));
   1168     }
   1169 
   1170     llvm_ir::LlvmIfData if_window_in_bounds_data = llvm_ir::EmitIfThenElse(
   1171         is_in_bounds, "scatter.in_bounds", &b_, /*emit_else=*/false);
   1172     llvm_ir::SetToFirstInsertPoint(if_window_in_bounds_data.true_block, &b_);
   1173     // All done, now just read from the calculated input from the window, and do
   1174     // an atomic store to the calculated location in the output.
   1175     llvm_ir::IrArray::Index input_window_index(input_window_multidim,
   1176                                                index.GetType());
   1177     HloInstruction* output_hlo =
   1178         scatter->IsFused() ? scatter->parent()->FusionInstruction() : scatter;
   1179     llvm::Value* output_address =
   1180         GetIrArray(*output_hlo, *output_hlo)
   1181             .EmitArrayElementAddress(input_window_index, &b_);
   1182     llvm::Value* input_address = Alloca(llvm_ir::PrimitiveTypeToIrType(
   1183         updates->shape().element_type(), module_));
   1184     TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, updates_gen(index));
   1185     Store(input_ir_value, input_address);
   1186     return EmitAtomicOperationForNestedComputation(
   1187         *scatter->to_apply(), output_address, input_address);
   1188   };
   1189 
   1190   // Launch a kernel that reads every element in the updates tensor. We could
   1191   // also do one kernel per window instead if bounds checks turn out to be a
   1192   // bottleneck.
   1193   LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
   1194       updates->shape(), ir_emitter_context_->device_description());
   1195   UpdateLaunchDimensions(launch_dimensions, thunk,
   1196                          ir_emitter_context_->llvm_module());
   1197 
   1198   return ParallelLoopEmitter(loop_body_emitter, updates->shape(),
   1199                              launch_dimensions, &b_)
   1200       .EmitLoop(IrName(scatter),
   1201                 GetIndexTypeForKernel(scatter, launch_dimensions.launch_bound(),
   1202                                       &b_));
   1203 }
   1204 
   1205 Status IrEmitterUnnested::HandleSelect(HloInstruction* select) {
   1206   return IrEmitter::HandleSelect(select);
   1207 }
   1208 
   1209 Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
   1210   std::vector<std::unique_ptr<Thunk>> thunks;
   1211   Shape keys_shape = sort->operand(0)->shape();
   1212   int64 dimension_to_sort = sort->dimensions(0);
   1213   for (int64 i = 0; i < sort->operand_count(); ++i) {
   1214     ShapeIndex shape_index =
   1215         sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({});
   1216     // We assume that the layout of all involved operands and outputs is the
   1217     // same.
   1218     TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(keys_shape,
   1219                                                   sort->operand(i)->shape()));
   1220     TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
   1221         keys_shape, ShapeUtil::GetSubshape(sort->shape(), shape_index)));
   1222 
   1223     // If possible, we share buffers. If that is not possible, we need to copy
   1224     // the values, because the emitter does the sorting in-place.
   1225     auto destination_buffer = GetAllocationSlice(*sort, shape_index);
   1226     auto source_address = GetAllocationSlice(*sort->operand(i));
   1227     if (destination_buffer != source_address) {
   1228       // TODO(b/26783907): Figure out why we never seem to share buffers for
   1229       // key/value sort.
   1230       thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
   1231           /*source_address=*/source_address,
   1232           /*destination_buffer=*/destination_buffer,
   1233           /*mem_size=*/ShapeUtil::ByteSizeOf(sort->operand(i)->shape()),
   1234           nullptr));
   1235     }
   1236   }
   1237 
   1238   uint64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort);
   1239   int64 num_stages = tensorflow::Log2Ceiling(dimension_to_sort_bound);
   1240   CHECK_GE(1ULL << num_stages, dimension_to_sort_bound);
   1241   CHECK_LT(1ULL << (num_stages - 1), dimension_to_sort_bound);
   1242 
   1243   // Naive C++ code for the outer loops:
   1244   //
   1245   // for (int64 stage = 0; stage < Log2Ceiling(dimension_to_sort_bound);
   1246   //     ++stage) {
   1247   //   int64 first_xor_mask = (1LL << (stage + 1)) - 1;
   1248   //   SortInPlace(first_xor_mask);
   1249   //   for (int64 mask = stage - 1; mask >= 0; --mask) {
   1250   //     int64 later_xor_mask = 1LL << mask;
   1251   //     SortInPlace(later_xor_mask);
   1252   //   }
   1253   // }
   1254   //
   1255   // This follows the alternative representation of the algorithm described on
   1256   // Wikipedia: https://en.wikipedia.org/wiki/Bitonic_sorter
   1257   //
   1258   // Each mask specifies how to derive from one position in the array the
   1259   // position with which it should be compared (we calculate the xor of the
   1260   // position with the mask).
   1261   // As an optimization, we can move the 'mask' loop to inside the
   1262   // sorting/comparison loop if the comparisons happen within a small block of
   1263   // the array. To make this work, we collect all consecutive masks that are
   1264   // smaller than our chosen power of 2 tile size, and pass them to SortInPlace.
   1265   // Each thread then processes one tile of data.
   1266 
   1267   const uint64 kTileSize = std::min(2048ULL, 1ULL << num_stages);
   1268 
   1269   // If we cannot combine several xor masks together, we don't use tiling, so we
   1270   // calculate the standard launch dimensions for the shape. However we only
   1271   // need to iterate through ~half of the dimension to sort (rounded up to the
   1272   // next highest power of 2), because each iteration compares one pair of
   1273   // elements.
   1274   Shape standard_iteration_shape = keys_shape;
   1275   uint64 standard_num_iterations_in_sort_dim = 1ULL << (num_stages - 1);
   1276   standard_iteration_shape.set_dimensions(dimension_to_sort,
   1277                                           standard_num_iterations_in_sort_dim);
   1278   LaunchDimensions standard_launch_dimensions = CalculateLaunchDimensions(
   1279       standard_iteration_shape, ir_emitter_context_->device_description());
   1280 
   1281   // Calculate the launch dimensions for the case where we use tiling. We split
   1282   // the dimension that should be sorted into tiles of size 'kTileSize'. This
   1283   // means we first need to round 'dimension_to_sort_bound' up to be a multiple
   1284   // of the tile size.
   1285   int64 rounded_bound = RoundUpToNearest(dimension_to_sort_bound, kTileSize);
   1286   Shape iteration_shape = keys_shape;
   1287 
   1288   // We iterate through the element pairs that should be compared.
   1289   uint64 num_iterations_in_sort_dim = rounded_bound / 2;
   1290   iteration_shape.set_dimensions(dimension_to_sort, num_iterations_in_sort_dim);
   1291   uint64 num_iterations = ShapeUtil::ElementsIn(iteration_shape);
   1292 
   1293   // For correctness reasons we need exactly 'kTileSize' / 2 many threads per
   1294   // block. Each thread is responsible for copying exactly two adjacent elements
   1295   // into shared memory, and then does a comparison of two possibly different
   1296   // elements taken from shared memory.
   1297   const uint64 kThreadsPerBlock = kTileSize / 2;
   1298 
   1299   // Check whether we should use any tiling. We might not be able to use it if
   1300   // we have not enough threads, or not enough shared memory. Also it does not
   1301   // give a speedup if the tile size is < 128.
   1302   int64 total_shared_memory_needed = 0;
   1303   for (int64 i = 0; i < sort->operand_count(); ++i) {
   1304     total_shared_memory_needed +=
   1305         kTileSize * ShapeUtil::ByteSizeOfPrimitiveType(
   1306                         sort->operand(i)->shape().element_type());
   1307   }
   1308   bool no_tiling =
   1309       kTileSize < 128 ||
   1310       kThreadsPerBlock >
   1311           ir_emitter_context_->device_description().threads_per_block_limit() ||
   1312       total_shared_memory_needed >
   1313           ir_emitter_context_->device_description().shared_memory_per_block();
   1314 
   1315   uint64 num_blocks = CeilOfRatio(num_iterations, kThreadsPerBlock);
   1316   LaunchDimensions tiled_launch_dimensions(num_blocks, kThreadsPerBlock);
   1317 
   1318   auto emit_kernel = [&](absl::Span<const int64> xor_masks) {
   1319     thunks.push_back(
   1320         BuildKernelThunk(sort, /*implements_whole_instruction=*/false));
   1321     LaunchDimensions launch_dimensions = xor_masks.size() > 1
   1322                                              ? tiled_launch_dimensions
   1323                                              : standard_launch_dimensions;
   1324     UpdateLaunchDimensions(launch_dimensions, thunks.back().get(),
   1325                            ir_emitter_context_->llvm_module());
   1326     std::vector<IrArray> values_arrays;
   1327     values_arrays.reserve(sort->operand_count());
   1328     for (int64 i = 0; i < sort->operand_count(); ++i) {
   1329       ShapeIndex shape_index =
   1330           sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({});
   1331       values_arrays.push_back(GetIrArray(*sort, *sort, shape_index));
   1332     }
   1333     return llvm_ir::EmitSortInPlace(
   1334         dimension_to_sort, values_arrays, IrName(sort), xor_masks, &b_,
   1335         launch_dimensions,
   1336         xor_masks.size() > 1 ? num_iterations_in_sort_dim
   1337                              : standard_num_iterations_in_sort_dim,
   1338         kTileSize,
   1339         [&](absl::Span<llvm::Value* const> operands, llvm::Value* output) {
   1340           return EmitCallToNestedComputation(*sort->to_apply(), operands,
   1341                                              output);
   1342         });
   1343   };
   1344   std::vector<int64> xor_masks;
   1345   for (int64 stage = 0; stage < num_stages; ++stage) {
   1346     for (int64 mask = stage; mask >= 0; --mask) {
   1347       int64 xor_mask;
   1348       if (mask == stage) {
   1349         xor_mask = (1LL << (stage + 1)) - 1;
   1350       } else {
   1351         xor_mask = 1LL << mask;
   1352       }
   1353       if (xor_mask >= kTileSize || no_tiling) {
   1354         if (!xor_masks.empty()) {
   1355           TF_RETURN_IF_ERROR(emit_kernel(xor_masks));
   1356           xor_masks.clear();
   1357         }
   1358         TF_RETURN_IF_ERROR(emit_kernel({xor_mask}));
   1359       } else {
   1360         xor_masks.push_back(xor_mask);
   1361       }
   1362     }
   1363   }
   1364   if (!xor_masks.empty()) {
   1365     TF_RETURN_IF_ERROR(emit_kernel(xor_masks));
   1366   }
   1367 
   1368   AddThunkToThunkSequence(
   1369       absl::make_unique<SequentialThunk>(std::move(thunks), sort));
   1370   return Status::OK();
   1371 }
   1372 
   1373 Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) {
   1374   AddThunkToThunkSequence(
   1375       BuildKernelThunk(tuple_select, /*implements_whole_instruction=*/true));
   1376   return IrEmitter::HandleTupleSelect(tuple_select);
   1377 }
   1378 
   1379 namespace {
   1380 
   1381 bool IsScalarAddComputation(HloComputation* computation) {
   1382   return Match(computation->root_instruction(),
   1383                m::AddAnyOrder(m::Parameter(0), m::Parameter(1))
   1384                    .WithShape(m::Shape().IsEffectiveScalar()));
   1385 }
   1386 
   1387 }  // namespace
   1388 
   1389 Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) {
   1390   VLOG(2) << "AllReduce; replica count: " << hlo_module_config_.replica_count()
   1391           << "; operand count: " << crs->operand_count()
   1392           << "; NCCL is enabled: " << NcclAllReduceThunk::NcclIsEnabled();
   1393 
   1394   // Note the replica_count == 1 case is handled via device-to-device copy
   1395   // below.
   1396   bool should_use_nccl_thunk =
   1397       hlo_module_config_.replica_count() > 1 &&
   1398       crs->IsCrossReplicaAllReduce() &&
   1399       crs->operand_count() == 1 &&  // One array to reduce.
   1400       crs->operand(0)->shape().element_type() == F32 &&
   1401       // Check the computation is a summation.
   1402       IsScalarAddComputation(crs->to_apply());
   1403 
   1404   if (should_use_nccl_thunk) {
   1405     CHECK(crs->operand(0)->shape().IsArray())
   1406         << "Operands to all-reduce must be arrays: " << crs->ToString();
   1407     AddThunkToThunkSequence(absl::make_unique<NcclAllReduceThunk>(
   1408         /*replica_count=*/hlo_module_config_.replica_count(),
   1409         /*elements=*/ShapeUtil::ElementsIn(crs->operand(0)->shape()),
   1410         /*source_address=*/GetAllocationSlice(*crs->operand(0)),
   1411         /*destination_buffer=*/GetAllocationSlice(*crs), crs));
   1412     return Status::OK();
   1413   }
   1414 
   1415   if (hlo_module_config_.replica_count() != 1) {
   1416     // TODO(b/33011107): Support more AllReduce configurations on GPU.
   1417     string message = absl::StrFormat(
   1418         "Requested AllReduce not implemented on GPU; replica_count: %d; "
   1419         "operand_count: %d; IsCrossReplicaAllReduce: %d; NCCL support: %d",
   1420         hlo_module_config_.replica_count(), crs->operand_count(),
   1421         crs->IsCrossReplicaAllReduce(), NcclAllReduceThunk::NcclIsEnabled());
   1422     if (crs->operand_count() > 0) {
   1423       absl::StrAppendFormat(
   1424           &message, "; first operand array element-type: %s",
   1425           PrimitiveType_Name(crs->operand(0)->shape().element_type()));
   1426     }
   1427     return Unimplemented("%s", message);
   1428   }
   1429 
   1430   // CRS with one operand and one replica is simply the identity function.
   1431   // Buffer assignment expects a copy, so that's what we do.
   1432   //
   1433   // TODO(b/80100934): We would like to eliminate one-replica CRS nodes entirely
   1434   // in algebraic-simplifier, but currently on some platforms
   1435   // HloModuleConfig::num_replicas changes between when the module is compiled
   1436   // and when it's run.
   1437   if (crs->operand_count() == 1) {
   1438     CHECK(crs->operand(0)->shape().IsArray())
   1439         << "Operands to all-reduce must be arrays: " << crs->ToString();
   1440     AddThunkToThunkSequence(absl::make_unique<DeviceToDeviceCopyThunk>(
   1441         /*source_address=*/GetAllocationSlice(*crs->operand(0)),
   1442         /*destination_buffer=*/GetAllocationSlice(*crs),
   1443         /*mem_size=*/ShapeUtil::ByteSizeOf(crs->shape()), crs));
   1444     return Status::OK();
   1445   }
   1446 
   1447   // One-replica CRS with multiple operands produces a tuple of the inputs.
   1448   // Again, buffer assignment expects us to copy each.
   1449   std::vector<std::unique_ptr<Thunk>> thunks;
   1450   std::vector<BufferAllocation::Slice> tuple_element_buffers;
   1451   for (int64 i = 0; i < crs->operand_count(); ++i) {
   1452     tuple_element_buffers.push_back(ir_emitter_context_->buffer_assignment()
   1453                                         .GetUniqueSlice(crs, {i})
   1454                                         .ValueOrDie());
   1455     thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
   1456         /*source_address=*/GetAllocationSlice(*crs->operand(i)),
   1457         /*destination_buffer=*/tuple_element_buffers.back(),
   1458         /*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), nullptr));
   1459   }
   1460 
   1461   // Output a tuple of the buffers above.
   1462   thunks.push_back(absl::make_unique<TupleThunk>(
   1463       tuple_element_buffers, GetAllocationSlice(*crs), nullptr));
   1464   AddThunkToThunkSequence(
   1465       absl::make_unique<SequentialThunk>(std::move(thunks), crs));
   1466   return Status::OK();
   1467 }
   1468 
   1469 Status IrEmitterUnnested::HandleAfterAll(HloInstruction* after_all) {
   1470   return Status::OK();
   1471 }
   1472 
   1473 Status IrEmitterUnnested::HandleInfeed(HloInstruction* infeed) {
   1474   AddThunkToThunkSequence(BuildInfeedThunk(infeed));
   1475   return Status::OK();
   1476 }
   1477 
   1478 Status IrEmitterUnnested::HandleOutfeed(HloInstruction* outfeed) {
   1479   AddThunkToThunkSequence(BuildOutfeedThunk(outfeed));
   1480   return Status::OK();
   1481 }
   1482 
   1483 // Figures out how to access the buffers for all subshapes of hlo's operands and
   1484 // for hlo itself (i.e. all the buffers produced by HLO).
   1485 //
   1486 // Returns a map keyed on the pair {HloInstruction, ShapeIndex}.  The value for
   1487 // this key is a pair {Slice, ShapeIndex}, where the slice tells you the root
   1488 // buffer to look in, and the ShapeIndex describes how to dereference starting
   1489 // at that buffer to get to the buffer in question.
   1490 //
   1491 // For example, if {hlo, {1}} is mapped to {slice, {3, 4}}, then the buffer for
   1492 // hlo at ShapeIndex {1} (i.e. the buffer for the second tuple element of hlo)
   1493 // is found at slice[3][4].  That is, slice is a void***, which we dereference
   1494 // twice -- first at index 3, and then at index 4 -- to get the address of our
   1495 // buffer.
   1496 //
   1497 // This function conservatively assumes that we'll touch all sub-buffers of
   1498 // every operand and of the output.
   1499 static std::map<std::pair<const HloInstruction*, ShapeIndex>,
   1500                 std::pair<BufferAllocation::Slice, ShapeIndex>>
   1501 GetHloBufferSlices(const HloInstruction* hlo,
   1502                    const BufferAssignment& buffer_assn) {
   1503   std::map<std::pair<const HloInstruction*, ShapeIndex>,
   1504            std::pair<BufferAllocation::Slice, ShapeIndex>>
   1505       slices;
   1506 
   1507   // Tries to find a slice plus an array of indices i1, ..., iN such that the
   1508   // sub-buffer for instr at index can be found at slice[i1]...[iN].
   1509   auto find_slice_for = [&](const HloInstruction* instr,
   1510                             const ShapeIndex& index)
   1511       -> optional<std::pair<BufferAllocation::Slice, ShapeIndex>> {
   1512     // Simple, common case: Is the buffer for instr known at runtime?  If so,
   1513     // we're done.
   1514     auto slice = buffer_assn.GetUniqueSlice(instr, index);
   1515     if (slice.ok()) {
   1516       return {{slice.ValueOrDie(), ShapeIndex()}};
   1517     }
   1518 
   1519     // If that didn't work, walk up any bitcasts that we might see.  These must
   1520     // appear before any GTE instructions, because it's illegal to bitcast to a
   1521     // tuple type.
   1522     const HloInstruction* parent = instr;
   1523     while (parent->opcode() == HloOpcode::kBitcast) {
   1524       parent = parent->operand(0);
   1525 
   1526       auto slice = buffer_assn.GetUniqueSlice(parent, {});
   1527       if (slice.ok()) {
   1528         return {{slice.ValueOrDie(), ShapeIndex()}};
   1529       }
   1530     }
   1531 
   1532     // Check whether instr is a GTE instruction.  If it is, see if we can get a
   1533     // buffer for its parent, and continue walking up parents until we find a
   1534     // defined buffer or we hit something that's not a GTE.
   1535     ShapeIndex gte_indices;
   1536     while (parent->opcode() == HloOpcode::kGetTupleElement) {
   1537       gte_indices.push_front(parent->tuple_index());
   1538       parent = parent->operand(0);
   1539 
   1540       auto slice = buffer_assn.GetUniqueSlice(parent, {});
   1541       if (slice.ok()) {
   1542         return {{slice.ValueOrDie(), gte_indices}};
   1543       }
   1544     }
   1545 
   1546     // Finally, if we don't know the buffer for instr at index, see if we know
   1547     // the buffer for instr at index without its last element.  If so, we can
   1548     // dynamically find the buffer for instr by dereferencing a pointer in that
   1549     // buffer.  Continue looking this way until we run out of elements in
   1550     // 'index'.
   1551     //
   1552     // We can almost always get a buffer without resorting to this.  The only
   1553     // exception is for cases where the relevant sub-buffer is truly unknowable,
   1554     // for example the sub-buffer of a tuple-shaped select.
   1555     ShapeIndex new_index = index;
   1556     while (!new_index.empty()) {
   1557       gte_indices.push_front(new_index.back());
   1558       new_index.pop_back();
   1559       auto slice = buffer_assn.GetUniqueSlice(instr, new_index);
   1560       if (slice.ok()) {
   1561         return {{slice.ValueOrDie(), gte_indices}};
   1562       }
   1563     }
   1564 
   1565     return nullopt;
   1566   };
   1567 
   1568   // Adds entries for all subshapes of instr to `slices`.
   1569   auto add_slices_for = [&](const HloInstruction* instr) {
   1570     ShapeUtil::ForEachSubshape(
   1571         instr->shape(), [&](const Shape& /*shape*/, const ShapeIndex& index) {
   1572           if (slices.count({instr, index})) {
   1573             // HLOs can have duplicate operands; don't bother redoing work.
   1574             return;
   1575           }
   1576           auto maybe_slice = find_slice_for(instr, index);
   1577           if (maybe_slice.has_value()) {
   1578             slices[{instr, index}] = *maybe_slice;
   1579           } else {
   1580             VLOG(1) << "Couldn't find buffer for " << instr->ToString()
   1581                     << " at index " << index.ToString();
   1582           }
   1583         });
   1584   };
   1585 
   1586   add_slices_for(hlo);
   1587   for (const HloInstruction* operand : hlo->operands()) {
   1588     // Conservatively assume we'll need the buffers for all subshapes of the
   1589     // operand.
   1590     add_slices_for(operand);
   1591   }
   1592 
   1593   return slices;
   1594 }
   1595 
   1596 std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
   1597     const HloInstruction* inst, bool implements_whole_instruction,
   1598     int unroll_factor) {
   1599   const BufferAssignment& buffer_assn =
   1600       ir_emitter_context_->buffer_assignment();
   1601 
   1602   std::map<std::pair<const HloInstruction*, ShapeIndex>,
   1603            std::pair<BufferAllocation::Slice, ShapeIndex>>
   1604       hlo_slices = GetHloBufferSlices(inst, buffer_assn);
   1605 
   1606   // Figure out which buffer allocations need to be passed as arguments to our
   1607   // kernel.  This is simply all of the allocations referenced in hlo_slices,
   1608   // plus the XLA temp buffer (if we have it).  We always include the temp
   1609   // buffer because even if the kernel itself doesn't use it, a nested
   1610   // subcomputation within the kernel (e.g. a kMap's computation) might.
   1611   std::unordered_set<const BufferAllocation*> buffers_needed;
   1612   for (const auto& kv : hlo_slices) {
   1613     buffers_needed.insert(kv.second.first.allocation());
   1614   }
   1615   absl::optional<const BufferAllocation*> temp_buffer;
   1616   for (const BufferAllocation& alloc : buffer_assn.Allocations()) {
   1617     if (alloc.IsPreallocatedTempBuffer()) {
   1618       if (!temp_buffer.has_value()) {
   1619         temp_buffer = &alloc;
   1620       } else {
   1621         LOG(FATAL) << "Multiple temp buffers found, but only one is allowed!";
   1622       }
   1623     }
   1624   }
   1625   if (temp_buffer.has_value()) {
   1626     buffers_needed.insert(*temp_buffer);
   1627   }
   1628 
   1629   // We'll pass a pointer to each of the elements of `buffers` to our kernel, in
   1630   // this order.
   1631   std::vector<const BufferAllocation*> non_constant_buffers;
   1632   absl::c_copy_if(buffers_needed, std::back_inserter(non_constant_buffers),
   1633                   [](const BufferAllocation* allocation) {
   1634                     return !allocation->is_constant();
   1635                   });
   1636 
   1637   absl::c_sort(non_constant_buffers,
   1638                [](const BufferAllocation* a, const BufferAllocation* b) {
   1639                  return a->index() < b->index();
   1640                });
   1641 
   1642   llvm::Function* kernel = BuildKernelPrototype(*inst, non_constant_buffers);
   1643 
   1644   // Build a map from a BufferAllocation to the corresponding argument in our
   1645   // kernel.
   1646   std::unordered_map<const BufferAllocation*, llvm::Value*> kernel_args;
   1647   {
   1648     auto arg_it = kernel->arg_begin();
   1649     auto buffers_it = non_constant_buffers.begin();
   1650     for (; arg_it != kernel->arg_end(); ++arg_it, ++buffers_it) {
   1651       kernel_args[*buffers_it] = arg_it;
   1652     }
   1653   }
   1654 
   1655   // For each buffer our kernel might want to touch, bind it to a value derived
   1656   // from our kernel args.
   1657   for (const auto& kv : hlo_slices) {
   1658     const HloInstruction* instr = kv.first.first;
   1659     const ShapeIndex& index = kv.first.second;
   1660     const BufferAllocation::Slice& slice = kv.second.first;
   1661     const ShapeIndex& gte_index = kv.second.second;
   1662 
   1663     VLOG(3) << "Buffer for " << instr->ToString() << " at " << index.ToString()
   1664             << " is found in slice " << slice.ToString() << " at GTE index "
   1665             << gte_index.ToString();
   1666 
   1667     llvm::Value* loc;
   1668     if (slice.allocation()->is_constant()) {
   1669       loc = ir_emitter_context_->llvm_module()->getGlobalVariable(
   1670           llvm_ir::ConstantBufferAllocationToGlobalName(*slice.allocation()));
   1671       CHECK_NE(loc, nullptr);
   1672     } else {
   1673       loc = InBoundsGEP(kernel_args.at(slice.allocation()),
   1674                         {b_.getInt64(slice.offset())});
   1675     }
   1676 
   1677     // If gte_index is nonempty, we have to dereference `loc` to get to the
   1678     // value we're ultimately interested in.
   1679     llvm::Type* int8_double_pointer =
   1680         llvm::PointerType::get(b_.getInt8PtrTy(), /*AddressSpace=*/0);
   1681     for (int64 idx : gte_index) {
   1682       loc = BitCast(loc, int8_double_pointer);
   1683       loc = Load(InBoundsGEP(loc, {b_.getInt64(idx)}));
   1684     }
   1685 
   1686     bindings_.BindHloToIrValue(*instr, loc, index);
   1687   }
   1688 
   1689   // Bind the temp buffer so that nested subcomputations can find it if they
   1690   // need.
   1691   if (temp_buffer.has_value()) {
   1692     bindings_.SetTempBufferBase(kernel_args.at(*temp_buffer));
   1693   } else {
   1694     bindings_.SetTempBufferBase(
   1695         llvm::ConstantPointerNull::get(b_.getInt8PtrTy()));
   1696   }
   1697 
   1698   return absl::make_unique<KernelThunk>(
   1699       non_constant_buffers, kernel->getName(),
   1700       implements_whole_instruction ? inst : nullptr, unroll_factor);
   1701 }
   1702 
   1703 std::unique_ptr<Thunk> IrEmitterUnnested::BuildHostToDeviceCopyThunk(
   1704     const HloInstruction* inst) {
   1705   const HloInstruction* operand = inst->operand(0);
   1706   CHECK_EQ(HloOpcode::kConstant, operand->opcode());
   1707   return absl::make_unique<HostToDeviceCopyThunk>(
   1708       /*source_address=*/operand->literal().untyped_data(),
   1709       /*destination_buffer=*/GetAllocationSlice(*inst),
   1710       /*mem_size=*/
   1711       llvm_ir::ByteSizeOf(operand->shape(),
   1712                           ir_emitter_context_->llvm_module()->getDataLayout()),
   1713       inst);
   1714 }
   1715 
   1716 std::unique_ptr<Thunk> IrEmitterUnnested::BuildDeviceToDeviceCopyThunk(
   1717     const HloInstruction* inst) {
   1718   const HloInstruction* operand = inst->operand(0);
   1719   return absl::make_unique<DeviceToDeviceCopyThunk>(
   1720       /*source_address=*/GetAllocationSlice(*operand),
   1721       /*destination_buffer=*/GetAllocationSlice(*inst),
   1722       /*mem_size=*/
   1723       llvm_ir::ByteSizeOf(operand->shape(),
   1724                           ir_emitter_context_->llvm_module()->getDataLayout()),
   1725       inst);
   1726 }
   1727 
   1728 std::unique_ptr<Thunk> IrEmitterUnnested::BuildInfeedThunk(
   1729     const HloInstruction* inst) {
   1730   CHECK_EQ(HloOpcode::kInfeed, inst->opcode());
   1731 
   1732   ShapeTree<BufferAllocation::Slice> slices(inst->shape());
   1733   slices.ForEachMutableElement(
   1734       [&](const ShapeIndex& index, BufferAllocation::Slice* slice) {
   1735         *slice = ir_emitter_context_->buffer_assignment()
   1736                      .GetUniqueSlice(inst, index)
   1737                      .ConsumeValueOrDie();
   1738       });
   1739   return absl::make_unique<InfeedThunk>(slices, inst);
   1740 }
   1741 
   1742 std::unique_ptr<Thunk> IrEmitterUnnested::BuildOutfeedThunk(
   1743     const HloInstruction* inst) {
   1744   CHECK_EQ(HloOpcode::kOutfeed, inst->opcode());
   1745 
   1746   ShapeTree<BufferAllocation::Slice> slices(inst->operand(0)->shape());
   1747   slices.ForEachMutableElement(
   1748       [&](const ShapeIndex& index, BufferAllocation::Slice* slice) {
   1749         auto status_or_slice =
   1750             ir_emitter_context_->buffer_assignment().GetUniqueSlice(
   1751                 inst->operand(0), index);
   1752         if (status_or_slice.ok()) {
   1753           *slice = status_or_slice.ConsumeValueOrDie();
   1754         }
   1755       });
   1756   return absl::make_unique<OutfeedThunk>(std::move(slices), inst);
   1757 }
   1758 
   1759 namespace {
   1760 double GetScalarConstantAsDouble(const Literal& literal) {
   1761   switch (literal.shape().element_type()) {
   1762     case F16:
   1763       return static_cast<double>(literal.Get<Eigen::half>({}));
   1764     case F32:
   1765       return literal.Get<float>({});
   1766     case F64:
   1767       return literal.Get<double>({});
   1768     default:
   1769       LOG(FATAL) << "Unsupported type.";
   1770   }
   1771 }
   1772 }  // namespace
   1773 
   1774 std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk(
   1775     const HloInstruction* inst) {
   1776   if (inst->opcode() == HloOpcode::kDot) {
   1777     const HloInstruction* lhs = inst->operand(0);
   1778     const HloInstruction* rhs = inst->operand(1);
   1779     return absl::make_unique<GemmThunk>(
   1780         GetAllocationSlice(*lhs),   // The buffer assigned to LHS.
   1781         GetAllocationSlice(*rhs),   // The buffer assigned to RHS.
   1782         GetAllocationSlice(*inst),  // The output buffer.
   1783         lhs->shape(),               // The shape of LHS.
   1784         rhs->shape(),               // The shape of RHS.
   1785         inst->shape(),              // The shape of the output.
   1786         1.0,                        // alpha.
   1787         0.0,                        // beta.
   1788         inst, /*implements_whole_instruction=*/true);
   1789   }
   1790 
   1791   if (inst->opcode() == HloOpcode::kFusion) {
   1792     CHECK_EQ(inst->fusion_kind(), HloInstruction::FusionKind::kOutput);
   1793     const HloInstruction* output_fused_op = inst->fused_expression_root();
   1794 
   1795     double alpha_value = 1.0;
   1796     const HloInstruction* bias = nullptr;
   1797     const HloInstruction* dot = output_fused_op->operand(0);
   1798     if (output_fused_op->opcode() == HloOpcode::kMultiply) {
   1799       const HloInstruction* alpha = output_fused_op->operand(1);
   1800       if (dot->opcode() != HloOpcode::kDot) {
   1801         std::swap(dot, alpha);
   1802       }
   1803       if (alpha->opcode() == HloOpcode::kBroadcast) {
   1804         alpha = alpha->operand(0);
   1805       }
   1806       if (alpha->opcode() == HloOpcode::kParameter) {
   1807         alpha = inst->operand(alpha->parameter_number());
   1808       }
   1809       // TODO(b/74185543): Remove the following if block once we support fusion
   1810       // with a non-constant as well. Then we will just always use the constant
   1811       // on the device.
   1812       if (alpha->opcode() == HloOpcode::kCopy) {
   1813         alpha = alpha->operand(0);
   1814       }
   1815       alpha_value = GetScalarConstantAsDouble(alpha->literal());
   1816     } else {
   1817       // Fused bias add.
   1818       CHECK_EQ(output_fused_op->opcode(), HloOpcode::kAdd);
   1819       bias = output_fused_op->operand(1);
   1820       if (dot->opcode() != HloOpcode::kDot) {
   1821         std::swap(dot, bias);
   1822       }
   1823       bias = inst->operand(bias->parameter_number());
   1824     }
   1825 
   1826     DCHECK(dot->opcode() == HloOpcode::kDot);
   1827     const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0));
   1828     const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1));
   1829     DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter &&
   1830            rhs_parameter->opcode() == HloOpcode::kParameter);
   1831     const HloInstruction* lhs =
   1832         inst->operand(lhs_parameter->parameter_number());
   1833     const HloInstruction* rhs =
   1834         inst->operand(rhs_parameter->parameter_number());
   1835 
   1836     // The bias is passed inside the output buffer. If those buffers are shared
   1837     // we can just use it, otherwise copy the bias values into the output buffer
   1838     // first.
   1839     if (bias != nullptr &&
   1840         GetAllocationSlice(*bias) != GetAllocationSlice(*inst)) {
   1841       std::vector<std::unique_ptr<Thunk>> thunks;
   1842       thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
   1843           /*source_buffer=*/GetAllocationSlice(*bias),
   1844           /*destination_buffer=*/GetAllocationSlice(*inst),
   1845           /*mem_size=*/ShapeUtil::ByteSizeOf(inst->shape()), nullptr));
   1846       thunks.push_back(absl::make_unique<GemmThunk>(
   1847           GetAllocationSlice(*lhs),   // The buffer assigned to LHS.
   1848           GetAllocationSlice(*rhs),   // The buffer assigned to RHS.
   1849           GetAllocationSlice(*inst),  // The output buffer.
   1850           lhs->shape(),               // The shape of LHS.
   1851           rhs->shape(),               // The shape of RHS.
   1852           inst->shape(),              // The shape of the output.
   1853           alpha_value,                // alpha.
   1854           1.0,                        // beta.
   1855           inst, /*implements_whole_instruction=*/false));
   1856       return absl::make_unique<SequentialThunk>(std::move(thunks), inst);
   1857     }
   1858     return absl::make_unique<GemmThunk>(
   1859         GetAllocationSlice(*lhs),     // The buffer assigned to LHS.
   1860         GetAllocationSlice(*rhs),     // The buffer assigned to RHS.
   1861         GetAllocationSlice(*inst),    // The output buffer.
   1862         lhs->shape(),                 // The shape of LHS.
   1863         rhs->shape(),                 // The shape of RHS.
   1864         inst->shape(),                // The shape of the output.
   1865         alpha_value,                  // alpha.
   1866         bias != nullptr ? 1.0 : 0.0,  // beta.
   1867         inst, /*implements_whole_instruction=*/true);
   1868   }
   1869 
   1870   LOG(FATAL) << "Cannot build a GemmThunk for " << inst->ToString();
   1871 }
   1872 
   1873 std::unique_ptr<Thunk> IrEmitterUnnested::BuildFftThunk(
   1874     const HloInstruction* inst) {
   1875   const HloInstruction* operand = inst->operand(0);
   1876   return absl::make_unique<FftThunk>(
   1877       inst->fft_type(), inst->fft_length(),
   1878       /*input_buffer=*/GetAllocationSlice(*operand),
   1879       /*output_buffer=*/GetAllocationSlice(*inst),
   1880       /*input_shape=*/operand->shape(),
   1881       /*output_shape=*/inst->shape(), inst);
   1882 }
   1883 
   1884 std::unique_ptr<Thunk> IrEmitterUnnested::BuildTriangularSolveThunk(
   1885     const HloInstruction* inst) {
   1886   const HloInstruction* a = inst->operand(0);
   1887   const HloInstruction* b = inst->operand(1);
   1888   int64 m = b->shape().dimensions(b->shape().rank() - 2);
   1889   int64 n = b->shape().dimensions(b->shape().rank() - 1);
   1890   int64 batch_size = std::accumulate(
   1891       b->shape().dimensions().begin(), b->shape().dimensions().end() - 2,
   1892       int64{1}, [](int64 a, int64 b) { return a * b; });
   1893   int64 elem_size =
   1894       ShapeUtil::ByteSizeOfPrimitiveType(inst->shape().element_type());
   1895   int64 a_batch_stride = inst->triangular_solve_options().left_side()
   1896                              ? m * m * elem_size
   1897                              : n * n * elem_size;
   1898   int64 b_batch_stride = m * n * elem_size;
   1899   return absl::make_unique<TriangularSolveThunk>(
   1900       inst->triangular_solve_options(),
   1901       /*a_input_buffer=*/GetAllocationSlice(*a),
   1902       /*b_input_buffer=*/GetAllocationSlice(*inst),
   1903       inst->shape().element_type(), batch_size, m, n, a_batch_stride,
   1904       b_batch_stride, inst);
   1905 }
   1906 
   1907 StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
   1908     HloInstruction* hlo, const ShapeIndex& index) {
   1909   bool fused = HloOpcode::kFusion == hlo->opcode();
   1910   HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo;
   1911   HloInstruction* init_value_operand = [&] {
   1912     switch (inst->opcode()) {
   1913       case HloOpcode::kSelectAndScatter:
   1914         return inst->mutable_operand(2);
   1915       case HloOpcode::kReduce:
   1916         return inst->mutable_operand(1);
   1917       case HloOpcode::kTuple:
   1918         CHECK(hlo->IsMultiOutputFusion())
   1919             << ": " << hlo->ToString() << " is not a multi-output fusion.";
   1920         CHECK(inst->operand(index.back())->opcode() == HloOpcode::kReduce)
   1921             << ": Found '" << inst->operand(index.back())->opcode() << "' in "
   1922             << inst->ToString() << " but expected 'reduce'.";
   1923         // For multi-output fusion look through the tuple.
   1924         return inst->mutable_operand(index.back())->mutable_operand(1);
   1925       default:
   1926         LOG(FATAL) << "Opcode " << inst->opcode()
   1927                    << " should not need an initializer.";
   1928     }
   1929   }();
   1930 
   1931   const HloInstruction* init_value = init_value_operand;
   1932   if (fused && init_value->opcode() == HloOpcode::kParameter) {
   1933     init_value = hlo->operand(init_value->parameter_number());
   1934   }
   1935 
   1936   // Initializer thunks don't implement a whole instruction, and we want to
   1937   // profile the whole instruction instead of the individual thunks it consists
   1938   // of. Therefore we pass nullptr as the HloInstruction* to the thunks we
   1939   // generate below.
   1940   //
   1941   // In the common case, the initializer is a constant.  In this case, emit a
   1942   // device-memset call if we can.  Currently StreamExecutor only supports
   1943   // zeroing and 32-bit memsets.
   1944   if (init_value->IsConstant()) {
   1945     CHECK(ShapeUtil::IsScalar(init_value->shape()));
   1946     int64 num_bytes = ShapeUtil::ByteSizeOfElements(init_value->shape());
   1947     const auto& literal = init_value->literal();
   1948 
   1949     // Are all the bytes of this scalar equal to 0?  If so, we can create a
   1950     // MemzeroThunk.
   1951     absl::Span<const uint8> literal_bytes(
   1952         reinterpret_cast<const uint8*>(literal.untyped_data()), num_bytes);
   1953     if (absl::c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) {
   1954       return {absl::make_unique<MemzeroThunk>(GetAllocationSlice(*hlo, index),
   1955                                               nullptr)};
   1956     }
   1957 
   1958     // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by
   1959     // repeating the literal 4 or 2 times, so long as the destination buffer is
   1960     // an even multiple of 32 bits long.
   1961     const Shape& output_shape = ShapeUtil::GetSubshape(hlo->shape(), index);
   1962     if ((num_bytes == 1 || num_bytes == 2) &&
   1963         ShapeUtil::ByteSizeOf(output_shape) % 4 == 0) {
   1964       uint16 pattern16;
   1965       if (num_bytes == 1) {
   1966         uint8 b = literal_bytes.front();
   1967         pattern16 = uint16{b} | (uint16{b} << 8);
   1968       } else {
   1969         memcpy(&pattern16, literal_bytes.data(), sizeof(pattern16));
   1970       }
   1971       uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16);
   1972       return {absl::make_unique<Memset32BitValueThunk>(
   1973           pattern32, GetAllocationSlice(*hlo, index), nullptr)};
   1974     }
   1975 
   1976     // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit
   1977     // memset so long as all 32-bit words of the scalar are equal to each other.
   1978     if (num_bytes >= 4 && num_bytes % 4 == 0 &&
   1979         memcmp(literal_bytes.data(), literal_bytes.data() + 4,
   1980                literal_bytes.size() - 4) == 0) {
   1981       uint32 word;
   1982       memcpy(&word, literal_bytes.data(), sizeof(word));
   1983       return {absl::make_unique<Memset32BitValueThunk>(
   1984           word, GetAllocationSlice(*hlo, index), nullptr)};
   1985     }
   1986   }
   1987 
   1988   // Otherwise fall back to our slow initializer code.
   1989   std::unique_ptr<KernelThunk> kernel_thunk =
   1990       BuildKernelThunk(hlo, /*implements_whole_instruction=*/false);
   1991   LaunchDimensions launch_dimensions =
   1992       CalculateLaunchDimensions(ShapeUtil::GetSubshape(hlo->shape(), index),
   1993                                 ir_emitter_context_->device_description());
   1994   UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
   1995                          ir_emitter_context_->llvm_module());
   1996 
   1997   if (fused) {
   1998     // If init_value was fused into this reduce we have to generate it first.
   1999     GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
   2000                                             ir_emitter_context_->llvm_module(),
   2001                                             &b_, GetNestedComputer());
   2002 
   2003     FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(hlo),
   2004                                  &elemental_emitter);
   2005     TF_RETURN_IF_ERROR(init_value_operand->Accept(&fused_emitter));
   2006     TF_RETURN_IF_ERROR(
   2007         ParallelLoopEmitter(fused_emitter.GetGenerator(init_value_operand),
   2008                             GetIrArray(*hlo, *hlo, index), launch_dimensions,
   2009                             &b_)
   2010             .EmitLoop(IrName(hlo)));
   2011   } else {
   2012     // In the unfused case the element is already there, just read from it.
   2013     TF_RETURN_IF_ERROR(ParallelLoopEmitter(
   2014                            [=](const IrArray::Index& index) {
   2015                              return GetIrArray(*init_value, *hlo)
   2016                                  .EmitReadArrayElement(index, &b_);
   2017                            },
   2018                            GetIrArray(*hlo, *hlo, index), launch_dimensions,
   2019                            &b_)
   2020                            .EmitLoop(IrName(hlo)));
   2021   }
   2022 
   2023   // Clean up state left behind by emitting the loop above.  (This is normally
   2024   // done in IrEmitterUnnested::Postprocess().)
   2025   bindings_.UnbindAllLocalIrValues();
   2026 
   2027   // Convert unique_ptr<KernelThunk> to StatusOr<unique_ptr<Thunk>>.
   2028   return {std::move(kernel_thunk)};
   2029 }
   2030 
   2031 namespace {
   2032 
   2033 // Checks that the buffers corresponding to the given two HLOs share the same
   2034 // allocation.
   2035 Status CheckHloBuffersShareAllocation(
   2036     const HloInstruction* a, const HloInstruction* b, const ShapeIndex& index,
   2037     const BufferAssignment& buffer_assignment) {
   2038   const BufferAllocation::Slice slice_a =
   2039       buffer_assignment.GetUniqueSlice(a, index).ConsumeValueOrDie();
   2040   const BufferAllocation::Slice slice_b =
   2041       buffer_assignment.GetUniqueSlice(b, index).ConsumeValueOrDie();
   2042   if (slice_a != slice_b) {
   2043     return InternalError(
   2044         "instruction %s %s does not share allocation with instruction %s %s",
   2045         a->ToString(), slice_a.ToString(), b->ToString(), slice_b.ToString());
   2046   }
   2047   return Status::OK();
   2048 }
   2049 
   2050 // Checks that all buffers used during while loop iteration share the same
   2051 // buffer allocation. This includes buffers for while result, while init
   2052 // operand, condition parameter, body parameter and body result.
   2053 // Returns OK on success, error status otherwise.
   2054 Status CheckWhileBuffersShareAllocation(
   2055     const HloInstruction* xla_while,
   2056     const BufferAssignment& buffer_assignment) {
   2057   return ShapeUtil::ForEachSubshapeWithStatus(
   2058       xla_while->shape(),
   2059       [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
   2060         const HloInstruction* condition_parameter =
   2061             xla_while->while_condition()->parameter_instruction(0);
   2062         const HloComputation* body = xla_while->while_body();
   2063         const HloInstruction* body_parameter = body->parameter_instruction(0);
   2064         const HloInstruction* body_result = body->root_instruction();
   2065         TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
   2066             xla_while, xla_while->operand(0), index, buffer_assignment));
   2067         TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
   2068             xla_while, condition_parameter, index, buffer_assignment));
   2069         TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
   2070             xla_while, body_parameter, index, buffer_assignment));
   2071         TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
   2072             xla_while, body_result, index, buffer_assignment));
   2073         return Status::OK();
   2074       });
   2075 }
   2076 
   2077 // Checks that the buffers used in a conditional instruction are shared with the
   2078 // operands and result as follows:
   2079 //   * The result buffer of the conditional should share the allocation with the
   2080 //     result buffers of each branch computation.
   2081 //   * The buffer of operand b+1 should share the allocation with the buffer of
   2082 //     the parameter 0 instruction of the b'th computation.
   2083 Status CheckConditionalBuffersShareAllocation(
   2084     const HloInstruction* conditional,
   2085     const BufferAssignment& buffer_assignment) {
   2086   TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
   2087       conditional->shape(),
   2088       [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
   2089         for (auto branch_computation : conditional->branch_computations()) {
   2090           TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
   2091               conditional, branch_computation->root_instruction(), index,
   2092               buffer_assignment));
   2093         }
   2094         return Status::OK();
   2095       }));
   2096   for (int j = 0; j < conditional->branch_count(); ++j) {
   2097     TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
   2098         conditional->operand(j + 1)->shape(),
   2099         [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
   2100           return CheckHloBuffersShareAllocation(
   2101               conditional->operand(j + 1),
   2102               conditional->branch_computation(j)->parameter_instruction(0),
   2103               index, buffer_assignment);
   2104         }));
   2105   }
   2106   return Status::OK();
   2107 }
   2108 
   2109 }  // namespace
   2110 
   2111 std::unique_ptr<Thunk> IrEmitterUnnested::BuildWhileThunk(
   2112     const HloInstruction* hlo) {
   2113   // Check that all while-related buffers share an allocation.
   2114   TF_CHECK_OK(CheckWhileBuffersShareAllocation(
   2115       hlo, ir_emitter_context_->buffer_assignment()));
   2116 
   2117   // Generate thunk sequence for while 'condition'.
   2118   HloComputation* condition = hlo->while_condition();
   2119   IrEmitterUnnested ir_emitter_condition(hlo_module_config_, condition,
   2120                                          ir_emitter_context_);
   2121   TF_CHECK_OK(condition->Accept(&ir_emitter_condition));
   2122 
   2123   // Generate thunk sequence for while 'body'.
   2124   HloComputation* body = hlo->while_body();
   2125   IrEmitterUnnested ir_emitter_body(hlo_module_config_, body,
   2126                                     ir_emitter_context_);
   2127   TF_CHECK_OK(body->Accept(&ir_emitter_body));
   2128 
   2129   return absl::make_unique<WhileThunk>(
   2130       GetAllocationSlice(*condition->root_instruction()),  // cond result
   2131       ir_emitter_condition.ConsumeThunkSequence(),
   2132       ir_emitter_body.ConsumeThunkSequence(), hlo);
   2133 }
   2134 
   2135 std::unique_ptr<Thunk> IrEmitterUnnested::BuildForThunk(
   2136     const HloInstruction* hlo, const int64 loop_limit) {
   2137   // Check that all while-related buffers share an allocation.
   2138   TF_CHECK_OK(CheckWhileBuffersShareAllocation(
   2139       hlo, ir_emitter_context_->buffer_assignment()));
   2140 
   2141   // Generate thunk sequence for while 'body' (will be used a For loop body).
   2142   HloComputation* body = hlo->while_body();
   2143   IrEmitterUnnested ir_emitter_body(hlo_module_config_, body,
   2144                                     ir_emitter_context_);
   2145   TF_CHECK_OK(body->Accept(&ir_emitter_body));
   2146 
   2147   return absl::make_unique<ForThunk>(
   2148       loop_limit, ir_emitter_body.ConsumeThunkSequence(), hlo);
   2149 }
   2150 
   2151 std::unique_ptr<Thunk> IrEmitterUnnested::BuildConditionalThunk(
   2152     const HloInstruction* hlo) {
   2153   // Check that the buffers used in conditional are shared with the operands and
   2154   // result appropriately.
   2155   TF_CHECK_OK(CheckConditionalBuffersShareAllocation(
   2156       hlo, ir_emitter_context_->buffer_assignment()));
   2157 
   2158   std::vector<BufferAllocation::Slice> branch_operands;
   2159   std::vector<ThunkSequence> branch_thunks;
   2160   for (int j = 0; j < hlo->branch_count(); ++j) {
   2161     branch_operands.emplace_back(GetAllocationSlice(*hlo->operand(j + 1)));
   2162     HloComputation* branch_computation = hlo->branch_computation(j);
   2163     IrEmitterUnnested ir_emitter(hlo_module_config_, branch_computation,
   2164                                  ir_emitter_context_);
   2165     TF_CHECK_OK(branch_computation->Accept(&ir_emitter));
   2166     branch_thunks.push_back(std::move(*ir_emitter.ConsumeThunkSequence()));
   2167   }
   2168 
   2169   return absl::make_unique<ConditionalThunk>(
   2170       GetAllocationSlice(*hlo->operand(0)), branch_operands,
   2171       std::move(branch_thunks), hlo);
   2172 }
   2173 
   2174 Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
   2175     const HloInstruction& hlo,
   2176     const llvm_ir::ElementGenerator& element_generator, KernelThunk* thunk) {
   2177   int unroll_factor = thunk->unroll_factor();
   2178   VLOG(3) << bindings_.ToString();
   2179 
   2180   const Shape& element_shape = hlo.IsMultiOutputFusion()
   2181                                    ? ShapeUtil::GetSubshape(hlo.shape(), {0})
   2182                                    : hlo.shape();
   2183   VLOG(3) << "EmitTargetElementLoopInThunk "
   2184           << ShapeUtil::HumanStringWithLayout(hlo.shape())
   2185           << " for unroll_factor " << unroll_factor;
   2186   LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
   2187       element_shape, ir_emitter_context_->device_description(), unroll_factor);
   2188   UpdateLaunchDimensions(launch_dimensions, thunk,
   2189                          ir_emitter_context_->llvm_module());
   2190   if (!hlo.IsMultiOutputFusion()) {
   2191     return ParallelLoopEmitter(element_generator, GetIrArray(hlo, hlo),
   2192                                launch_dimensions, &b_, unroll_factor)
   2193         .EmitLoop(
   2194             IrName(&hlo),
   2195             GetIndexTypeForKernel(&hlo, launch_dimensions.launch_bound(), &b_));
   2196   }
   2197 
   2198   // Emit the tuple pointers in one thread.  We could do this at any point in
   2199   // the kernel, but we do it at the beginning in the hopes of reducing register
   2200   // pressure, since we touch threadIdx.x and blockIdx.x at the beginning of the
   2201   // kernel *anyway*.
   2202   std::vector<IrArray> output_arrays = ConstructIrArrayForOutputs(hlo);
   2203   KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] {
   2204     llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_);
   2205   });
   2206 
   2207   // For multioutput fusion, we need to emit each operand and the root.
   2208   TF_RETURN_IF_ERROR(
   2209       ParallelLoopEmitter(element_generator, output_arrays, launch_dimensions,
   2210                           &b_, unroll_factor)
   2211           .EmitLoop(IrName(&hlo),
   2212                     GetIndexTypeForKernel(
   2213                         &hlo, launch_dimensions.launch_bound(), &b_)));
   2214 
   2215   b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator());
   2216   return Status::OK();
   2217 }
   2218 
   2219 namespace {
   2220 
   2221 // Returns true if the fusion contains any instruction that is likely
   2222 // translated to complex LLVM IR, such as loops, and prevent vectorization.
   2223 bool MayPreventVectorization(const HloInstruction& fusion_hlo) {
   2224   CHECK_EQ(fusion_hlo.opcode(), HloOpcode::kFusion);
   2225   return absl::c_any_of(
   2226       fusion_hlo.fused_instructions_computation()->instructions(),
   2227       [&](const HloInstruction* instr) {
   2228         switch (instr->opcode()) {
   2229           case HloOpcode::kReduce:
   2230           case HloOpcode::kReduceWindow:
   2231           case HloOpcode::kSort:
   2232           case HloOpcode::kDot:
   2233             return true;
   2234           default:
   2235             return false;
   2236         }
   2237       });
   2238 }
   2239 
   2240 }  // namespace
   2241 
   2242 Status IrEmitterUnnested::EmitTargetElementLoop(
   2243     const HloInstruction& hlo,
   2244     const llvm_ir::ElementGenerator& element_generator) {
   2245   int unroll_factor = 1;
   2246   // Unfused elementwise operations are usually memory bound, unroll them.
   2247   if (hlo.IsElementwise() ||
   2248       (hlo.opcode() == HloOpcode::kFusion && !MayPreventVectorization(hlo))) {
   2249     unroll_factor = ComputeMaxUnrollFactor(&hlo);
   2250   }
   2251 
   2252   std::unique_ptr<KernelThunk> kernel_thunk = BuildKernelThunk(
   2253       &hlo, /*implements_whole_instruction=*/true, unroll_factor);
   2254   Status emit_status =
   2255       EmitTargetElementLoopInThunk(hlo, element_generator, kernel_thunk.get());
   2256   thunk_sequence_->emplace_back(std::move(kernel_thunk));
   2257 
   2258   return emit_status;
   2259 }
   2260 
   2261 std::vector<IrArray> IrEmitterUnnested::ConstructIrArrayForInputs(
   2262     const HloInstruction& hlo) {
   2263   std::vector<IrArray> param_arrays;
   2264   param_arrays.reserve(hlo.operands().size());
   2265   for (const HloInstruction* param : hlo.operands()) {
   2266     param_arrays.push_back(GetIrArray(*param, hlo));
   2267   }
   2268   return param_arrays;
   2269 }
   2270 
   2271 int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape(
   2272     const HloInstruction& hlo, const std::vector<IrArray>& param_arrays,
   2273     const std::vector<llvm::Value*>& param_buffers,
   2274     absl::Span<const int64> reduced_output_dims,
   2275     std::vector<Shape>* param_reduced_shapes,
   2276     std::vector<IrArray>* param_in_reduced_shape_arrays) {
   2277   int64 num_params = hlo.operands().size();
   2278   param_in_reduced_shape_arrays->reserve(num_params);
   2279   param_reduced_shapes->reserve(num_params);
   2280   for (int64 id = 0; id < num_params; ++id) {
   2281     if (param_buffers[id] == nullptr) {
   2282       param_reduced_shapes->push_back(Shape());
   2283       param_in_reduced_shape_arrays->push_back(IrArray());
   2284       continue;
   2285     }
   2286     const HloInstruction* param = hlo.operand(id);
   2287     param_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout(
   2288         param->shape().element_type(),
   2289         Permute({0, 2, 1}, reduced_output_dims)));
   2290     param_in_reduced_shape_arrays->push_back(
   2291         param_arrays[id].CastToShape((*param_reduced_shapes)[id], &b_));
   2292   }
   2293   return num_params;
   2294 }
   2295 
   2296 namespace {
   2297 
   2298 std::tuple<llvm::Value*, int64> GetStartOffsetAndStepForX(
   2299     int64 tile_size_x, int64 num_threads_x,
   2300     const KernelMappingScheme* mapping_scheme, llvm::IRBuilder<>* builder,
   2301     llvm::Value* x, llvm::Type* index_ty) {
   2302   llvm::Value* start_offset_x;
   2303   int64 step_x;
   2304   if (mapping_scheme->DilatedX()) {
   2305     start_offset_x = x;
   2306     step_x = num_threads_x;
   2307   } else {
   2308     start_offset_x = builder->CreateMul(
   2309         x, llvm::ConstantInt::get(index_ty, tile_size_x / num_threads_x));
   2310     step_x = 1;
   2311   }
   2312   return std::make_tuple(start_offset_x, step_x);
   2313 }
   2314 
   2315 void EmitFullElementalTile(const KernelMappingScheme* mapping_scheme,
   2316                            const IrArray::Index& tile_origin_index,
   2317                            const string& loop_name, KernelSupportLibrary* ksl,
   2318                            llvm::IRBuilder<>* builder, llvm::Value* y,
   2319                            llvm::Value* x, llvm::Type* index_ty,
   2320                            const EmitElementFunction& emit_elem_function) {
   2321   int64 num_threads_x = mapping_scheme->GetNumberOfThreadsForDimensionX();
   2322   int64 num_threads_y = mapping_scheme->GetNumberOfThreadsForDimensionY();
   2323   int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX();
   2324   int64 tile_size_y = mapping_scheme->GetTileSizeForDimensionY();
   2325 
   2326   llvm::Value* start_offset_x;
   2327   int64 step_x;
   2328   std::tie(start_offset_x, step_x) = GetStartOffsetAndStepForX(
   2329       tile_size_x, num_threads_x, mapping_scheme, builder, x, index_ty);
   2330   IrArray::Index source_idx =
   2331       tile_origin_index.AddOffsetToDim(y, KernelMappingScheme::DimY, builder)
   2332           .AddOffsetToDim(start_offset_x, KernelMappingScheme::DimX, builder);
   2333   ksl->For(loop_name + "_y", /*start=*/llvm::ConstantInt::get(index_ty, 0),
   2334            /*end=*/llvm::ConstantInt::get(index_ty, tile_size_y),
   2335            /*step=*/llvm::ConstantInt::get(index_ty, num_threads_y),
   2336            [&](llvm::Value* y_indvar) {
   2337              IrArray::Index source_idx_y = source_idx.AddOffsetToDim(
   2338                  y_indvar, KernelMappingScheme::DimY, builder);
   2339              llvm::Value* y_loc = builder->CreateAdd(y_indvar, y);
   2340 
   2341              for (int64 j = 0; j < tile_size_x / num_threads_x; j++) {
   2342                IrArray::Index source_idx_y_x = source_idx_y.AddOffsetToDim(
   2343                    llvm::ConstantInt::get(index_ty, j * step_x),
   2344                    KernelMappingScheme::DimX, builder);
   2345                llvm::Value* x_loc = builder->CreateAdd(
   2346                    llvm::ConstantInt::get(index_ty, j * step_x),
   2347                    start_offset_x);
   2348                emit_elem_function(source_idx_y_x, y_loc, x_loc, j);
   2349              }
   2350            });
   2351 }
   2352 
   2353 void EmitPartialElementalTile(const KernelMappingScheme* mapping_scheme,
   2354                               const IrArray::Index& tile_origin_index,
   2355                               const string& loop_name,
   2356                               KernelSupportLibrary* ksl,
   2357                               llvm::IRBuilder<>* builder, llvm::Value* y,
   2358                               llvm::Value* x, llvm::Value* tile_height,
   2359                               llvm::Value* tile_width, llvm::Type* index_ty,
   2360                               const EmitElementFunction& emit_elem_function) {
   2361   int64 num_threads_x = mapping_scheme->GetNumberOfThreadsForDimensionX();
   2362   int64 num_threads_y = mapping_scheme->GetNumberOfThreadsForDimensionY();
   2363   int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX();
   2364 
   2365   llvm::Value* start_offset_x;
   2366   int64 step_x;
   2367   std::tie(start_offset_x, step_x) = GetStartOffsetAndStepForX(
   2368       tile_size_x, num_threads_x, mapping_scheme, builder, x, index_ty);
   2369   IrArray::Index source_idx =
   2370       tile_origin_index.AddOffsetToDim(y, KernelMappingScheme::DimY, builder)
   2371           .AddOffsetToDim(start_offset_x, KernelMappingScheme::DimX, builder);
   2372   for (int64 j = 0; j < tile_size_x / num_threads_x; j++) {
   2373     IrArray::Index source_idx_x =
   2374         source_idx.AddOffsetToDim(llvm::ConstantInt::get(index_ty, j * step_x),
   2375                                   KernelMappingScheme::DimX, builder);
   2376     llvm::Value* x_loc = builder->CreateAdd(
   2377         llvm::ConstantInt::get(index_ty, j * step_x), start_offset_x);
   2378 
   2379     ksl->If(
   2380         loop_name + "_x_in_tile", builder->CreateICmpULT(x_loc, tile_width),
   2381         [&] {
   2382           // tile_height_bound =
   2383           //   ceil(tile_height / num_threads_y) * num_threads_y
   2384           llvm::Value* ceiling_of_ratio = builder->CreateUDiv(
   2385               builder->CreateAdd(tile_height, llvm::ConstantInt::get(
   2386                                                   index_ty, num_threads_y - 1)),
   2387               llvm::ConstantInt::get(index_ty, num_threads_y));
   2388           llvm::Value* tile_height_bound = builder->CreateMul(
   2389               ceiling_of_ratio,
   2390               llvm::ConstantInt::get(index_ty, num_threads_y));
   2391           ksl->For(
   2392               loop_name, /*start=*/llvm::ConstantInt::get(index_ty, 0),
   2393               /*end=*/tile_height_bound,
   2394               /*step=*/llvm::ConstantInt::get(index_ty, num_threads_y),
   2395               [&](llvm::Value* y_indvar) {
   2396                 llvm::Value* y_loc = builder->CreateAdd(y_indvar, y);
   2397                 ksl->If(loop_name + "_y_in_tile",
   2398                         builder->CreateICmpULT(y_loc, tile_height), [&] {
   2399                           emit_elem_function(
   2400                               source_idx_x.AddOffsetToDim(
   2401                                   y_indvar, KernelMappingScheme::DimY, builder),
   2402                               y_loc, x_loc, j);
   2403                         });
   2404               });
   2405         });
   2406   }
   2407 }
   2408 
   2409 // Emits code to process up to
   2410 // (tile_size_x/num_threads_x * tile_size_y/num_threads_y) elements in a tile,
   2411 // given `emit_elem_function` is the function to emit code to process one
   2412 // element, `y` and `x` are the intra-tile coordinates for the first element
   2413 // to process, and `index` is the index for the origin of the tile. Information
   2414 // about tile_size_x/y and num_threads_x/y are stored in `mapping_scheme`. Emits
   2415 // bounds check to ensure that each processed element is within the boundary
   2416 // defined by `tile_width` and `tile_height`.
   2417 void EmitTiledElementalCodeWithBoundsCheck(
   2418     const KernelMappingScheme* mapping_scheme,
   2419     const IrArray::Index& tile_origin_index, const string& loop_name,
   2420     KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y,
   2421     llvm::Value* x, llvm::Value* tile_height, llvm::Value* tile_width,
   2422     const EmitElementFunction& emit_elem_function) {
   2423   int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX();
   2424   int64 tile_size_y = mapping_scheme->GetTileSizeForDimensionY();
   2425   llvm::Type* index_ty = tile_width->getType();
   2426 
   2427   ksl->If(
   2428       loop_name + "_full_tile",
   2429       builder->CreateAnd(
   2430           builder->CreateICmpEQ(llvm::ConstantInt::get(index_ty, tile_size_x),
   2431                                 tile_width),
   2432           builder->CreateICmpEQ(llvm::ConstantInt::get(index_ty, tile_size_y),
   2433                                 tile_height)),
   2434       [&] {
   2435         EmitFullElementalTile(mapping_scheme, tile_origin_index, loop_name, ksl,
   2436                               builder, y, x, index_ty, emit_elem_function);
   2437       },
   2438       [&] {
   2439         EmitPartialElementalTile(mapping_scheme, tile_origin_index, loop_name,
   2440                                  ksl, builder, y, x, tile_height, tile_width,
   2441                                  index_ty, emit_elem_function);
   2442       });
   2443 }
   2444 }  // namespace
   2445 
   2446 // Emits code to process a tensor element in a tile for the given kCopy HLO that
   2447 // performs a 0-2-1 transpose.
   2448 //
   2449 // index: The index for the first output element in the normalized tensor. The
   2450 //   normalized tensor is the resulting tensor after collapsing contiguous
   2451 //   dimensions that play the same role in the transpose.
   2452 // y_loc: The y coordinate within a tile.
   2453 // x_loc: The x coordinate within a tile.
   2454 // kernel_info: Other information to support the kernel code generation.
   2455 void IrEmitterUnnested::EmitTileElementForCopy(
   2456     HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
   2457     const KernelCodegenInfo* kernel_info, llvm::Value* y_loc,
   2458     llvm::Value* x_loc, int64 /*x_iter_num*/) {
   2459   llvm_ir::TiledParameterInfo* tiled_param_info =
   2460       kernel_info->GetTiledParameterInfo();
   2461   // TODO(jlebar): Add AA metadata to this load.
   2462   llvm::Instruction* load_from_shmem_buffer =
   2463       Load(GEP(tiled_param_info->GetBufferForParameter(0),
   2464                {b_.getInt64(0), x_loc, y_loc}),
   2465            "output_element");
   2466   llvm_ir::IrArray output_array = GetIrArray(*hlo, *hlo);
   2467   Shape output_reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout(
   2468       hlo->shape().element_type(),
   2469       kernel_info->GetKernelMappingScheme()->GetDimensionsInElements());
   2470   // When the output_reduced_shape is a 0-2-1 transpose of the input shape,
   2471   // the 0-2-1 transpose is achieved through EmitWriteArrayElement.
   2472   output_array.CastToShape(output_reduced_shape, &b_)
   2473       .EmitWriteArrayElement(index, load_from_shmem_buffer, &b_);
   2474 }
   2475 
   2476 // Emits code to process a tensor element in a tile for the given kLoop fusion
   2477 // HLO containing parameters that are 0-2-1 transpose of its outputs.
   2478 //
   2479 // index: The index for the first output element in the normalized tensor, that
   2480 //   is the resulting tensor after collapsing contiguous dimensions that play
   2481 //   the same role in the transpose.
   2482 // kernel_info: Other information to support the kernel code generation.
   2483 // y_loc: The y coordinate within a tile.
   2484 // x_loc: The x coordinate within a tile.
   2485 void IrEmitterUnnested::EmitTileElementForFusion(
   2486     HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
   2487     const KernelCodegenInfo* kernel_info, llvm::Value* y_loc,
   2488     llvm::Value* x_loc, int64 /*x_iter_num*/) {
   2489   llvm_ir::TiledParameterInfo* tiled_param_info =
   2490       kernel_info->GetTiledParameterInfo();
   2491   std::vector<IrArray> output_arrays = ConstructIrArrayForOutputs(*hlo);
   2492   GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
   2493                                      GetNestedComputer());
   2494   FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(hlo),
   2495                                &elem_emitter);
   2496   tiled_param_info->set_y(y_loc);
   2497   tiled_param_info->set_x(x_loc);
   2498   fused_emitter.SetTiledParameterInfo(tiled_param_info);
   2499   TF_CHECK_OK(hlo->fused_expression_root()->Accept(&fused_emitter));
   2500   IrArray::Index untiled_index =
   2501       kernel_info->GetKernelMappingScheme()->GetUnnormalizedIndex(
   2502           index, output_arrays[0].GetShape());
   2503   const llvm_ir::ElementGenerator& output_generator =
   2504       fused_emitter.GetRootGenerator();
   2505   llvm::Value* output_value = output_generator(untiled_index).ValueOrDie();
   2506   if (hlo->IsMultiOutputFusion()) {
   2507     DCHECK(output_value->getType()->isStructTy());
   2508     DCHECK_EQ(output_value->getType()->getStructNumElements(),
   2509               output_arrays.size());
   2510     for (int64 i = 0; i < output_arrays.size(); ++i) {
   2511       output_arrays[i].EmitWriteArrayElement(
   2512           untiled_index, ExtractValue(output_value, i), &b_);
   2513     }
   2514   } else {
   2515     output_arrays[0].EmitWriteArrayElement(untiled_index, output_value, &b_);
   2516   }
   2517 }
   2518 
   2519 // Information to support the code generation for a tiled reduction kernel.
   2520 using AddressVector = InlinedVector<llvm::AllocaInst*, 1>;
   2521 class ReductionCodegenInfo : public IrEmitterUnnested::KernelCodegenInfo {
   2522  public:
   2523   explicit ReductionCodegenInfo(llvm_ir::KernelMappingScheme* mapping_scheme,
   2524                                 bool is_row_reduction)
   2525       : KernelCodegenInfo(mapping_scheme),
   2526         current_output_linear_index_address_(nullptr),
   2527         current_output_inbound_address_(nullptr),
   2528         is_row_reduction_(is_row_reduction) {}
   2529 
   2530   void SetCurrentOutputLinearIndexAddress(llvm::AllocaInst* a) {
   2531     current_output_linear_index_address_ = a;
   2532   }
   2533   // Returns the address of the memory that stores the linear index of the
   2534   // current output. Since we are processing reduction to contiguous physical
   2535   // dimensions, this linear index is the linear index of the 1D output array.
   2536   llvm::AllocaInst* GetCurrentOutputLinearIndexAddress() const {
   2537     return current_output_linear_index_address_;
   2538   }
   2539 
   2540   void SetCurrentOutputInboundAddress(llvm::AllocaInst* a) {
   2541     current_output_inbound_address_ = a;
   2542   }
   2543 
   2544   llvm::AllocaInst* GetCurrentOutputInboundAddress() const {
   2545     return current_output_inbound_address_;
   2546   }
   2547 
   2548   AddressVector* GetMutablePartialResultAddresses() {
   2549     return &partial_result_addresses_;
   2550   }
   2551   absl::Span<llvm::AllocaInst* const> GetPartialResultAddresses() const {
   2552     return partial_result_addresses_;
   2553   }
   2554 
   2555   AddressVector* GetMutableReductionInputAddresses() {
   2556     return &reduction_input_addresses_;
   2557   }
   2558   absl::Span<llvm::AllocaInst* const> GetReductionInputAddresses() const {
   2559     return reduction_input_addresses_;
   2560   }
   2561 
   2562   InlinedVector<HloComputation*, 1>* GetMutableReducers() { return &reducers_; }
   2563   const InlinedVector<HloComputation*, 1>& GetReducers() const {
   2564     return reducers_;
   2565   }
   2566   int GetNumberOfReduces() const { return reducers_.size(); }
   2567 
   2568   InlinedVector<ShapeIndex, 1>* GetMutableReductionOutputShapeIndices() {
   2569     return &reduction_output_shape_indices_;
   2570   }
   2571   absl::Span<const ShapeIndex> GetReductionOutputShapeIndices() const {
   2572     return reduction_output_shape_indices_;
   2573   }
   2574 
   2575   bool IsRowReduction() const { return is_row_reduction_; }
   2576 
   2577   // Return the dimension that is being reduced between DimX and DimY.
   2578   int GetReducedDimensionEnum() const {
   2579     return IsRowReduction() ? llvm_ir::KernelMappingScheme::DimX
   2580                             : llvm_ir::KernelMappingScheme::DimY;
   2581   }
   2582 
   2583   // Return the dimension that is being ketp between DimX and DimY.
   2584   int GetKeptDimensionEnum() const {
   2585     return IsRowReduction() ? llvm_ir::KernelMappingScheme::DimY
   2586                             : llvm_ir::KernelMappingScheme::DimX;
   2587   }
   2588 
   2589   int GetNumberOfPartialResults() const {
   2590     if (IsRowReduction()) {
   2591       return 1;
   2592     }
   2593     int64 num_thread = mapping_scheme_->GetNumberOfThreadsForDimensionX();
   2594     int64 tile_size = mapping_scheme_->GetTileSizeForDimensionX();
   2595     CHECK_EQ(tile_size % num_thread, 0);
   2596     return tile_size / num_thread;
   2597   }
   2598 
   2599   int GetPartialResultIndex(int64 x_iter_num) const {
   2600     if (IsRowReduction()) {
   2601       return 0;
   2602     }
   2603     return x_iter_num;
   2604   }
   2605 
   2606  private:
   2607   AddressVector partial_result_addresses_;
   2608   AddressVector reduction_input_addresses_;
   2609   InlinedVector<HloComputation*, 1> reducers_;
   2610   InlinedVector<ShapeIndex, 1> reduction_output_shape_indices_;
   2611   llvm::AllocaInst* current_output_linear_index_address_;
   2612   llvm::AllocaInst* current_output_inbound_address_;
   2613   bool is_row_reduction_;
   2614 };
   2615 
   2616 namespace {
   2617 // Returns a group of instructions that generate the output for the kernel
   2618 // containing the given HLO instruction. The result may be an unnested kReduce
   2619 // HLO, a nested kReduce HLO of a kInput fusion, or the operands of the tuple
   2620 // for a multiple output fusion.
   2621 absl::Span<HloInstruction* const> GetOutputInstructions(
   2622     HloInstruction* const* reduce_or_tuple_pointer) {
   2623   HloOpcode opcode = (*reduce_or_tuple_pointer)->opcode();
   2624   CHECK(opcode == HloOpcode::kReduce || opcode == HloOpcode::kTuple);
   2625   return opcode == HloOpcode::kTuple
   2626              ? (*reduce_or_tuple_pointer)->operands()
   2627              : absl::Span<HloInstruction* const>(reduce_or_tuple_pointer, 1);
   2628 }
   2629 
   2630 const HloInstruction* GetFirstReduceInstruction(
   2631     absl::Span<HloInstruction* const> instructions) {
   2632   auto first_reduce_iter =
   2633       absl::c_find_if(instructions, [](const HloInstruction* inst) {
   2634         return inst->opcode() == HloOpcode::kReduce;
   2635       });
   2636   CHECK_NE(first_reduce_iter, instructions.end());
   2637   return *first_reduce_iter;
   2638 }
   2639 
   2640 };  // namespace
   2641 
   2642 void IrEmitterUnnested::EmitPrologueForOneReduction(
   2643     HloInstruction* unnested_hlo, HloInstruction* reduce_inst, int reduce_idx,
   2644     KernelCodegenInfo* kernel_info, GpuElementalIrEmitter* elemental_emitter,
   2645     ShapeIndex output_shape_index) {
   2646   ReductionCodegenInfo* reduction_info =
   2647       static_cast<ReductionCodegenInfo*>(kernel_info);
   2648 
   2649   InlinedVector<HloComputation*, 1>* reducers =
   2650       reduction_info->GetMutableReducers();
   2651   CHECK(IsReductionToVector(*reduce_inst));
   2652   reducers->push_back(reduce_inst->to_apply());
   2653 
   2654   InlinedVector<ShapeIndex, 1>* reduction_output_shape_indices =
   2655       reduction_info->GetMutableReductionOutputShapeIndices();
   2656   reduction_output_shape_indices->push_back(std::move(output_shape_index));
   2657 
   2658   AddressVector* reduction_input_addresses =
   2659       reduction_info->GetMutableReductionInputAddresses();
   2660   llvm::Type* element_type = llvm_ir::PrimitiveTypeToIrType(
   2661       reduce_inst->shape().element_type(), ir_emitter_context_->llvm_module());
   2662   llvm::AllocaInst* reduction_input_address = Alloca(element_type);
   2663   reduction_input_addresses->push_back(reduction_input_address);
   2664 
   2665   int num_partial_results = reduction_info->GetNumberOfPartialResults();
   2666   AddressVector* partial_result_addresses =
   2667       reduction_info->GetMutablePartialResultAddresses();
   2668   llvm::AllocaInst* partial_result_address =
   2669       Alloca(element_type, /*ArraySize=*/b_.getInt32(num_partial_results),
   2670              "partial_reduction_result." + llvm::Twine(reduce_idx));
   2671   partial_result_addresses->push_back(partial_result_address);
   2672 
   2673   // Initialize the partial result with the initial value of the reduction.
   2674   llvm::Value* init_ir_value;
   2675   if (unnested_hlo->opcode() == HloOpcode::kFusion) {
   2676     HloInstruction* init_value_operand = reduce_inst->mutable_operand(1);
   2677     FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(unnested_hlo),
   2678                                  elemental_emitter);
   2679 
   2680     TF_CHECK_OK(init_value_operand->Accept(&fused_emitter));
   2681     init_ir_value =
   2682         fused_emitter
   2683             .GetGenerator(init_value_operand)(IrArray::Index(b_.getInt32Ty()))
   2684             .ValueOrDie();
   2685   } else {
   2686     const HloInstruction* init_value = unnested_hlo->operand(1);
   2687     init_ir_value =
   2688         GetIrArray(*init_value, *unnested_hlo)
   2689             .EmitReadArrayElement(IrArray::Index(b_.getInt32Ty()), &b_);
   2690   }
   2691 
   2692   for (int i = 0; i < num_partial_results; ++i) {
   2693     Store(init_ir_value, InBoundsGEP(partial_result_address, {b_.getInt32(i)}));
   2694   }
   2695 }
   2696 
   2697 void IrEmitterUnnested::EmitPrologueForReduction(
   2698     HloInstruction* unnested_hlo, KernelCodegenInfo* kernel_info) {
   2699   VLOG(10) << "Emit prologue for reduction " << unnested_hlo->ToString();
   2700   // Find the unnested kReduce or the tuple that contains a list of kReduce.
   2701   HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion
   2702                                         ? unnested_hlo->fused_expression_root()
   2703                                         : unnested_hlo;
   2704   absl::Span<HloInstruction* const> output_instructions =
   2705       GetOutputInstructions(&reduce_or_tuple);
   2706   ReductionCodegenInfo* reduction_info =
   2707       static_cast<ReductionCodegenInfo*>(kernel_info);
   2708   GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
   2709                                           ir_emitter_context_->llvm_module(),
   2710                                           &b_, GetNestedComputer());
   2711   const HloInstruction* first_reduce = nullptr;
   2712   for (int i = 0, e = output_instructions.size(); i != e; ++i) {
   2713     if (output_instructions[i]->opcode() != HloOpcode::kReduce) {
   2714       continue;
   2715     }
   2716     HloInstruction* reduce_inst = output_instructions[i];
   2717     if (first_reduce == nullptr) {
   2718       first_reduce = reduce_inst;
   2719     } else {
   2720       CHECK(first_reduce->dimensions() == reduce_inst->dimensions());
   2721     }
   2722     ShapeIndex output_shape_index;
   2723     if (reduce_or_tuple->opcode() == HloOpcode::kTuple) {
   2724       output_shape_index = {i};
   2725     }
   2726 
   2727     EmitPrologueForOneReduction(unnested_hlo, reduce_inst, i, kernel_info,
   2728                                 &elemental_emitter,
   2729                                 std::move(output_shape_index));
   2730   }
   2731 
   2732   int num_partial_results = reduction_info->GetNumberOfPartialResults();
   2733 
   2734   // Allocate stack storage to store the linear indices for the current output,
   2735   // and record the address of the storage.
   2736   reduction_info->SetCurrentOutputLinearIndexAddress(
   2737       Alloca(reduction_info->GetIndexType(),
   2738              /*ArraySize=*/b_.getInt32(num_partial_results),
   2739              "current_output_linear_index_address"));
   2740 
   2741   if (!reduction_info->IsRowReduction()) {
   2742     llvm::Type* bool_ty = b_.getInt1Ty();
   2743     llvm::AllocaInst* output_inbound_addr = Alloca(bool_ty);
   2744     Store(llvm::ConstantInt::get(bool_ty, 0), output_inbound_addr);
   2745     reduction_info->SetCurrentOutputInboundAddress(output_inbound_addr);
   2746   }
   2747 }
   2748 
   2749 void IrEmitterUnnested::EmitFullWarpShuffleDownLoopForAllReduces(
   2750     absl::Span<HloComputation* const> reducers,
   2751     absl::Span<llvm::AllocaInst* const> partial_result_addresses) {
   2752   for (int distance = 16; distance >= 1; distance /= 2) {
   2753     for (int i = 0; i != reducers.size(); ++i) {
   2754       llvm::Type* element_type =
   2755           partial_result_addresses[i]->getType()->getElementType();
   2756       int bit_width = llvm_ir::GetSizeInBits(element_type);
   2757       llvm::Value* result_from_other_lane = Alloca(
   2758           element_type, nullptr, "result_from_other_lane" + llvm::Twine(i));
   2759       // Bitcast cannot be applied to aggregate types (even packed ones), so
   2760       // we bitcast addresses of load/store to intN* of the same bit-width.
   2761       llvm::Type* shuffled_value_type =
   2762           element_type->isStructTy() ? b_.getIntNTy(bit_width) : element_type;
   2763       auto convert_pointer_for_shuffle = [&](llvm::Value* ptr) {
   2764         return BitCast(ptr, shuffled_value_type->getPointerTo());
   2765       };
   2766       llvm::Value* partial_result =
   2767           Load(convert_pointer_for_shuffle(partial_result_addresses[i]),
   2768                "partial_reduction_result");
   2769       Store(EmitFullWarpShuffleDown(partial_result, b_.getInt32(distance), &b_),
   2770             convert_pointer_for_shuffle(result_from_other_lane));
   2771       TF_CHECK_OK(EmitCallToNestedComputation(
   2772           *reducers[i], {partial_result_addresses[i], result_from_other_lane},
   2773           partial_result_addresses[i]));
   2774     }
   2775   }
   2776 }
   2777 
   2778 void IrEmitterUnnested::EmitEpilogueForReduction(
   2779     HloInstruction* unnested_hlo, KernelCodegenInfo* kernel_info) {
   2780   ReductionCodegenInfo* reduction_info =
   2781       static_cast<ReductionCodegenInfo*>(kernel_info);
   2782   int num_reduces = reduction_info->GetNumberOfReduces();
   2783   absl::Span<llvm::AllocaInst* const> partial_result_addresses =
   2784       reduction_info->GetPartialResultAddresses();
   2785   const InlinedVector<HloComputation*, 1>& reducers =
   2786       reduction_info->GetReducers();
   2787   absl::Span<const ShapeIndex> reduction_output_shape_indices =
   2788       reduction_info->GetReductionOutputShapeIndices();
   2789 
   2790   if (reduction_info->IsRowReduction()) {
   2791     EmitFullWarpShuffleDownLoopForAllReduces(reducers,
   2792                                              partial_result_addresses);
   2793     llvm::Value* lane_id = reduction_info->GetLaneId();
   2794     llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse(
   2795         ICmpEQ(lane_id, llvm::ConstantInt::get(lane_id->getType(), 0)),
   2796         "lane_id_is_zero", &b_);
   2797     llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_);
   2798   } else {
   2799     llvm::Value* output_inbound_addr =
   2800         reduction_info->GetCurrentOutputInboundAddress();
   2801     llvm::Value* output_inbound = Load(output_inbound_addr);
   2802     llvm_ir::LlvmIfData if_output_inbound_data = llvm_ir::EmitIfThenElse(
   2803         ICmpEQ(output_inbound,
   2804                llvm::ConstantInt::get(output_inbound->getType(), 1)),
   2805         "output_inbound", &b_);
   2806     llvm_ir::SetToFirstInsertPoint(if_output_inbound_data.true_block, &b_);
   2807   }
   2808 
   2809   int num_partial_results = reduction_info->GetNumberOfPartialResults();
   2810 
   2811   // Emit an atomic operation that accumulates the partial reduction to the
   2812   // output element. For row reduction, this is only for lane 0 due to the
   2813   // if-statement emitted above.
   2814   for (int i = 0; i != num_reduces; ++i) {
   2815     for (int j = 0; j < num_partial_results; ++j) {
   2816       IrArray::Index element_index(
   2817           /*linear=*/Load(
   2818               InBoundsGEP(reduction_info->GetCurrentOutputLinearIndexAddress(),
   2819                           {b_.getInt32(j)}),
   2820               "output_linear_addr"),
   2821           ShapeUtil::GetSubshape(unnested_hlo->shape(),
   2822                                  reduction_output_shape_indices[i]),
   2823           &b_);
   2824       llvm::Value* output_address =
   2825           GetIrArray(*unnested_hlo, *unnested_hlo,
   2826                      reduction_output_shape_indices[i])
   2827               .EmitArrayElementAddress(element_index, &b_,
   2828                                        "output_element_address");
   2829       // Do not emit atomic operations if each element in the reduction result
   2830       // is computed by one block, that is the dimension being reduced has only
   2831       // one block.
   2832       const llvm_ir::KernelMappingScheme* mapping_scheme =
   2833           reduction_info->GetKernelMappingScheme();
   2834       if (mapping_scheme->GetTileBlockSizeForDimension(
   2835               llvm_ir::KernelMappingScheme::DimZ) == 1 &&
   2836           mapping_scheme->GetTileBlockSizeForDimension(
   2837               reduction_info->GetReducedDimensionEnum()) == 1) {
   2838         TF_CHECK_OK(EmitCallToNestedComputation(
   2839             *reducers[i],
   2840             {output_address,
   2841              InBoundsGEP(partial_result_addresses[i], {b_.getInt32(j)})},
   2842             output_address));
   2843       } else {
   2844         TF_CHECK_OK(EmitAtomicOperationForNestedComputation(
   2845             *reducers[i], output_address,
   2846             InBoundsGEP(partial_result_addresses[i], {b_.getInt32(j)})));
   2847       }
   2848     }
   2849   }
   2850 }
   2851 
   2852 void IrEmitterUnnested::EmitTileElementForReduction(
   2853     HloInstruction* unnested_hlo, const llvm_ir::IrArray::Index& index,
   2854     const KernelCodegenInfo* kernel_info, llvm::Value* y_loc,
   2855     llvm::Value* x_loc, int64 x_iter_num) {
   2856   VLOG(10) << "Emit tile element for reduce " << unnested_hlo->ToString();
   2857   HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion
   2858                                         ? unnested_hlo->fused_expression_root()
   2859                                         : unnested_hlo;
   2860   llvm_ir::TiledParameterInfo* tiled_param_info =
   2861       kernel_info->GetTiledParameterInfo();
   2862   tiled_param_info->set_y(y_loc);
   2863   tiled_param_info->set_x(x_loc);
   2864 
   2865   // Record the linear address for the current reduction.
   2866   const ReductionCodegenInfo* reduction_info =
   2867       dynamic_cast<const ReductionCodegenInfo*>(kernel_info);
   2868   int partial_result_index = reduction_info->IsRowReduction() ? 0 : x_iter_num;
   2869 
   2870   Store(index[reduction_info->GetKeptDimensionEnum()],
   2871         InBoundsGEP(reduction_info->GetCurrentOutputLinearIndexAddress(),
   2872                     {b_.getInt32(partial_result_index)}));
   2873   if (!reduction_info->IsRowReduction()) {
   2874     llvm::Type* bool_ty = b_.getInt1Ty();
   2875     llvm::AllocaInst* output_inbound_addr =
   2876         reduction_info->GetCurrentOutputInboundAddress();
   2877     Store(llvm::ConstantInt::get(bool_ty, 1), output_inbound_addr);
   2878   }
   2879 
   2880   InlinedVector<llvm_ir::ElementGenerator, 1> input_gens;
   2881   std::vector<std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
   2882       extra_output_gens;
   2883   GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
   2884                                      GetNestedComputer());
   2885   FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(unnested_hlo),
   2886                                &elem_emitter);
   2887   absl::Span<HloInstruction* const> output_instructions =
   2888       GetOutputInstructions(&reduce_or_tuple);
   2889   // Construct the ElementGenerator for each reduction and extra output in the
   2890   // the group of output instructions.
   2891   if (unnested_hlo->opcode() == HloOpcode::kFusion) {
   2892     fused_emitter.SetTiledParameterInfo(tiled_param_info);
   2893     TF_CHECK_OK(unnested_hlo->fused_expression_root()->Accept(&fused_emitter));
   2894 
   2895     for (int i = 0, e = output_instructions.size(); i != e; ++i) {
   2896       const HloInstruction* inst = output_instructions[i];
   2897       ShapeIndex output_shape_index;
   2898       if (reduce_or_tuple->opcode() == HloOpcode::kTuple) {
   2899         output_shape_index = {i};
   2900       }
   2901       if (inst->opcode() == HloOpcode::kReduce) {
   2902         input_gens.push_back(fused_emitter.GetGenerator(inst->operand(0)));
   2903       } else {
   2904         extra_output_gens.emplace_back(fused_emitter.GetGenerator(inst),
   2905                                        std::move(output_shape_index));
   2906       }
   2907     }
   2908   } else {
   2909     input_gens.push_back([&](const IrArray::Index& index) {
   2910       return GetIrArray(*unnested_hlo->operand(0), *unnested_hlo)
   2911           .EmitReadArrayElement(index, &b_);
   2912     });
   2913   }
   2914 
   2915   IrArray::Index input_index =
   2916       reduction_info->GetKernelMappingScheme()->GetUnnormalizedIndex(
   2917           index,
   2918           GetFirstReduceInstruction(output_instructions)->operand(0)->shape());
   2919   absl::Span<llvm::AllocaInst* const> partial_reduction_result_addresses =
   2920       reduction_info->GetPartialResultAddresses();
   2921   absl::Span<llvm::AllocaInst* const> reduction_input_addresses =
   2922       reduction_info->GetReductionInputAddresses();
   2923   const InlinedVector<HloComputation*, 1>& reducers =
   2924       reduction_info->GetReducers();
   2925 
   2926   // Emit code to generate the input and perform the reduction computation for
   2927   // each reduction instruction.
   2928   for (int i = 0; i != reducers.size(); ++i) {
   2929     llvm::Value* const input_ir_value = input_gens[i](input_index).ValueOrDie();
   2930     Store(input_ir_value, reduction_input_addresses[i]);
   2931     llvm::Value* partial_result_address =
   2932         InBoundsGEP(partial_reduction_result_addresses[i],
   2933                     {b_.getInt32(partial_result_index)});
   2934     TF_CHECK_OK(EmitCallToNestedComputation(
   2935         *reducers[i], {partial_result_address, reduction_input_addresses[i]},
   2936         partial_result_address));
   2937   }
   2938 
   2939   // Emit code to generate the output for the non-reduction instructions in the
   2940   // fusion, if any.
   2941   TF_CHECK_OK(
   2942       EmitExtraOutputsForReduce(unnested_hlo, input_index, extra_output_gens));
   2943 }
   2944 
   2945 // Emits a kernel for the hlo instruction using the given tiling scheme.
   2946 void IrEmitterUnnested::EmitBlock(const TileGenerator& emit_one_tile,
   2947                                   KernelCodegenInfo* kernel_info,
   2948                                   KernelSupportLibrary* ksl,
   2949                                   llvm::Type* index_ty) {
   2950   KernelMappingScheme* mapping_scheme = kernel_info->GetKernelMappingScheme();
   2951   absl::Span<const int64> dims_in_tile = mapping_scheme->GetDimensionsInTiles();
   2952   absl::Span<const int64> dims_in_block =
   2953       mapping_scheme->GetDimensionsInBlocks();
   2954   absl::Span<const int64> block_sizes = mapping_scheme->GetBlockSizes();
   2955   auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
   2956     return llvm::ConstantInt::get(index_ty, c);
   2957   };
   2958 
   2959   // Emit all the tiles for a given dimension in a tile block.
   2960   auto emit_tiles_for_block_dim =
   2961       [&](const string& loop_name, const IrArray::Index& starting_tile,
   2962           int dim_id,
   2963           const std::function<void(const IrArray::Index& tile_index)>
   2964               emit_next_block_dim) {
   2965         if (block_sizes[dim_id] == 1) {
   2966           emit_next_block_dim(starting_tile);
   2967         } else {
   2968           llvm::Value* starting_tile_index_for_dim = starting_tile[dim_id];
   2969           llvm::Value* block_size_for_dim =
   2970               index_typed_constant(block_sizes[dim_id]);
   2971           llvm::Value* block_id_for_dim =
   2972               b_.CreateUDiv(starting_tile_index_for_dim, block_size_for_dim);
   2973           llvm::Value* last_block_for_dim =
   2974               index_typed_constant(dims_in_block[dim_id] - 1);
   2975           llvm::Value* last_block_size_for_dim = index_typed_constant(
   2976               dims_in_tile[dim_id] -
   2977               (dims_in_block[dim_id] - 1) * block_sizes[dim_id]);
   2978           llvm::Value* num_tiles_in_block =
   2979               Select(ICmpEQ(last_block_for_dim, block_id_for_dim),
   2980                      last_block_size_for_dim, block_size_for_dim);
   2981           ksl->For(loop_name,
   2982                    /*start=*/index_typed_constant(0),
   2983                    /*end=*/num_tiles_in_block,
   2984                    /*step=*/1, [&](llvm::Value* block_dim_induction_var) {
   2985                      IrArray::Index tile_index = starting_tile.AddOffsetToDim(
   2986                          block_dim_induction_var, dim_id, &b_);
   2987                      emit_next_block_dim(tile_index);
   2988                    });
   2989         }
   2990       };
   2991 
   2992   absl::Span<const int64> reduced_dims =
   2993       mapping_scheme->GetDimensionsInElements();
   2994   const bool block_contains_multi_tiles =
   2995       mapping_scheme->GetNumberOfTilesInOneBlock() > 1;
   2996 
   2997   // Emit the tile with a given tile_index, by calculating the tight bounds for
   2998   // each dimension of the tile and then calling emit_one_tile.
   2999   auto emit_one_tile_for_tile_index = [&](const IrArray::Index& tile_index) {
   3000     std::vector<llvm::Value*> output_tile_bounds(3);
   3001     for (int i = KernelMappingScheme::DimY; i < KernelMappingScheme::DimTot;
   3002          ++i) {
   3003       int64 tile_size_for_dim = mapping_scheme->GetTileSizeForDimension(i);
   3004       // Only last row or column may not have full size.
   3005       llvm::Value* is_last_row =
   3006           ICmpEQ(tile_index[i], index_typed_constant(dims_in_tile[i] - 1));
   3007       int64 partial_row_size =
   3008           reduced_dims[i] - (dims_in_tile[i] - 1) * tile_size_for_dim;
   3009       output_tile_bounds[i] =
   3010           Select(is_last_row, index_typed_constant(partial_row_size),
   3011                  index_typed_constant(tile_size_for_dim), "tile_bound");
   3012     }
   3013 
   3014     IrArray::Index tile_origin =
   3015         mapping_scheme->GetElementIndexForTileOrigin(tile_index);
   3016     emit_one_tile(tile_origin, output_tile_bounds, block_contains_multi_tiles);
   3017   };
   3018 
   3019   const IrArray::Index starting_block =
   3020       mapping_scheme->EmitBlockIndex(index_ty);
   3021   const IrArray::Index starting_tile_for_dim_z =
   3022       mapping_scheme->GetTileIndexForBlockOrigin(starting_block);
   3023 
   3024   // Emit the three dimensional block of tiles.
   3025   emit_tiles_for_block_dim(
   3026       "block_dim_z", starting_tile_for_dim_z, KernelMappingScheme::DimZ,
   3027       [&](const IrArray::Index& starting_tile_for_dim_y) {
   3028         emit_tiles_for_block_dim(
   3029             "block_dim_y", starting_tile_for_dim_y, KernelMappingScheme::DimY,
   3030             [&](const IrArray::Index& starting_tile_for_dim_x) {
   3031               emit_tiles_for_block_dim("block_dim_x", starting_tile_for_dim_x,
   3032                                        KernelMappingScheme::DimX,
   3033                                        emit_one_tile_for_tile_index);
   3034             });
   3035       });
   3036 }
   3037 
   3038 // Emits a kernel for the hlo instruction using the given kernel mapping scheme.
   3039 //
   3040 // unnested_hlo: The unnested hlo instruction for which the kernel is generated.
   3041 //   Currently, these hlo instructions are supported: kLoop fusion, kCopy.
   3042 // tiled_param_ids: The IDs for the parameters that are 0-2-1 transpose of
   3043 //   other tensors with the same dimensions and are safe to be tranposed via
   3044 //   the shared memory tranpose implementation.
   3045 // mapping_scheme: The tiling scheme to use.
   3046 // kernel_generator: Contains function objects for code generation, such as
   3047 //   element generator, block prologue and epilogue generators.
   3048 // kernel_info: Represent other information to support the code generation
   3049 //   of the tiled kernel for the hlo.
   3050 LaunchDimensions IrEmitterUnnested::EmitKernel(
   3051     HloInstruction* unnested_hlo, absl::Span<const int64> tiled_param_ids,
   3052     const KernelCodeGenerator& kernel_generator,
   3053     KernelCodegenInfo* kernel_info) {
   3054   KernelMappingScheme* mapping_scheme = kernel_info->GetKernelMappingScheme();
   3055 
   3056   std::vector<IrArray> param_arrays = ConstructIrArrayForInputs(*unnested_hlo);
   3057   int64 num_params = param_arrays.size();
   3058   // Allocate shared memory buffers to store the tiled inputs.
   3059   std::vector<llvm::Value*> param_shmem_buffers(num_params, nullptr);
   3060   for (int64 id : tiled_param_ids) {
   3061     const HloInstruction* param = unnested_hlo->operand(id);
   3062     param_shmem_buffers[id] =
   3063         mapping_scheme->GetSharedMemoryBufferForElementType(
   3064             llvm_ir::PrimitiveTypeToIrType(param->shape().element_type(),
   3065                                            module_),
   3066             IrName(unnested_hlo, StrCat("tile", id)));
   3067     VLOG(3) << "Added shmem buffer for parameter " << id << ": "
   3068             << llvm_ir::DumpToString(*param_shmem_buffers[id]);
   3069   }
   3070 
   3071   const ReductionCodegenInfo* reduction_info =
   3072       dynamic_cast<const ReductionCodegenInfo*>(kernel_info);
   3073   bool is_column_reduction =
   3074       (reduction_info && !reduction_info->IsRowReduction());
   3075 
   3076   LaunchDimensions launch_dimensions =
   3077       LaunchDimensions(mapping_scheme->GetNumberOfBlocks(),
   3078                        mapping_scheme->GetThreadsPerBlock());
   3079 
   3080   // TODO(b/110211620): Enable int32 index type for column reduction.
   3081   llvm::Type* index_ty =
   3082       is_column_reduction
   3083           ? b_.getInt64Ty()
   3084           : GetIndexTypeForKernel(unnested_hlo,
   3085                                   launch_dimensions.launch_bound(), &b_);
   3086 
   3087   auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
   3088     return llvm::ConstantInt::get(index_ty, c);
   3089   };
   3090 
   3091   // For multioutput fusion, one thread needs to output a tuple with pointers to
   3092   // all the individual outputs.  We could do this at any point in the kernel,
   3093   // but we do it at the beginning in the hopes of reducing register pressure,
   3094   // since we touch threadIdx.x and blockIdx.x at the beginning of the kernel
   3095   // *anyway*.
   3096   if (!reduction_info && unnested_hlo->IsMultiOutputFusion()) {
   3097     KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] {
   3098       llvm_ir::EmitTuple(GetIrArray(*unnested_hlo, *unnested_hlo),
   3099                          ConstructIrArrayForOutputs(*unnested_hlo), &b_);
   3100     });
   3101   }
   3102 
   3103   // For each tiled parameter, cast its input IrArray to the corresponding
   3104   // reduced shape and keep the reduced shape live during IR emission.
   3105   std::vector<IrArray> param_in_reduced_shape_arrays;
   3106   std::vector<Shape> param_reduced_shapes;
   3107   absl::Span<const int64> reduced_dims =
   3108       mapping_scheme->GetDimensionsInElements();
   3109   int num_shapes = ConstructInputReducedShapeAndCastInputIrArrayToShape(
   3110       *unnested_hlo, param_arrays, param_shmem_buffers, reduced_dims,
   3111       &param_reduced_shapes, &param_in_reduced_shape_arrays);
   3112   DCHECK_EQ(num_shapes, num_params);
   3113 
   3114   // Calculate the starting element coordinate within a tile for the current
   3115   // thread, (y, x) from thread_id.
   3116   llvm::Value* x;
   3117   llvm::Value* y;
   3118   std::tie(y, x) = mapping_scheme->EmitThreadYXCoordinate(index_ty);
   3119 
   3120   kernel_info->SetLaneId(
   3121       mapping_scheme->GetNumberOfThreadsForDimensionX() == kWarpSize ? x
   3122                                                                      : nullptr);
   3123   kernel_info->SetIndexType(index_ty);
   3124 
   3125   KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
   3126   // Curry a few parameters to EmitTiledElementalCodeWithBoundsCheck.
   3127   auto emit_tiled_elemental_code_with_bounds_check =
   3128       [&](const IrArray::Index& index, const string& loop_name,
   3129           llvm::Value* tile_height, llvm::Value* tile_width,
   3130           const EmitElementFunction& emit_elem_function) {
   3131         EmitTiledElementalCodeWithBoundsCheck(mapping_scheme, index, loop_name,
   3132                                               &ksl, &b_, y, x, tile_height,
   3133                                               tile_width, emit_elem_function);
   3134       };
   3135 
   3136   auto emit_one_tile = [&](const IrArray::Index& output_tile_origin,
   3137                            absl::Span<llvm::Value* const> output_tile_bounds,
   3138                            bool block_contains_multi_tiles) {
   3139     // Calculate the input tile origin from the output tile origin.
   3140     const IrArray::Index input_tile_origin(
   3141         Permute({0, 2, 1}, output_tile_origin.multidim()));
   3142 
   3143     // If shared memory transpose is needed, wait for all threads to reach this
   3144     // point, lest we copy a value from tile to output before the other thread
   3145     // copies it from input to tile. This is `__syncthreads` in CUDA.
   3146     if (!tiled_param_ids.empty()) {
   3147       // Copy input parameter values to shared memory buffers:
   3148       // tile[y, x] = input[index]
   3149       // Note that tile_width and tile_height are flipped here because we are
   3150       // reading a transposed tile.
   3151       emit_tiled_elemental_code_with_bounds_check(
   3152           input_tile_origin, "input", output_tile_bounds[2],
   3153           output_tile_bounds[1],
   3154           [&](const IrArray::Index& index, llvm::Value* y_loc,
   3155               llvm::Value* x_loc, int64 /*x_iter_num*/) {
   3156             for (int64 id : tiled_param_ids) {
   3157               IrArray& input_in_logical_shape =
   3158                   param_in_reduced_shape_arrays[id];
   3159               llvm::Value* shmem_buffer = param_shmem_buffers[id];
   3160               // TODO(jlebar): Add AA metadata to this store.  Tile buffers are
   3161               // global variables, so LLVM can't infer much about it.
   3162               Store(input_in_logical_shape.EmitReadArrayElement(
   3163                         index, &b_, "input_element"),
   3164                     GEP(shmem_buffer, {index_typed_constant(0), y_loc, x_loc}));
   3165             }
   3166           });
   3167 
   3168       // Wait for all threads to reach this point using `__syncthreads` in CUDA.
   3169       llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, &b_);
   3170     }
   3171 
   3172     llvm_ir::TiledParameterInfo tiled_param_info(param_shmem_buffers, y, x);
   3173     kernel_info->SetTiledParamInfo(&tiled_param_info);
   3174 
   3175     // Write to output[index] by emitting code like normal, except that values
   3176     // for the tiled parameters are read from the shmem buffers.
   3177     emit_tiled_elemental_code_with_bounds_check(
   3178         output_tile_origin, "output", output_tile_bounds[1],
   3179         output_tile_bounds[2],
   3180         [&](const IrArray::Index& index, llvm::Value* y_loc, llvm::Value* x_loc,
   3181             int64 x_iter_num) {
   3182           kernel_generator.GetTileElementGenerator()(
   3183               unnested_hlo, index, kernel_info, y_loc, x_loc, x_iter_num);
   3184         });
   3185 
   3186     // If a tile block contains multiple tiles and shared memory buffers are
   3187     // used, we need to wait for all threads to finish using the shared memory
   3188     // buffer for the current tile before we move on to process the next tile
   3189     // and overwrite the shared memory buffers.
   3190     if (block_contains_multi_tiles && !tiled_param_ids.empty()) {
   3191       llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, &b_);
   3192     }
   3193   };
   3194 
   3195   const BlockPrologueGenerator& block_prologue_generator =
   3196       kernel_generator.GetBlockPrologueGenerator();
   3197   if (block_prologue_generator) {
   3198     block_prologue_generator(unnested_hlo, kernel_info);
   3199   }
   3200 
   3201   EmitBlock(std::move(emit_one_tile), kernel_info, &ksl, index_ty);
   3202 
   3203   const BlockEpilogueGenerator& block_epilogue_generator =
   3204       kernel_generator.GetBlockEpilogueGenerator();
   3205   if (block_epilogue_generator) {
   3206     block_epilogue_generator(unnested_hlo, kernel_info);
   3207   }
   3208 
   3209   return launch_dimensions;
   3210 }
   3211 
   3212 // Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose
   3213 // algorithm to improve the memory access patterns for the input parameters
   3214 // with a shape that is a 0-2-1 transpose of the output tensor shape. The caller
   3215 // is responsible for making sure that it is safe to apply the shared memory
   3216 // tranpose on the input parameters.
   3217 //
   3218 //
   3219 // For the purpose of tiling, the output tensors have a logical shape of three
   3220 // components 0-2-1 while the relevant input parameters have a logical shape
   3221 // of three components 0-1-2 in the order major to minor. The x- and y-
   3222 // dimensions of the tensors are tiled in square tiles with an edge length
   3223 // `kTileSize`. Each thread block of `kTileSize` x `kNumRows` threads
   3224 // transposes one tile: each thread copies kTileSize/kNumRows elements from
   3225 // the input to a shared memory tile, then the otherwise "regular HLO kernel"
   3226 // reads from the shared memory instead of the original input.
   3227 //
   3228 // This is similar to the following CUDA algorithm in TensorFlow:
   3229 // https://goo.gl/MStRV6.
   3230 //
   3231 // `kTileSize` should usually be same as warp size. We currently choose 32 for
   3232 // `kTileSize` and 4 for `kNumRows`. The CUDA algorithm uses 8 for `kNumRows`.
   3233 //
   3234 // TODO(b/33320379): Here each block transposes 1 tile. It may be more
   3235 // efficient to launch fewer blocks so each transposes many tiles.
   3236 LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
   3237     HloInstruction* hlo, absl::Span<const int64> reduced_output_dims,
   3238     absl::Span<const int64> tiled_param_ids) {
   3239   constexpr int kNumRows = 4;
   3240   KernelMappingScheme mapping_scheme(
   3241       reduced_output_dims, /*tile_size_y=*/kWarpSize,
   3242       /*tile_size_x=*/kWarpSize, /*req_block_sizes=*/{1, 1, 1},
   3243       /*num_threads_y=*/kNumRows,
   3244       /*num_threads_x=*/kWarpSize, &b_);
   3245   TileElementGenerator element_generator;
   3246   if (hlo->opcode() == HloOpcode::kCopy) {
   3247     element_generator = [&](HloInstruction* hlo,
   3248                             const llvm_ir::IrArray::Index& index,
   3249                             const KernelCodegenInfo* kernel_info,
   3250                             llvm::Value* y_loc, llvm::Value* x_loc,
   3251                             int64 x_iter_num) {
   3252       EmitTileElementForCopy(hlo, index, kernel_info, y_loc, x_loc, x_iter_num);
   3253     };
   3254   } else {
   3255     DCHECK_EQ(hlo->opcode(), HloOpcode::kFusion);
   3256     element_generator =
   3257         [&](HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
   3258             const KernelCodegenInfo* kernel_info, llvm::Value* y_loc,
   3259             llvm::Value* x_loc, int64 x_iter_num) {
   3260           EmitTileElementForFusion(hlo, index, kernel_info, y_loc, x_loc,
   3261                                    x_iter_num);
   3262         };
   3263   }
   3264   KernelCodegenInfo kernel_info(&mapping_scheme);
   3265   KernelCodeGenerator kernel_generator(std::move(element_generator));
   3266   return EmitKernel(hlo, tiled_param_ids, kernel_generator, &kernel_info);
   3267 }
   3268 
   3269 namespace {
   3270 // A recursive function to inspect the users of a parameter to determine
   3271 // whether it's safe for a parameter to participate in a shared-memory
   3272 // transpose.
   3273 //
   3274 // Consider a fusion parameter P for which we might want to use a shmem
   3275 // transpose.  If we do, we use a GPU thread block to preload a tile of P with
   3276 // indices [z, y..y+31, x..x+31] to compute an output tile with the same indices
   3277 // cooperatively, where z, y, x are the indices for the normalized input/output
   3278 // tensor (see the document for FindTranspose021 for the definition of
   3279 // normalized tensor for 0-2-1 transpose). This shmem transpose implementation
   3280 // requires that the computation of the output tile only read elements within
   3281 // the preload tile. If this is not true, we can't use a shmem transpose for P.
   3282 //
   3283 // If the computation of output element [z, y, x] only requires the element of
   3284 // P with the same indices, the shmem tranpose implementation can be applied
   3285 // to P safely. This is a sufficient but not necessary condition. We check all
   3286 // the transitive users of P to see if we can find a user that may cause an
   3287 // exception to the situation. If such a user is not found, we conclude that P
   3288 // is safe for shmem transpose.
   3289 //
   3290 // This is trivially true for elementwise operations and some "data-movement"
   3291 // ops like kTuple. However, it's not true for operations that can change the
   3292 // dimensions of the inputs (e.g. pad, slice) and bitcast operation.
   3293 // For example:
   3294 //
   3295 // fused_computation {
   3296 //   param_0 = f32[64,64]{1,0} parameter(0)
   3297 //   ROOT bitcast = f32[64,64]{0,1} bitcast(param_0)
   3298 // }
   3299 // The output element at logical address [0, 63] depends on the input element
   3300 // at logical address [63, 0], which would not be within the shared-memory
   3301 // block.
   3302 //
   3303 // TODO(bixia): In order to extend this for kInput fusion, that is reduction
   3304 // with tranpose, we only need to end the use-chain checking with the input of
   3305 // a reduce operations. In this case, the above description on "output" apply
   3306 // to the result of such a use-chain, which provides the input to the reduce
   3307 // operation.
   3308 bool IsInstructionSafeForShmemTranspose(const HloInstruction* hlo) {
   3309   if (hlo->IsElementwise()) {
   3310     return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) {
   3311       return IsInstructionSafeForShmemTranspose(user);
   3312     });
   3313   }
   3314 
   3315   switch (hlo->opcode()) {
   3316     // Non-elementwise instructions that don't cause the shmem transpose
   3317     // to be unsafe, including the instructions that don't currently fuse.
   3318     case HloOpcode::kGetDimensionSize:
   3319       // The result of the operation doesn't rely on the content of the
   3320       // tensor. As such, there is no need to further inspect its users.
   3321       return true;
   3322     case HloOpcode::kGetTupleElement:
   3323     case HloOpcode::kMap:
   3324     case HloOpcode::kParameter:
   3325     case HloOpcode::kTuple:
   3326     case HloOpcode::kTupleSelect:
   3327       return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) {
   3328         return IsInstructionSafeForShmemTranspose(user);
   3329       });
   3330 
   3331     default:
   3332       return false;
   3333   }
   3334 }
   3335 
   3336 // Given a group of input parameters that are 0-2-1 tranpose of the outputs of
   3337 // a fusion kernel, returns the input parameters that are safe for the shared
   3338 // memory tranpose implementation.
   3339 //
   3340 // When a tile based shared memory transpose is used to implement an input with
   3341 // 0-2-1 transpose, we preload a tile of the input elements
   3342 // [z, y..y+31, x..x+31] to compute the output tile elements of the same
   3343 // indices. Preloading the input tile this way is only safe when the computation
   3344 // of the output tile elements do not need any input element outside the
   3345 // preloaded tile. We inspect all the transitive users of the input parameter
   3346 // up to the fusion root instruction to see if we can find any instruction
   3347 // that can make preloading the input tile unsafe.
   3348 std::vector<int64> FilterInputsForShmemTranspose(const HloInstruction* fusion,
   3349                                                  std::vector<int64> input_ids) {
   3350   std::vector<int64> filtered_input_ids;
   3351   for (int64 i = 0; i < input_ids.size(); ++i) {
   3352     const HloInstruction* input = fusion->fused_parameter(input_ids[i]);
   3353     if (IsInstructionSafeForShmemTranspose(input)) {
   3354       filtered_input_ids.push_back(input_ids[i]);
   3355     } else {
   3356       VLOG(10) << "Input not safe for shmem transpose " << input->ToString()
   3357                << "\n";
   3358     }
   3359   }
   3360   return filtered_input_ids;
   3361 }
   3362 
   3363 }  // namespace
   3364 
   3365 bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) {
   3366   HloOpcode opcode = hlo->opcode();
   3367   CHECK(opcode == HloOpcode::kFusion || opcode == HloOpcode::kCopy);
   3368   CHECK(opcode != HloOpcode::kFusion ||
   3369         hlo->fusion_kind() == HloInstruction::FusionKind::kLoop)
   3370       << "Only loop fusions are supported.";
   3371 
   3372   const Shape& output_shape = hlo->IsMultiOutputFusion()
   3373                                   ? ShapeUtil::GetSubshape(hlo->shape(), {0})
   3374                                   : hlo->shape();
   3375 
   3376   // If the output_shape is reduced to 021 shape, find all the parameters of
   3377   // the HLO that are in the corresponding 012 shape.
   3378   std::vector<int64> params_012;
   3379   optional<std::vector<int64>> reduced_dims_021;
   3380   for (int64 operand_idx = 0; operand_idx < hlo->operand_count();
   3381        ++operand_idx) {
   3382     HloInstruction* operand = hlo->mutable_operand(operand_idx);
   3383     auto find_transpose_result =
   3384         llvm_ir::FindTranspose021(operand->shape(), output_shape);
   3385     if (!find_transpose_result.has_value()) {
   3386       continue;
   3387     }
   3388     const std::vector<int64>& curr_reduced_dims_021 = *find_transpose_result;
   3389     if (!reduced_dims_021.has_value()) {
   3390       reduced_dims_021 = curr_reduced_dims_021;
   3391     }
   3392     if (!absl::c_equal(*reduced_dims_021, curr_reduced_dims_021)) {
   3393       // There is more than one possible transpose. Instead of picking one
   3394       // transpose, we simply give up here.
   3395       return false;
   3396     }
   3397     params_012.push_back(operand_idx);
   3398   }
   3399 
   3400   if (!reduced_dims_021.has_value()) {
   3401     return false;
   3402   }
   3403 
   3404   if ((*reduced_dims_021)[1] < kMinDimensionToTransposeTiled ||
   3405       (*reduced_dims_021)[2] < kMinDimensionToTransposeTiled) {
   3406     return false;
   3407   }
   3408 
   3409   if (opcode == HloOpcode::kFusion) {
   3410     params_012 = FilterInputsForShmemTranspose(hlo, params_012);
   3411     if (params_012.empty()) {
   3412       return false;
   3413     }
   3414   }
   3415 
   3416   // Each of our shared memory tiles has 32*33 elements (so ~4kb, if the
   3417   // elements are of size 4 bytes), and CUDA has an architectural limit of
   3418   // 48kb shared memory per SM.  (This is increased to 96kb in Volta, but we
   3419   // don't use this, in part because it eats into our L1 cache space.)
   3420   //
   3421   // For correctness we need to ensure that we don't make more than 48kb worth
   3422   // of shmem tiles per block.  And for performance, we'd probably like to use
   3423   // significantly less, so that we can fit more than one block at a time on a
   3424   // gpu core.
   3425   //
   3426   // We say without benchmarks that we want at least 3 threads/block,
   3427   // corresponding to 3 shmem tiles if the elements are 32 bits wide.  We
   3428   // choose which params get the shmem transpose treatment arbitrarily; it's
   3429   // not clear if there's a Right Choice.
   3430   //
   3431   // This is only sound if tiled transposes are the only place where we use
   3432   // shared memory in fusions.  If in the future other fusible ops use shared
   3433   // memory, we'll have to adjust this heuristic.
   3434   constexpr int kMinBlocksPerCore = 3;
   3435   constexpr int64 kShmemPerCore = 48 * 1024;
   3436   int64 shmem_used = 0;
   3437   for (int64 i = 0; i < params_012.size(); ++i) {
   3438     const HloInstruction* operand = hlo->operand(params_012[i]);
   3439     shmem_used +=
   3440         32 * 33 *
   3441         ShapeUtil::ByteSizeOfPrimitiveType(operand->shape().element_type());
   3442 
   3443     if (kMinBlocksPerCore * shmem_used > kShmemPerCore) {
   3444       // Erase this element and everything after it from params_012.
   3445       params_012.resize(i);
   3446       break;
   3447     }
   3448   }
   3449 
   3450   VLOG(3) << "EmitHlo021Tile Emitting hlo tile 0-2-1" << hlo->ToString();
   3451   std::unique_ptr<KernelThunk> kernel_thunk =
   3452       BuildKernelThunk(hlo, /*implements_whole_instruction=*/true);
   3453   const LaunchDimensions launch_dimensions =
   3454       EmitHlo021Tile(hlo, *reduced_dims_021, params_012);
   3455   UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
   3456                          ir_emitter_context_->llvm_module());
   3457   AddThunkToThunkSequence(std::move(kernel_thunk));
   3458 
   3459   return true;
   3460 }
   3461 
   3462 namespace {
   3463 // Checks that the outputs of a fusion with reduction are consistent.
   3464 Status AreFusedReductionOutputsConsistent(
   3465     absl::Span<HloInstruction* const> output_instructions,
   3466     const HloInstruction* first_reduce) {
   3467   for (const HloInstruction* inst : output_instructions) {
   3468     if (inst->opcode() == HloOpcode::kReduce) {
   3469       // Shapes, layouts and dimensions must be the same for all reduces
   3470       // inside of this fusion.
   3471       TF_RET_CHECK(ShapeUtil::Equal(first_reduce->shape(), inst->shape()));
   3472       TF_RET_CHECK(ShapeUtil::Equal(first_reduce->operand(0)->shape(),
   3473                                     inst->operand(0)->shape()));
   3474       TF_RET_CHECK(ShapeUtil::Equal(first_reduce->operand(1)->shape(),
   3475                                     inst->operand(1)->shape()));
   3476       TF_RET_CHECK(first_reduce->dimensions() == inst->dimensions());
   3477     } else {
   3478       // For extra outputs we can relax shape equality to allow different
   3479       // types (with the same number of elements). Layouts still have to
   3480       // match.
   3481       TF_RET_CHECK(ShapeUtil::CompatibleIgnoringElementType(
   3482           first_reduce->operand(0)->shape(), inst->shape()));
   3483       TF_RET_CHECK(LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(),
   3484                                      inst->shape().layout()));
   3485     }
   3486   }
   3487   return Status::OK();
   3488 }
   3489 
   3490 // Finds the dimensions to keep for the reduction, sorts and returns the
   3491 // dimensions from minor to major.
   3492 DimensionVector GetDimensionsToKeepMinorToMajor(
   3493     const Shape& input_shape, absl::Span<const int64> dims_to_reduce) {
   3494   DimensionVector input_dims(input_shape.rank(), 0);
   3495   absl::c_iota(input_dims, 0);
   3496   DimensionVector input_dims_to_keep;
   3497   for (int input_dim : input_dims) {
   3498     auto it = absl::c_find_if(dims_to_reduce, [&](int64 dim_to_reduce) {
   3499       return dim_to_reduce == input_dim;
   3500     });
   3501     if (it == dims_to_reduce.end()) {
   3502       input_dims_to_keep.push_back(input_dim);
   3503     }
   3504   }
   3505 
   3506   // Sort the dimensions to keep from minor to major.
   3507   absl::c_sort(input_dims_to_keep, [&input_shape](int64 dim_a, int64 dim_b) {
   3508     return PositionInContainer(LayoutUtil::MinorToMajor(input_shape), dim_a) <
   3509            PositionInContainer(LayoutUtil::MinorToMajor(input_shape), dim_b);
   3510   });
   3511 
   3512   VLOG(10) << "dims to keep minor to major"
   3513            << absl::StrJoin(input_dims_to_keep, ",");
   3514   return input_dims_to_keep;
   3515 }
   3516 
   3517 // Given the input shape and dimensions to reduce for the reduction to vector,
   3518 // returns <num_reduced_major, num_kept, num_reduced_minor>:
   3519 // num_kept: the number of elements in the contiguous dimensions to keep.
   3520 // num_reduced_major: the number of elements in the dimensions to reduce that
   3521 //   are more major than the dimensions to keep.
   3522 // num_reduced_minor: the number of elements in the dimensions to reduce that
   3523 //   are more minor than the dimensions to kept.
   3524 std::tuple<int64, int64, int64> GetReductionToVectorDimensions(
   3525     const Shape& input_shape, absl::Span<const int64> dims_to_reduce) {
   3526   DimensionVector input_dims_to_keep_minor_to_major =
   3527       GetDimensionsToKeepMinorToMajor(input_shape, dims_to_reduce);
   3528   CHECK(LayoutUtil::AreDimensionsConsecutive(
   3529       input_shape.layout(), input_dims_to_keep_minor_to_major));
   3530   int num_reduced_major = 1, num_kept = 1, num_reduced_minor = 1;
   3531   if (input_dims_to_keep_minor_to_major.empty()) {
   3532     return std::make_tuple(num_reduced_major, num_kept, num_reduced_minor);
   3533   }
   3534   DimensionVector input_dims(input_shape.rank(), 0);
   3535   absl::c_iota(input_dims, 0);
   3536   absl::Span<const int64> minor_to_major =
   3537       LayoutUtil::MinorToMajor(input_shape);
   3538   for (int input_dim : input_dims) {
   3539     int64 curr_dim_size = input_shape.dimensions(input_dim);
   3540     if (PositionInContainer(minor_to_major, input_dim) >
   3541         PositionInContainer(minor_to_major,
   3542                             input_dims_to_keep_minor_to_major.back())) {
   3543       num_reduced_major *= curr_dim_size;
   3544     } else if (PositionInContainer(minor_to_major, input_dim) <
   3545                PositionInContainer(minor_to_major,
   3546                                    input_dims_to_keep_minor_to_major.front())) {
   3547       num_reduced_minor *= curr_dim_size;
   3548     } else {
   3549       num_kept *= curr_dim_size;
   3550     }
   3551   }
   3552 
   3553   return std::make_tuple(num_reduced_major, num_kept, num_reduced_minor);
   3554 }
   3555 
   3556 // Returns true if all the transitive users of hlo before hitting users in
   3557 // use_chain_endings are elementwise operations.
   3558 bool AreUsersElementwise(const HloInstruction* hlo,
   3559                          const ConstHloInstructionSet& use_chain_endings) {
   3560   return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) {
   3561     return use_chain_endings.count(user) ||
   3562            (user->IsElementwise() &&
   3563             AreUsersElementwise(user, use_chain_endings));
   3564   });
   3565 }
   3566 
   3567 // Returns the number of fusion inputs that have the same dimension as the
   3568 // given shape, and involve in only elementwise operations.
   3569 int64 NumInputsInvolveInOnlyElementwiseOps(
   3570     const HloInstruction* unnested_hlo, const Shape& op_shape,
   3571     const ConstHloInstructionSet& use_chain_endings) {
   3572   return absl::c_count_if(
   3573       unnested_hlo->fused_parameters(), [&](const HloInstruction* parameter) {
   3574         const Shape& parameter_shape = parameter->shape();
   3575         return ShapeUtil::SameDimensions(op_shape, parameter_shape) &&
   3576                AreUsersElementwise(parameter, use_chain_endings);
   3577       });
   3578 }
   3579 
   3580 // Returns the number of fusion inputs that have more elements than the given
   3581 // shape.
   3582 int64 NumInputsWithMoreElementsThan(const HloInstruction* unnested_hlo,
   3583                                     const Shape& shape) {
   3584   int64 num_elements = ShapeUtil::ElementsIn(shape);
   3585   return absl::c_count_if(
   3586       unnested_hlo->fused_parameters(), [&](const HloInstruction* parameter) {
   3587         return ShapeUtil::ElementsIn(parameter->shape()) > num_elements;
   3588       });
   3589 }
   3590 
   3591 // The benefit of unrolling a kInput fusion that is a column reduction comes
   3592 // from the vectorization of non-reduction fusion outputs and fusion inputs.
   3593 // On the other hand, unrolling can also introduce factors that can cause
   3594 // the kernel to run slower. This routine uses a simple heuristic to estimate
   3595 // the benefit as well as the overhead of unrolling in order to decide whether
   3596 // unrolling is beneficial for the given kInput fusion.
   3597 bool IsUnrollingColumnReductionBeneficial(const HloInstruction* unnested_hlo,
   3598                                           const Shape& input_shape,
   3599                                           int64 num_kept) {
   3600   // TODO(b/122468062): Need further investigate to see whether we can
   3601   // remove the constraint on IsPowerOfTwo.
   3602   if (!IsPowerOfTwo(static_cast<uint64>(num_kept))) {
   3603     return false;
   3604   }
   3605 
   3606   if (unnested_hlo->opcode() == HloOpcode::kReduce) {
   3607     return true;
   3608   }
   3609 
   3610   CHECK_EQ(unnested_hlo->opcode(), HloOpcode::kFusion);
   3611   int64 can_be_vectorized = 0;
   3612   int64 cannot_be_vectorized = 0;
   3613   const HloInstruction* fused_root = unnested_hlo->fused_expression_root();
   3614   ConstHloInstructionSet use_chain_endings;
   3615   if (fused_root->opcode() == HloOpcode::kReduce) {
   3616     use_chain_endings.insert(fused_root);
   3617     // Atomic.add of the reduction result can't be vectorized.
   3618     cannot_be_vectorized++;
   3619   } else {
   3620     CHECK_EQ(fused_root->opcode(), HloOpcode::kTuple);
   3621     for (const HloInstruction* instr : fused_root->operands()) {
   3622       if (instr->opcode() == HloOpcode::kReduce) {
   3623         // Atomic.add of the reduction result can't be vectorized.
   3624         cannot_be_vectorized++;
   3625       } else {
   3626         // Write of the non-reduction result can be vectorized.
   3627         can_be_vectorized++;
   3628       }
   3629       use_chain_endings.insert(instr);
   3630     }
   3631   }
   3632   // Fusion inputs that have the same dimension as the reduce input and
   3633   // only involve in elementwise operations can be vectorized.
   3634   can_be_vectorized += NumInputsInvolveInOnlyElementwiseOps(
   3635       unnested_hlo, input_shape, use_chain_endings);
   3636   // Fusion inputs with more elements than the reduce op input must participate
   3637   // in non-elementwise operations and we assume that they are not vectorizable
   3638   // for the purpose of estimating the benefit of unrolling. If the kernel is
   3639   // unrolled even with such an assumption,  and the accesses to those inputs
   3640   // turn out to be vectorizable, the compiler will still vectorize them.
   3641   cannot_be_vectorized +=
   3642       NumInputsWithMoreElementsThan(unnested_hlo, input_shape);
   3643   return can_be_vectorized >= cannot_be_vectorized;
   3644 }
   3645 
   3646 }  // namespace
   3647 
   3648 std::tuple<KernelMappingScheme, bool>
   3649 IrEmitterUnnested::ComputeMappingSchemeAndReductionKind(
   3650     const HloInstruction* unnested_hlo, const HloInstruction* first_reduce) {
   3651   int64 depth = 1;
   3652   int64 height = 1;
   3653   int64 width = 1;
   3654   bool is_row_reduction = true;
   3655   int64 tile_size_x = 1;
   3656   int64 tile_size_y = 1;
   3657   int64 block_size_z = 1;
   3658   int64 num_threads_x = 1;
   3659   int64 num_threads_y = 1;
   3660   const Shape& input_shape = first_reduce->operand(0)->shape();
   3661   int64 num_input_elems = ShapeUtil::ElementsIn(input_shape);
   3662   int64 num_output_elems = ShapeUtil::ElementsIn(first_reduce->shape());
   3663   int64 num_reduced_major, num_kept, num_reduced_minor;
   3664   std::tie(num_reduced_major, num_kept, num_reduced_minor) =
   3665       GetReductionToVectorDimensions(input_shape, first_reduce->dimensions());
   3666   CHECK_EQ(num_output_elems, num_kept);
   3667   bool dilated_x = true;
   3668 
   3669   if (num_kept == 1) {
   3670     // Scalar reduction is a special row reduction with depth = height = 1.
   3671     width = num_input_elems;
   3672     tile_size_x = kWarpSize * 16;
   3673     num_threads_x = kWarpSize;
   3674   } else if (num_reduced_minor == 1) {
   3675     // Column reduction reduces inputs with dimension [height, width], where
   3676     // width is the minor dimension, to dimension [width].
   3677     height = num_reduced_major;
   3678     width = num_kept;
   3679     is_row_reduction = false;
   3680     // Column reduction without transpose doesn't require communication among
   3681     // threads processing elements in the same tile. The current implementation
   3682     // only support the use of one hardware thread block to process one block of
   3683     // tiles in the KernelMappingScheme. We try to use one thread to compute
   3684     // the partial results for two tensor elements and to maximize the values of
   3685     // num_threads_x and tile_size_x to allow a bigger hardware thread block.
   3686     int64 hw_threads_per_block_limit =
   3687         ThreadsPerBlockLimit(ir_emitter_context_->device_description());
   3688     if (IsUnrollingColumnReductionBeneficial(unnested_hlo, input_shape,
   3689                                              num_kept)) {
   3690       tile_size_x = std::min(2 * hw_threads_per_block_limit, num_kept);
   3691       num_threads_x = tile_size_x / 2;
   3692       dilated_x = false;
   3693     } else {
   3694       tile_size_x = std::min(hw_threads_per_block_limit, num_kept);
   3695       num_threads_x = tile_size_x;
   3696     }
   3697     int64 kNumElementsPerPartialSum = 128;
   3698     tile_size_y = kNumElementsPerPartialSum;
   3699   } else {
   3700     // Row reduction reduces inputs with dimension [depth, height, width],
   3701     // where width is the most minor dimension, to dimension [height] .
   3702     depth = num_reduced_major;
   3703     height = num_kept;
   3704     width = num_reduced_minor;
   3705     num_threads_x = kWarpSize;
   3706     if (width % (kWarpSize * 64) == 0) {
   3707       tile_size_x = kWarpSize * 64;
   3708     } else {
   3709       tile_size_x = kWarpSize * 8;
   3710       block_size_z = 8;
   3711       while (depth % block_size_z != 0) {
   3712         block_size_z -= 1;
   3713       }
   3714     }
   3715   }
   3716   DCHECK_EQ(depth * height * width, num_input_elems);
   3717   VLOG(10) << "is_row_reduction " << is_row_reduction << depth << " " << height
   3718            << " " << width;
   3719 
   3720   DimensionVector dims_in_elem{depth, height, width};
   3721   DimensionVector req_block_sizes{block_size_z, 1, 1};
   3722   llvm_ir::KernelMappingScheme mapping_scheme(
   3723       dims_in_elem, tile_size_y, tile_size_x, req_block_sizes, num_threads_y,
   3724       num_threads_x, &b_);
   3725   mapping_scheme.SetDilatedX(dilated_x);
   3726   return std::make_tuple(mapping_scheme, is_row_reduction);
   3727 }
   3728 
   3729 Status IrEmitterUnnested::EmitReductionToVector(HloInstruction* unnested_hlo) {
   3730   VLOG(10) << "Emitting reduction to vector " << unnested_hlo->ToString();
   3731 
   3732   HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion
   3733                                         ? unnested_hlo->fused_expression_root()
   3734                                         : unnested_hlo;
   3735   absl::Span<HloInstruction* const> output_instructions =
   3736       GetOutputInstructions(&reduce_or_tuple);
   3737   const HloInstruction* first_reduce =
   3738       GetFirstReduceInstruction(output_instructions);
   3739 
   3740   if (output_instructions.size() > 1) {
   3741     TF_RETURN_IF_ERROR(
   3742         AreFusedReductionOutputsConsistent(output_instructions, first_reduce));
   3743   }
   3744 
   3745   // Build an initializer thunk to initialize each reduction output.
   3746   std::vector<std::unique_ptr<Thunk>> thunks;
   3747   for (int i = 0, e = output_instructions.size(); i != e; ++i) {
   3748     if (output_instructions[i]->opcode() != HloOpcode::kReduce) {
   3749       continue;
   3750     }
   3751     TF_ASSIGN_OR_RETURN(
   3752         std::unique_ptr<Thunk> initializer_thunk,
   3753         BuildInitializerThunk(unnested_hlo,
   3754                               (output_instructions[i] == reduce_or_tuple)
   3755                                   ? ShapeIndex()
   3756                                   : ShapeIndex({i})));
   3757     thunks.push_back(std::move(initializer_thunk));
   3758   }
   3759 
   3760   // Build a kernel thunk to compute all the outputs.
   3761   std::unique_ptr<KernelThunk> kernel_thunk =
   3762       BuildKernelThunk(unnested_hlo, /*implements_whole_instruction=*/false);
   3763 
   3764   const Shape& input_shape = first_reduce->operand(0)->shape();
   3765   // The layout of a reduction input is either set by LayoutAssignment for
   3766   // unnested kReduce or by InstructionFusion for fused kReduce.
   3767   CHECK(input_shape.has_layout()) << "LayoutAssignment or InstructionFusion "
   3768                                      "doesn't set the input layout of "
   3769                                   << first_reduce->ToString();
   3770 
   3771   bool is_row_reduction;
   3772   llvm_ir::KernelMappingScheme mapping_scheme;
   3773   std::tie(mapping_scheme, is_row_reduction) =
   3774       ComputeMappingSchemeAndReductionKind(unnested_hlo, first_reduce);
   3775   ReductionCodegenInfo reduction_info(&mapping_scheme, is_row_reduction);
   3776   KernelCodeGenerator kernel_generator(
   3777       /*tile_element_generator=*/
   3778       [&](HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
   3779           const KernelCodegenInfo* kernel_info, llvm::Value* y_loc,
   3780           llvm::Value* x_loc, int64 x_iter_num) {
   3781         EmitTileElementForReduction(hlo, index, kernel_info, y_loc, x_loc,
   3782                                     x_iter_num);
   3783       },
   3784       /*block_prologue_generator=*/
   3785       [&](HloInstruction* hlo, KernelCodegenInfo* kernel_info) {
   3786         EmitPrologueForReduction(hlo, kernel_info);
   3787       },
   3788       /*block_epilogue_generator*/
   3789       [&](HloInstruction* hlo, KernelCodegenInfo* kernel_info) {
   3790         EmitEpilogueForReduction(hlo, kernel_info);
   3791       });
   3792 
   3793   LaunchDimensions launch_dimensions =
   3794       EmitKernel(unnested_hlo, {}, kernel_generator, &reduction_info);
   3795   UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
   3796                          ir_emitter_context_->llvm_module());
   3797 
   3798   thunks.push_back(std::move(kernel_thunk));
   3799   std::unique_ptr<SequentialThunk> sequential_thunk =
   3800       absl::make_unique<SequentialThunk>(std::move(thunks), unnested_hlo);
   3801   AddThunkToThunkSequence(std::move(sequential_thunk));
   3802 
   3803   return Status::OK();
   3804 }
   3805 
   3806 Status IrEmitterUnnested::EmitConstantGlobals() {
   3807   for (const BufferAllocation& allocation :
   3808        ir_emitter_context_->buffer_assignment().Allocations()) {
   3809     if (!allocation.is_constant()) {
   3810       continue;
   3811     }
   3812 
   3813     const Literal& literal = llvm_ir::LiteralForConstantAllocation(allocation);
   3814     const bool should_emit_initializer = ShouldEmitLiteralInLlvmIr(literal);
   3815     llvm::ArrayType* global_type =
   3816         llvm::ArrayType::get(b_.getInt8Ty(), allocation.size());
   3817     llvm::Constant* initializer =
   3818         should_emit_initializer
   3819             ? llvm_ir::ConvertLiteralToIrConstant(literal, module_)
   3820             : llvm::ConstantAggregateZero::get(global_type);
   3821     if (should_emit_initializer) {
   3822       VLOG(3) << "Emitted initializer for constant with shape "
   3823               << ShapeUtil::HumanString(literal.shape());
   3824     }
   3825 
   3826     // These globals will be looked up by name by GpuExecutable so we need to
   3827     // give them an external linkage.  Not all of their uses are visible in
   3828     // the LLVM IR (e.g. TupleThunk) so we can't give then a linkage that
   3829     // merely preserves their names (like available_externally), we also need
   3830     // to ensure that they stick around even if they're "unused".
   3831     //
   3832     // We may have to be more more clever here in the future if we notice that
   3833     // we're keeping around too many globals because of their linkage.
   3834     llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable(
   3835         global_type, /*isConstant=*/should_emit_initializer,
   3836         llvm::GlobalValue::ExternalLinkage,
   3837         /*Initializer=*/initializer,
   3838         llvm_ir::ConstantBufferAllocationToGlobalName(allocation));
   3839     global_for_const->setAlignment(kConstantBufferAlignBytes);
   3840     ir_emitter_context_->llvm_module()->getGlobalList().push_back(
   3841         global_for_const);
   3842   }
   3843 
   3844   return Status::OK();
   3845 }
   3846 
   3847 }  // namespace gpu
   3848 }  // namespace xla
   3849