Home | History | Annotate | Download | only in gpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <memory>
     17 #include <string>
     18 #include <vector>
     19 
     20 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
     21 
     22 #include "llvm/ADT/StringRef.h"
     23 #include "llvm/IR/BasicBlock.h"
     24 #include "llvm/IR/Function.h"
     25 #include "llvm/IR/IRBuilder.h"
     26 #include "llvm/IR/Instructions.h"
     27 #include "llvm/IR/LLVMContext.h"
     28 #include "llvm/IR/Module.h"
     29 #include "tensorflow/compiler/xla/literal_util.h"
     30 #include "tensorflow/compiler/xla/ptr_util.h"
     31 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
     32 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
     33 #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h"
     34 #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
     35 #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h"
     36 #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h"
     37 #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
     38 #include "tensorflow/compiler/xla/service/gpu/fft_thunk.h"
     39 #include "tensorflow/compiler/xla/service/gpu/for_thunk.h"
     40 #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h"
     41 #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
     42 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
     43 #include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h"
     44 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
     45 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
     46 #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h"
     47 #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
     48 #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
     49 #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
     50 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
     51 #include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h"
     52 #include "tensorflow/compiler/xla/service/gpu/while_thunk.h"
     53 #include "tensorflow/compiler/xla/service/gpu/while_transformer.h"
     54 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     55 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     56 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     57 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
     58 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
     59 #include "tensorflow/compiler/xla/service/llvm_ir/ops.h"
     60 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
     61 #include "tensorflow/compiler/xla/service/name_uniquer.h"
     62 #include "tensorflow/compiler/xla/shape_util.h"
     63 #include "tensorflow/compiler/xla/status_macros.h"
     64 #include "tensorflow/compiler/xla/types.h"
     65 #include "tensorflow/compiler/xla/util.h"
     66 #include "tensorflow/compiler/xla/window_util.h"
     67 #include "tensorflow/compiler/xla/xla_data.pb.h"
     68 #include "tensorflow/core/lib/core/status.h"
     69 #include "tensorflow/core/lib/gtl/array_slice.h"
     70 #include "tensorflow/core/platform/logging.h"
     71 
     72 namespace xla {
     73 namespace gpu {
     74 
     75 namespace {
     76 
     77 using llvm_ir::IrName;
     78 using tensorflow::gtl::ArraySlice;
     79 using tensorflow::gtl::nullopt;
     80 using tensorflow::gtl::optional;
     81 using tensorflow::strings::StrCat;
     82 
     83 // If a dimensions is smaller than this, untiled transposition may be more
     84 // efficient.
     85 const int64 kMinDimensionToTransposeTiled = 16;
     86 
     87 // Returns true if all paths from `hlo` to `root` contain only tuples. The
     88 // result of such an HloInstruction does not need to be materialized, when the
     89 // computation can have a hybrid result.
     90 bool ReachRootViaOnlyTuples(const HloInstruction& hlo,
     91                             const HloInstruction& root) {
     92   if (hlo.opcode() != HloOpcode::kTuple) {
     93     return false;
     94   }
     95 
     96   if (&hlo == &root) {
     97     return true;
     98   }
     99 
    100   for (HloInstruction* user : hlo.users()) {
    101     if (!ReachRootViaOnlyTuples(*user, root)) {
    102       return false;
    103     }
    104   }
    105 
    106   return true;
    107 }
    108 
    109 // If `hlo` is a Transpose, returns its operand; otherwise returns `hlo` itself.
    110 const HloInstruction* StripTranspose(const HloInstruction& hlo) {
    111   if (hlo.IsRank2Transpose()) {
    112     return hlo.operand(0);
    113   }
    114   return &hlo;
    115 }
    116 
    117 // Updates the launch dimensions in "thunk" and annotate the launch dimensions
    118 // of the corresponding IR kernel in "llvm_module".
    119 // Precondition: "thunk" must be a KernelThunk.
    120 void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk,
    121                             llvm::Module* llvm_module) {
    122   CHECK(Thunk::Kind::kKernel == thunk->kind());
    123   KernelThunk* kernel_thunk = static_cast<KernelThunk*>(thunk);
    124   kernel_thunk->SetLaunchDimensions(launch_dims);
    125 
    126   // Add __launch_bounds__ to metadata. This limits registers per thread to
    127   // avoid out-of-resources launching errors.
    128   llvm::NamedMDNode* nvvm_annotations_node =
    129       llvm_module->getOrInsertNamedMetadata("nvvm.annotations");
    130   llvm::Function* ir_kernel =
    131       llvm_module->getFunction(kernel_thunk->kernel_name().c_str());
    132   llvm::LLVMContext& llvm_context = llvm_module->getContext();
    133   llvm::ConstantInt* threads_per_block_ir_value = llvm::ConstantInt::get(
    134       llvm::IntegerType::get(llvm_context, /*NumBits=*/32),
    135       launch_dims.threads_per_block());
    136   // Our launch bounds are exact, so we can specify them as reqntidx rather than
    137   // maxntidx.
    138   nvvm_annotations_node->addOperand(llvm::MDNode::get(
    139       llvm_context,
    140       {llvm::ConstantAsMetadata::get(ir_kernel),
    141        llvm::MDString::get(llvm_context, "reqntidx"),
    142        llvm::ConstantAsMetadata::get(threads_per_block_ir_value)}));
    143 }
    144 
    145 // Tries to get a Slice for the given instruction at the given index, but
    146 // returns nullopt if we might not know the slice's address at runtime without
    147 // dereferencing a containing tuple.
    148 //
    149 // In particular, when XLA accepts a parameter of tuple type, the caller has the
    150 // option of telling XLA what are the values inside of the tuple, or just giving
    151 // XLA a pointer to the top-level tuple and letting us chase the pointers on the
    152 // GPU.  We therefore cannot rely having these pointers to parameter sub-buffers
    153 // being present when we run the program.
    154 optional<BufferAllocation::Slice> GetKnownAtRuntimeSlice(
    155     const HloInstruction* instr, const ShapeIndex& index,
    156     const BufferAssignment& buffer_assn) {
    157   auto maybe_slice = buffer_assn.GetUniqueSlice(instr, index);
    158   if (!maybe_slice.ok()) {
    159     return nullopt;
    160   }
    161   // BufferAllocation gives a slice and alloc to every buffer accessed by XLA,
    162   // but we don't necessarily know the runtime address of sub-buffers of input
    163   // parameters.
    164   const BufferAllocation::Slice& slice = maybe_slice.ValueOrDie();
    165   const BufferAllocation* alloc = slice.allocation();
    166   if (alloc->IsInputOrOutput() && !alloc->maybe_live_out() &&
    167       !alloc->param_shape_index().empty()) {
    168     return nullopt;
    169   }
    170 
    171   // Otherwise, we will know the address of this slice at runtime without having
    172   // to dereference a tuple.
    173   return slice;
    174 }
    175 
    176 }  // namespace
    177 
    178 IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
    179                                      const HloComputation* hlo_computation,
    180                                      IrEmitterContext* ir_emitter_context)
    181     : IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/false),
    182       hlo_computation_(hlo_computation) {
    183   // Initialize thunk_sequence_ to an empty list of thunks.
    184   thunk_sequence_.reset(new ThunkSequence());
    185 }
    186 
    187 Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) {
    188   bindings_.UnbindAllLocalIrValues();
    189   return DfsHloVisitor::Postprocess(hlo);
    190 }
    191 
    192 namespace {
    193 bool ImplementedAsHostToDeviceMemcpy(const BufferAssignment& buffer_assignment,
    194                                      const HloInstruction& hlo) {
    195   // `hlo` needs to satisfy the following conditions to be implemented as a
    196   // host-to-device cuMemcpy.
    197   //
    198   // 1. `hlo` is a kCopy instruction.
    199   // 2. `hlo`'s only operand is a kConstant instruction.
    200   // 3. `hlo` and its operand have the same shape (thus the same layout too).
    201   // 4. The address of `hlo`'s buffer is known at runtime (without dereferencing
    202   //    pointers in a tuple).
    203   return hlo.opcode() == HloOpcode::kCopy &&
    204          hlo.operand(0)->opcode() == HloOpcode::kConstant &&
    205          ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) &&
    206          GetKnownAtRuntimeSlice(&hlo, {}, buffer_assignment).has_value();
    207 }
    208 
    209 bool ImplementedAsDeviceToDeviceMemcpy(
    210     const BufferAssignment& buffer_assignment, const HloInstruction& hlo) {
    211   // `hlo` needs to satisfy three conditions to be implemented as a
    212   // device-to-device cuMemcpy.
    213   //
    214   // 1. `hlo` is a kCopy instruction.
    215   // 2. `hlo` and its operand have the same shape (thus the same layout too).
    216   // 3. The operand to `hlo` has a buffer assignment (constants do not, for
    217   //    instance) which means the source buffer also resides on the device.
    218   return hlo.opcode() == HloOpcode::kCopy &&
    219          ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) &&
    220          GetKnownAtRuntimeSlice(&hlo, {}, buffer_assignment).has_value() &&
    221          GetKnownAtRuntimeSlice(hlo.operand(0), {}, buffer_assignment)
    222              .has_value();
    223 }
    224 }  // namespace
    225 
    226 llvm::Function* IrEmitterUnnested::BuildKernelPrototype(
    227     const HloInstruction& inst,
    228     tensorflow::gtl::ArraySlice<const BufferAllocation*> args) {
    229   // Compute the kernel name. The opcode string may contain "-" which cannot be
    230   // in a PTX function name, so sanitize the name before uniquifying it.
    231   string kernel_name = ir_emitter_context_->name_uniquer()->GetUniqueName(
    232       llvm_ir::SanitizeFunctionName(inst.name()));
    233 
    234   // Create the kernel and add it to the module.
    235   llvm::Module* module = ir_emitter_context_->llvm_module();
    236   llvm::LLVMContext& context = module->getContext();
    237   llvm::FunctionType* kernel_type = llvm::FunctionType::get(
    238       /*Result=*/llvm::Type::getVoidTy(context),
    239       std::vector<llvm::Type*>(args.size(), ir_builder_.getInt8PtrTy()),
    240       /*isVarArg=*/false);
    241   llvm::Function* kernel =
    242       llvm::Function::Create(kernel_type, llvm::GlobalValue::ExternalLinkage,
    243                              kernel_name.c_str(), module);
    244 
    245   // Add dereferenceable and alignment information to each of the kernel's
    246   // parameters.
    247   auto arg_it = kernel->arg_begin();
    248   for (size_t arg_no = 0; arg_no < args.size(); ++arg_no) {
    249     const BufferAllocation* alloc = args[arg_no];
    250     llvm::Argument* fn_arg = &*arg_it;
    251     ++arg_it;
    252 
    253     kernel->addDereferenceableAttr(arg_no + 1, alloc->size());
    254     kernel->addParamAttr(
    255         arg_no, llvm::Attribute::get(context, llvm::Attribute::Alignment,
    256                                      kCudaMallocAlignBytes));
    257 
    258     if (alloc->IsPreallocatedTempBuffer()) {
    259       fn_arg->setName("temp_buf");
    260     } else {
    261       fn_arg->setName(llvm_ir::AsStringRef(StrCat("alloc", alloc->index())));
    262     }
    263   }
    264 
    265   // TODO(b/65380986): Investigate if adding fast math flags for generated
    266   // kernels makes sense.
    267 
    268   // Add the declaration of this kernel to llvm.nvvm.annotations so that NVPTX
    269   // treats it as a CUDA kernel.
    270   llvm::NamedMDNode* nvvm_annotations_node =
    271       module->getOrInsertNamedMetadata("nvvm.annotations");
    272   nvvm_annotations_node->addOperand(llvm::MDNode::get(
    273       context, {llvm::ConstantAsMetadata::get(kernel),
    274                 llvm::MDString::get(context, "kernel"),
    275                 llvm::ConstantAsMetadata::get(ir_builder_.getInt32(1))}));
    276 
    277   // Update the insert point to the entry basic block.
    278   llvm::BasicBlock* entry_bb =
    279       llvm::BasicBlock::Create(context, /*Name=*/"entry", /*Parent=*/kernel);
    280 
    281   // Emit a "return void" at entry_bb's end, and set the insert point before
    282   // that return instruction.
    283   ir_builder_.SetInsertPoint(llvm::ReturnInst::Create(context, entry_bb));
    284 
    285   return kernel;
    286 }
    287 
    288 Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) {
    289   thunk_sequence_->emplace_back(BuildKernelThunk(hlo));
    290   return IrEmitter::DefaultAction(hlo);
    291 }
    292 
    293 Status IrEmitterUnnested::HandleDot(HloInstruction* dot) {
    294   const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
    295   if (dnums.lhs_batch_dimensions_size() > 0 ||
    296       dnums.rhs_batch_dimensions_size() > 0) {
    297     return Unimplemented("Dot with batch dimensions not implemented.");
    298   }
    299   if (ImplementedAsGemm(*dot)) {
    300     thunk_sequence_->emplace_back(BuildGemmThunk(dot));
    301     return Status::OK();
    302   }
    303   thunk_sequence_->emplace_back(BuildKernelThunk(dot));
    304   return IrEmitter::HandleDot(dot);
    305 }
    306 
    307 Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) {
    308   thunk_sequence_->emplace_back(BuildConditionalThunk(conditional));
    309   return Status::OK();
    310 }
    311 
    312 Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution) {
    313   thunk_sequence_->emplace_back(BuildKernelThunk(convolution));
    314   return IrEmitter::HandleConvolution(convolution);
    315 }
    316 
    317 Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
    318   // A CustomCall on the GPU backend can either be a custom-call to a
    319   // user-supplied kernel, or a call into a library like cudnn.
    320 
    321   // Lower custom-calls to cudnn batchnorm ops to specialized thunks.  It's part
    322   // of the contract of these cudnn batchnorm calls that the epsilon and
    323   // feature_index operands be constants.
    324   if (custom_call->custom_call_target() ==
    325       kCudnnBatchNormForwardInferenceCallTarget) {
    326     const HloInstruction* epsilon = custom_call->operand(5);
    327     CHECK(epsilon->IsConstant());
    328     float epsilon_value = epsilon->literal().Get<float>({});
    329 
    330     const HloInstruction* feature_index = custom_call->operand(6);
    331     CHECK(feature_index->IsConstant());
    332     int64 feature_index_value = feature_index->literal().Get<int64>({});
    333 
    334     thunk_sequence_->emplace_back(
    335         MakeUnique<CudnnBatchNormForwardInferenceThunk>(
    336             /*operand=*/GetAllocationSlice(*custom_call->operand(0)),
    337             /*scale=*/GetAllocationSlice(*custom_call->operand(1)),
    338             /*offset=*/GetAllocationSlice(*custom_call->operand(2)),
    339             /*mean=*/GetAllocationSlice(*custom_call->operand(3)),
    340             /*variance=*/GetAllocationSlice(*custom_call->operand(4)),
    341             /*epsilon=*/epsilon_value,
    342             /*feature_index=*/feature_index_value,
    343             /*output=*/GetAllocationSlice(*custom_call),
    344             /*hlo=*/custom_call));
    345     return Status::OK();
    346   }
    347 
    348   if (custom_call->custom_call_target() ==
    349       kCudnnBatchNormForwardTrainingCallTarget) {
    350     const HloInstruction* epsilon = custom_call->operand(3);
    351     CHECK(epsilon->IsConstant());
    352     float epsilon_value = epsilon->literal().Get<float>({});
    353 
    354     const HloInstruction* feature_index = custom_call->operand(4);
    355     CHECK(feature_index->IsConstant());
    356     int64 feature_index_value = feature_index->literal().Get<int64>({});
    357 
    358     // BatchNormTraining returns a tuple of three elements: data, calculated
    359     // mean, and calculated 1/sqrt(variance + epsilon).
    360     const auto& assn = ir_emitter_context_->buffer_assignment();
    361     auto output_data = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie();
    362     auto output_mean = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie();
    363     auto output_inv_stddev = assn.GetUniqueSlice(custom_call, {2}).ValueOrDie();
    364     thunk_sequence_->emplace_back(
    365         MakeUnique<CudnnBatchNormForwardTrainingThunk>(
    366             /*operand=*/GetAllocationSlice(*custom_call->operand(0)),
    367             /*scale=*/GetAllocationSlice(*custom_call->operand(1)),
    368             /*offset=*/GetAllocationSlice(*custom_call->operand(2)),
    369             /*epsilon=*/epsilon_value,
    370             /*feature_index=*/feature_index_value,
    371             /*output_data=*/output_data,
    372             /*output_mean=*/output_mean,
    373             /*output_inv_stddev=*/output_inv_stddev,
    374             /*output_tuple=*/GetAllocationSlice(*custom_call),
    375             /*hlo=*/custom_call));
    376     return Status::OK();
    377   }
    378 
    379   if (custom_call->custom_call_target() == kCudnnBatchNormBackwardCallTarget) {
    380     const HloInstruction* epsilon = custom_call->operand(5);
    381     CHECK(epsilon->IsConstant());
    382     float epsilon_value = epsilon->literal().Get<float>({});
    383 
    384     const HloInstruction* feature_index = custom_call->operand(6);
    385     CHECK(feature_index->IsConstant());
    386     int64 feature_index_value = feature_index->literal().Get<int64>({});
    387 
    388     // BatchNormGrad returns a tuple of three elements: grad_data, grad_scale,
    389     // grad_offset.
    390     const auto& assn = ir_emitter_context_->buffer_assignment();
    391     auto output_grad_data = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie();
    392     auto output_grad_scale = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie();
    393     auto output_grad_offset =
    394         assn.GetUniqueSlice(custom_call, {2}).ValueOrDie();
    395     thunk_sequence_->emplace_back(MakeUnique<CudnnBatchNormBackwardThunk>(
    396         /*operand=*/GetAllocationSlice(*custom_call->operand(0)),
    397         /*scale=*/GetAllocationSlice(*custom_call->operand(1)),
    398         /*mean=*/GetAllocationSlice(*custom_call->operand(2)),
    399         /*inv_stddev=*/GetAllocationSlice(*custom_call->operand(3)),
    400         /*grad_output=*/GetAllocationSlice(*custom_call->operand(4)),
    401         /*epsilon=*/epsilon_value,
    402         /*feature_index=*/feature_index_value,
    403         /*output_grad_data=*/output_grad_data,
    404         /*output_grad_scale=*/output_grad_scale,
    405         /*output_grad_offset=*/output_grad_offset,
    406         /*output_tuple=*/GetAllocationSlice(*custom_call),
    407         /*hlo=*/custom_call));
    408     return Status::OK();
    409   }
    410 
    411   if (IsCustomCallToDnnConvolution(*custom_call)) {
    412     const auto& assn = ir_emitter_context_->buffer_assignment();
    413     const auto& lhs_shape = custom_call->operand(0)->shape();
    414     const auto& rhs_shape = custom_call->operand(1)->shape();
    415     const auto& conv_result_shape = custom_call->shape().tuple_shapes(0);
    416     auto lhs_slice = GetAllocationSlice(*custom_call->operand(0));
    417     auto rhs_slice = GetAllocationSlice(*custom_call->operand(1));
    418     auto tuple_result_slice = GetAllocationSlice(*custom_call);
    419     auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie();
    420     auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie();
    421 
    422     const HloInstruction* algorithm_inst = custom_call->operand(2);
    423     CHECK(algorithm_inst->IsConstant()) << algorithm_inst->ToString();
    424     int64 algorithm = algorithm_inst->literal().Get<int64>({});
    425 
    426     const HloInstruction* tensor_ops_enabled_inst = custom_call->operand(3);
    427     CHECK(tensor_ops_enabled_inst->IsConstant())
    428         << tensor_ops_enabled_inst->ToString();
    429     bool tensor_ops_enabled = tensor_ops_enabled_inst->literal().Get<bool>({});
    430 
    431     const auto& target = custom_call->custom_call_target();
    432     std::unique_ptr<ConvolutionThunk> thunk;
    433     if (target == kCudnnConvForwardCallTarget) {
    434       thunk = MakeUnique<ConvolutionThunk>(
    435           CudnnConvKind::kForward,
    436           /*input_buffer=*/lhs_slice,
    437           /*filter_buffer=*/rhs_slice,
    438           /*output_buffer=*/conv_result_slice,
    439           /*tuple_result_buffer=*/tuple_result_slice,
    440           /*scratch_buffer=*/scratch_slice,
    441           /*input_shape=*/lhs_shape,
    442           /*filter_shape=*/rhs_shape,
    443           /*output_shape=*/conv_result_shape,  //
    444           custom_call->window(), custom_call->convolution_dimension_numbers(),
    445           algorithm, tensor_ops_enabled, custom_call);
    446     } else if (target == kCudnnConvBackwardInputCallTarget) {
    447       thunk = MakeUnique<ConvolutionThunk>(
    448           CudnnConvKind::kBackwardInput,
    449           /*input_buffer=*/conv_result_slice,
    450           /*filter_buffer=*/rhs_slice,
    451           /*output_buffer=*/lhs_slice,
    452           /*tuple_result_buffer=*/tuple_result_slice,
    453           /*scratch_buffer=*/scratch_slice,
    454           /*input_shape=*/conv_result_shape,
    455           /*filter_shape=*/rhs_shape,
    456           /*output_shape=*/lhs_shape,  //
    457           custom_call->window(), custom_call->convolution_dimension_numbers(),
    458           algorithm, tensor_ops_enabled, custom_call);
    459     } else if (target == kCudnnConvBackwardFilterCallTarget) {
    460       thunk = MakeUnique<ConvolutionThunk>(
    461           CudnnConvKind::kBackwardFilter,
    462           /*input_buffer=*/lhs_slice,
    463           /*filter_buffer=*/conv_result_slice,
    464           /*output_buffer=*/rhs_slice,
    465           /*tuple_result_buffer=*/tuple_result_slice,
    466           /*scratch_buffer=*/scratch_slice,
    467           /*input_shape=*/lhs_shape,
    468           /*filter_shape=*/conv_result_shape,
    469           /*output_shape=*/rhs_shape,  //
    470           custom_call->window(), custom_call->convolution_dimension_numbers(),
    471           algorithm, tensor_ops_enabled, custom_call);
    472     } else {
    473       LOG(FATAL) << "Unexpected custom call target: "
    474                  << custom_call->custom_call_target();
    475     }
    476 
    477     thunk_sequence_->emplace_back(std::move(thunk));
    478     return Status::OK();
    479   }
    480 
    481   return IrEmitter::HandleCustomCall(custom_call);
    482 }
    483 
    484 Status IrEmitterUnnested::HandleFft(HloInstruction* fft) {
    485   TF_RET_CHECK(
    486       LayoutUtil::IsMonotonicWithDim0Major(fft->operand(0)->shape().layout()));
    487   TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(fft->shape().layout()));
    488   thunk_sequence_->emplace_back(BuildFftThunk(fft));
    489   return Status::OK();
    490 }
    491 
    492 Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
    493   HloInstruction* root = fusion->fused_expression_root();
    494   // HandleFusion specializes reduction from a multi-dimensional array to a 1D
    495   // array. The specialized version requires a initializer thunk that
    496   // initializes the output array to the initial value of the reduce.
    497   if (HloInstruction::FusionKind::kInput == fusion->fusion_kind()) {
    498     switch (root->opcode()) {
    499       case HloOpcode::kReduce: {
    500         VLOG(3) << "Emitting fused reduction to vector: " << fusion->ToString();
    501         std::vector<std::unique_ptr<Thunk>> thunks;
    502         thunks.emplace_back(BuildKernelThunk(fusion));
    503         TF_RETURN_IF_ERROR(EmitInitializer(
    504             fusion, static_cast<KernelThunk*>(thunks.back().get())));
    505         bindings_.UnbindAllLocalIrValues();
    506         thunks.emplace_back(BuildKernelThunk(fusion));
    507         thunk_sequence_->emplace_back(
    508             MakeUnique<SequentialThunk>(std::move(thunks), fusion));
    509         std::vector<llvm_ir::IrArray> parameter_arrays;
    510         for (HloInstruction* operand : fusion->operands()) {
    511           parameter_arrays.push_back(GetIrArray(*operand, *fusion));
    512         }
    513         GpuElementalIrEmitter elemental_emitter(
    514             hlo_module_config_, ir_emitter_context_->llvm_module(),
    515             &ir_builder_, GetNestedComputer());
    516         FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter);
    517         TF_RETURN_IF_ERROR(root->Accept(&fused_emitter));
    518 
    519         Shape input_shape = root->operand(0)->shape();
    520         // EmitReductionToVector requires the input shape to have a layout, but
    521         // fused instructions don't have one. So we determine its layout from
    522         // the fusion's operands. The choice of the layout only affects
    523         // performance but not correctness.
    524         auto choose_input_layout = [](
    525             tensorflow::gtl::ArraySlice<const HloInstruction*> operands,
    526             Shape* input_shape) -> Status {
    527           // Prefer the layout of an operand whose shape is compatible with
    528           // input_shape.
    529           for (const HloInstruction* operand : operands) {
    530             if (ShapeUtil::Compatible(*input_shape, operand->shape())) {
    531               return LayoutUtil::CopyLayoutBetweenShapes(operand->shape(),
    532                                                          input_shape);
    533             }
    534           }
    535           // If no operand has a compatible shape, prefer an operand that has
    536           // the same rank at least.
    537           for (const HloInstruction* operand : operands) {
    538             if (ShapeUtil::Rank(*input_shape) ==
    539                 ShapeUtil::Rank(operand->shape())) {
    540               // Do not use CopyLayoutBetweenShapes because input_shape and
    541               // operand->shape() may be incompatible.
    542               *input_shape->mutable_layout() = operand->shape().layout();
    543               return Status::OK();
    544             }
    545           }
    546           // When all the above fails, which is rare, set the default layout.
    547           LayoutUtil::SetToDefaultLayout(input_shape);
    548           return Status::OK();
    549         };
    550         TF_RETURN_IF_ERROR(
    551             choose_input_layout(fusion->operands(), &input_shape));
    552 
    553         return EmitReductionToVector(
    554             root, input_shape, fused_emitter.GetGenerator(root->operand(0)),
    555             fused_emitter.GetGenerator(root->operand(1)), root->dimensions(),
    556             root->to_apply());
    557       }
    558       default:
    559         LOG(FATAL) << "Bad opcode for input fusion: "
    560                    << fusion->fused_expression_root()->opcode();
    561     }
    562   } else if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(
    563                  fusion, ir_emitter_context_->buffer_assignment())) {
    564     // Fusion node with dynamic-update-slice as the root where the op's input
    565     // (i.e. array to update) shares the same slice as its output.  In this case
    566     // we have a special algorithm that modifies the output in place without
    567     // touching the un-updated elements.
    568 
    569     // Set up kernel thunk and fused ir emitter.
    570     thunk_sequence_->emplace_back(BuildKernelThunk(fusion));
    571     std::vector<llvm_ir::IrArray> operand_arrays;
    572     for (HloInstruction* operand : fusion->operands()) {
    573       operand_arrays.push_back(GetIrArray(*operand, *fusion));
    574     }
    575     GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
    576                                             ir_emitter_context_->llvm_module(),
    577                                             &ir_builder_, GetNestedComputer());
    578 
    579     // Shape of the dynamic-update-slice's "update" operand.
    580     Shape update_shape = root->operand(1)->shape();
    581 
    582     // Array to write into.  Because this is an in-place operation, this is the
    583     // same as operand 0's array.
    584     llvm_ir::IrArray output_array = GetIrArray(*fusion, *fusion);
    585 
    586     LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
    587         update_shape, ir_emitter_context_->device_description());
    588     CHECK(Thunk::Kind::kKernel == LastThunk()->kind());
    589     UpdateLaunchDimensions(launch_dimensions,
    590                            static_cast<KernelThunk*>(LastThunk()),
    591                            ir_emitter_context_->llvm_module());
    592 
    593     return llvm_ir::EmitParallelFusedDynamicUpdateSliceInPlace(
    594         fusion, operand_arrays, output_array, &elemental_emitter,
    595         launch_dimensions, &ir_builder_);
    596   }
    597   if (ImplementedAsGemm(*fusion)) {
    598     thunk_sequence_->emplace_back(BuildGemmThunk(fusion));
    599     return Status::OK();
    600   }
    601   thunk_sequence_->emplace_back(BuildKernelThunk(fusion));
    602   return IrEmitter::HandleFusion(fusion);
    603 }
    604 
    605 namespace {
    606 
    607 // Returns the indices of the first elements of all consecutive subarrays of the
    608 // given array. For example:
    609 // ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4}
    610 std::vector<size_t> ConsecutiveSegments(tensorflow::gtl::ArraySlice<int64> xs) {
    611   std::vector<size_t> is = {0};
    612   for (size_t i = 1; i < xs.size(); ++i) {
    613     if (1 != xs[i] - xs[i - 1]) {
    614       is.push_back(i);
    615     }
    616   }
    617   return is;
    618 }
    619 
    620 // Merges the sequences of dimensions of the given shape which start at the
    621 // given indices `segs`.
    622 Shape MergeDimensions(tensorflow::gtl::ArraySlice<size_t> segs,
    623                       const Shape& shape) {
    624   std::vector<int64> dimensions;
    625   for (size_t i = 1; i <= segs.size(); ++i) {
    626     dimensions.push_back(std::accumulate(
    627         shape.dimensions().begin() + segs[i - 1],
    628         shape.dimensions().begin() +
    629             (segs.size() == i ? shape.dimensions().size() : segs[i]),
    630         1, std::multiplies<int64>()));
    631   }
    632   return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
    633                                                   dimensions);
    634 }
    635 
    636 // Returns whether the given shapes and permutation are a 0-2-1 transpose, and
    637 // if so, the normalized and rank-reduced shapes. The shapes must have the same
    638 // dimensions, so this considers layout only.
    639 //
    640 // This function recognizes higher-rank transposes which are elementwise
    641 // equivalent to a 0-2-1 transpose.
    642 std::tuple<bool, Shape, Shape> IsTranspose021(const Shape& a, const Shape& b) {
    643   CHECK(ShapeUtil::Compatible(a, b));
    644   std::vector<int64> perm(a.dimensions().size());
    645   {
    646     auto layout_a_orig = LayoutUtil::MinorToMajor(a);
    647     std::vector<int64> layout_a(layout_a_orig.rbegin(), layout_a_orig.rend());
    648     auto layout_b_orig = LayoutUtil::MinorToMajor(b);
    649     std::vector<int64> layout_b(layout_b_orig.rbegin(), layout_b_orig.rend());
    650     for (size_t i = 0; i < perm.size(); ++i) {
    651       perm[i] = PositionInContainer(layout_b, layout_a[i]);
    652     }
    653   }
    654   auto segs = ConsecutiveSegments(perm);
    655   Shape norm_a =
    656       ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a);
    657   Shape norm_b =
    658       ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(b);
    659   if (3 == segs.size() && 0 == perm[0]) {
    660     Shape reduced_a = MergeDimensions(segs, norm_a);
    661     Shape reduced_b = ShapeUtil::MakeShapeWithDescendingLayout(
    662         b.element_type(),
    663         Permute({0, 2, 1}, AsInt64Slice(reduced_a.dimensions())));
    664     return std::make_tuple(true, reduced_a, reduced_b);
    665   }
    666   return std::make_tuple(false, ShapeUtil::MakeNil(), ShapeUtil::MakeNil());
    667 }
    668 
    669 // Returns whether the given shapes are potentially of a 0-2-1 transpose.
    670 // As 0-2-1 is a self-inverse permutation, which shape is input or output is
    671 // arbitrary.
    672 bool AreShapesForTranspose021(const Shape& a, const Shape& b) {
    673   return 3 == b.dimensions().size() &&
    674          ShapeUtil::Compatible(
    675              ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a),
    676              ShapeUtil::PermuteDimensions(
    677                  {0, 2, 1},
    678                  ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
    679                      b)));
    680 }
    681 
    682 // Emits a tiled 0-2-1 transpose, assuming both input and output lain out from
    683 // major to minor. The x- and y- dimensions are tiled in square tiles of edge
    684 // length `tile_size`. Each thread block of `tile_size` x `num_rows` threads
    685 // transposes one tile: each thread copies a row from the input to a shared
    686 // memory tile, then copies a column from the shared memory tile to the output.
    687 //
    688 // `tile_size` should usually be same as warp size.
    689 //
    690 // Returns (number of tiles = number of thread blocks needed).
    691 //
    692 // TODO(b/33320379): Here each block transposes 1 tile. It may be more efficient
    693 //                   to launch fewer blocks so each transposes many tiles, and
    694 //                   in any case, the number of blocks we can launch is limited.
    695 //
    696 // This is the same algorithm in CUDA:
    697 // https://github.com/tensorflow/tensorflow/blob/d2693c8a70567cc78b2e8a9ac8020d321620ca83/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc#L189
    698 int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output,
    699                             const int64 tile_size, const int64 num_rows,
    700                             llvm::IRBuilder<>* builder) {
    701   // Adds `addend` to the given `dim` of `index`.
    702   auto offset_dim = [builder](llvm_ir::IrArray::Index index,
    703                               llvm::Value* addend, int64 dim) {
    704     index[dim] = builder->CreateAdd(index[dim], addend);
    705     return index;
    706   };
    707 
    708   CHECK(AreShapesForTranspose021(input.GetShape(), output.GetShape()));
    709 
    710   Shape input_shape =
    711       ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
    712           input.GetShape());
    713   Shape output_shape =
    714       ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
    715           output.GetShape());
    716   input = input.CastToShape(input_shape, builder);
    717   output = output.CastToShape(output_shape, builder);
    718 
    719   llvm::Type* tile_type = llvm::ArrayType::get(
    720       llvm::ArrayType::get(input.GetElementLlvmType(), tile_size),
    721       // One extra here to avoid share memory bank conflict
    722       tile_size + 1);
    723   auto* tile = new llvm::GlobalVariable(
    724       *builder->GetInsertBlock()->getParent()->getParent(), tile_type,
    725       /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage,
    726       llvm::UndefValue::get(tile_type), "tile", nullptr,
    727       llvm::GlobalValue::NotThreadLocal,
    728       /*AddressSpace=*/3 /* GPU shared memory */);
    729 
    730   // let x = threadIdx.x
    731   llvm::Value* x = llvm_ir::EmitCallToIntrinsic(
    732       llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, builder);
    733   llvm_ir::AddRangeMetadata(0, num_rows * tile_size,
    734                             static_cast<llvm::Instruction*>(x));
    735   x = builder->CreateIntCast(x, builder->getInt64Ty(), /*isSigned=*/true,
    736                              "thread.id.x");
    737 
    738   // computing logical thread ids
    739   // logical_x = x % tile_size
    740   auto logical_x = builder->CreateURem(x, builder->getInt64(tile_size));
    741 
    742   // logical_y = x / tile_size
    743   auto logical_y = builder->CreateUDiv(x, builder->getInt64(tile_size));
    744 
    745   // `emit_cp` emits equivalent to following pseudocode:
    746   // if (tile_size == tile_width && tile_size == tile_height) {
    747   //   unroll for (i in range(0, tile_size, num_rows)) {
    748   //     emit_cp_element(index + {0, i, 0}, y + logical_y);
    749   //   }
    750   // } else if (x < tile_width) {
    751   //   tile_height_upperbound = ceil(tile_height / num_rows) * num_rows;
    752   //   for (i in range(0, tile_height_upperbound, num_rows)) {
    753   //     y_loc = i + logical_y;
    754   //     if (y_loc < tile_height)
    755   //      emit_cp_element(index + {0, i, 0}, y_loc);
    756   //   }
    757   // }
    758   //
    759   // We use this to emit both the copy from input to tile and the copy from tile
    760   // to output.
    761   //
    762   // `index` is the origin of the row or column in the input or output array.
    763   //
    764   // `emit_cp_element(index, y)` emits code to copy a single element between the
    765   // tile and the input or output array, where `y` is the `y`-position in the
    766   // tile, whether which is row or column is a function of whether we're copying
    767   // from input or to output, and `index` is the index into the input or output
    768   // array.
    769   auto emit_cp_tile = [builder, tile_size, &offset_dim, num_rows, logical_x,
    770                        logical_y](
    771                           std::function<void(const llvm_ir::IrArray::Index&,
    772                                              llvm::Value*)>
    773                               emit_cp_element,
    774                           llvm::Value* tile_width, llvm::Value* tile_height,
    775                           const llvm_ir::IrArray::Index& index,
    776                           const string& loop_name) {
    777     llvm_ir::LlvmIfData if_not_last_row = llvm_ir::EmitIfThenElse(
    778         builder->CreateAnd(
    779             builder->CreateICmpEQ(builder->getInt64(tile_size), tile_width),
    780             builder->CreateICmpEQ(builder->getInt64(tile_size), tile_height)),
    781         "not_last_row", builder);
    782     builder->SetInsertPoint(if_not_last_row.true_block->getTerminator());
    783     for (int64 i = 0; i < tile_size; i += num_rows) {
    784       auto source_idx = offset_dim(index, builder->getInt64(i), /*dim=*/1);
    785       auto y_loc = builder->CreateAdd(builder->getInt64(i), logical_y);
    786       emit_cp_element(source_idx, y_loc);
    787     }
    788     builder->SetInsertPoint(if_not_last_row.false_block->getTerminator());
    789     llvm_ir::LlvmIfData if_in_tile = llvm_ir::EmitIfThenElse(
    790         builder->CreateICmpULT(logical_x, tile_width), "x_in_tile", builder);
    791     builder->SetInsertPoint(if_in_tile.true_block->getTerminator());
    792 
    793     // tile_height_upper_bound = ceil(tile_height / num_rows) * num_rows
    794     auto tile_height_upper_bound = builder->CreateMul(
    795         builder->CreateUDiv(
    796             builder->CreateAdd(tile_height, builder->getInt64(num_rows - 1)),
    797             builder->getInt64(num_rows)),
    798         builder->getInt64(num_rows));
    799 
    800     auto loop = llvm_ir::ForLoop::EmitForLoop(
    801         loop_name, builder->getInt64(0), tile_height_upper_bound,
    802         builder->getInt64(num_rows), builder);
    803     llvm_ir::SetToFirstInsertPoint(loop->GetHeaderBasicBlock(), builder);
    804     builder->SetInsertPoint(loop->GetBodyBasicBlock()->getTerminator());
    805 
    806     auto y_loc = builder->CreateAdd(loop->GetIndVarValue(), logical_y);
    807     auto if_y_in_tile = llvm_ir::EmitIfThenElse(
    808         builder->CreateICmpULT(y_loc, tile_height), "y_in_tile", builder);
    809     builder->SetInsertPoint(if_y_in_tile.true_block->getTerminator());
    810 
    811     emit_cp_element(offset_dim(index, loop->GetIndVarValue(), /*dim=*/1),
    812                     y_loc);
    813     builder->SetInsertPoint(if_not_last_row.after_block->getTerminator());
    814   };
    815 
    816   auto input_dims_in_tiles = input_shape.dimensions();
    817   // Unpermuted dimensions are untiled.
    818   for (int i = 1; i < 3; ++i) {
    819     input_dims_in_tiles[i] =
    820         CeilOfRatio<int64>(input_dims_in_tiles[i], tile_size);
    821   }
    822   int64 num_tiles =
    823       std::accumulate(input_dims_in_tiles.begin(), input_dims_in_tiles.end(), 1,
    824                       std::multiplies<int64>());
    825   const llvm_ir::IrArray::Index input_tile_index(
    826       /*linear=*/builder->CreateIntCast(
    827           llvm_ir::AddRangeMetadata(
    828               0, num_tiles,
    829               static_cast<llvm::Instruction*>(llvm_ir::EmitCallToIntrinsic(
    830                   llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {},
    831                   builder))),
    832           builder->getInt64Ty(), /*isSigned=*/true, "block.id.x"),
    833       ShapeUtil::MakeShapeWithDescendingLayout(
    834           PRED /*arbitrary*/, AsInt64Slice(input_dims_in_tiles)),
    835       builder);
    836   const llvm_ir::IrArray::Index input_tile_origin = ({
    837     llvm_ir::IrArray::Index index = input_tile_index;
    838     for (int i = 1; i < 3; ++i) {
    839       index[i] = builder->CreateMul(index[i], builder->getInt64(tile_size),
    840                                     "tile_origin." + std::to_string(i));
    841     }
    842     index;
    843   });
    844   const llvm_ir::IrArray::Index input_index =
    845       offset_dim(offset_dim(input_tile_origin, logical_x, /*dim=*/2), logical_y,
    846                  /*dim=*/1);
    847   std::vector<llvm::Value*> tile_dims(input_shape.dimensions().size());
    848   // Only last row or column may not have full size.
    849   for (int i = 1; i < 3; ++i) {
    850     tile_dims[i] = builder->CreateSelect(
    851         builder->CreateICmpEQ(input_tile_index[i],
    852                               builder->getInt64(input_dims_in_tiles[i] - 1)),
    853         builder->getInt64(input_shape.dimensions(i) -
    854                           (input_dims_in_tiles[i] - 1) * tile_size),
    855         builder->getInt64(tile_size), "tile_size");
    856   }
    857 
    858   // Load data from input memory to shared memory tile.
    859   emit_cp_tile(
    860       // tile[y, x] = input_array[index]
    861       [builder, tile, &input, logical_x](const llvm_ir::IrArray::Index& index,
    862                                          llvm::Value* y) {
    863         builder->CreateStore(
    864             input.EmitReadArrayElement(index, builder, "input_element"),
    865             builder->CreateGEP(tile, {builder->getInt64(0), y, logical_x}));
    866       },
    867       tile_dims[2], tile_dims[1], input_index, "input");
    868 
    869   // Wait for all threads to reach this point, lest we copy a value from tile to
    870   // output before the other thread copies it from input to tile.
    871   // This is `__syncthreads` in CUDA.
    872   llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, builder);
    873 
    874   const llvm_ir::IrArray::Index output_tile_index(
    875       Permute({0, 2, 1}, input_tile_index.multidim()));
    876   const llvm_ir::IrArray::Index output_tile_origin(
    877       Permute({0, 2, 1}, input_tile_origin.multidim()));
    878   const llvm_ir::IrArray::Index output_index =
    879       offset_dim(offset_dim(output_tile_origin, logical_x, /*dim=*/2),
    880                  logical_y, /*dim=*/1);
    881 
    882   // Store data from shared memory tile to output memory.
    883   emit_cp_tile(
    884       // output_array[index] = tile[x, y]
    885       [builder, tile, &output, logical_x](const llvm_ir::IrArray::Index& index,
    886                                           llvm::Value* y) {
    887         output.EmitWriteArrayElement(
    888             index,
    889             builder->CreateLoad(
    890                 builder->CreateGEP(tile, {builder->getInt64(0), logical_x, y}),
    891                 "output_element"),
    892             builder);
    893       },
    894       tile_dims[1], tile_dims[2], output_index, "output");
    895 
    896   return num_tiles;
    897 }
    898 
    899 }  // namespace
    900 
    901 Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
    902   if (ImplementedAsHostToDeviceMemcpy(ir_emitter_context_->buffer_assignment(),
    903                                       *copy)) {
    904     thunk_sequence_->emplace_back(BuildHostToDeviceCopyThunk(copy));
    905     return Status::OK();
    906   }
    907   if (ImplementedAsDeviceToDeviceMemcpy(
    908           ir_emitter_context_->buffer_assignment(), *copy)) {
    909     thunk_sequence_->emplace_back(BuildDeviceToDeviceCopyThunk(copy));
    910     return Status::OK();
    911   }
    912   bool is_transpose_021;
    913   Shape reduced_input_shape, reduced_output_shape;
    914   std::tie(is_transpose_021, reduced_input_shape, reduced_output_shape) =
    915       IsTranspose021(copy->operand(0)->shape(), copy->shape());
    916   if (is_transpose_021 &&
    917       reduced_input_shape.dimensions(1) >= kMinDimensionToTransposeTiled &&
    918       reduced_input_shape.dimensions(2) >= kMinDimensionToTransposeTiled) {
    919     thunk_sequence_->emplace_back(BuildKernelThunk(copy));
    920     VLOG(3) << "Emitting tiled 0-2-1 transposition";
    921     constexpr int64 tile_size = 32;
    922     constexpr int64 num_rows = 8;
    923     int64 num_tiles = EmitTranspose021Tiled(
    924         GetIrArray(*copy->operand(0), *copy)
    925             .CastToShape(reduced_input_shape, &ir_builder_),
    926         GetIrArray(*copy, *copy)
    927             .CastToShape(reduced_output_shape, &ir_builder_),
    928         tile_size, num_rows, &ir_builder_);
    929     UpdateLaunchDimensions(LaunchDimensions(num_tiles, num_rows * tile_size),
    930                            LastThunk(), ir_emitter_context_->llvm_module());
    931     return Status::OK();
    932   }
    933 
    934   return IrEmitter::HandleCopy(copy);
    935 }
    936 
    937 Status IrEmitterUnnested::EmitReductionToScalar(
    938     HloInstruction* reduce, const Shape& input_shape,
    939     const llvm_ir::ElementGenerator& input_gen,
    940     const llvm_ir::ElementGenerator& init_value_gen, HloComputation* reducer) {
    941   // Number of elements processed by a single thread.
    942   constexpr int64 kTileSize = 16;
    943   int64 num_elems = ShapeUtil::ElementsIn(input_shape);
    944 
    945   // Round up the number of tiles to a multiple of the warp size.  This is
    946   // necessary for correctness.  We launch one thread per tile, and if the
    947   // number of threads isn't a multiple of the number of the warp size, our
    948   // shuffles will read from inactive threads, producing undefined values.
    949   int64 num_tiles =
    950       RoundUpToNearest(CeilOfRatio(num_elems, kTileSize), kWarpSize);
    951 
    952   // Check whether every thread will process a full tile's worth of elements
    953   // without reading outside the bounds of the input.  If this is true, we can
    954   // skip some bounds checks in the final algorithm.
    955   bool all_threads_in_bounds = num_tiles * kTileSize == num_elems;
    956 
    957   // __global__ void full_reduce_kernel() {
    958   //   x_in_tiles = threadIdx.x + blockIdx.x * blockDim.x;
    959   //   x = x_in_tiles * kTileSize;
    960   //
    961   //   partial_result = init_value;
    962   //   if (all_threads_in_bounds || x + kTileSize <= num_elems) {
    963   //     for (i = 0; i < kTileSize; ++i) {
    964   //       partial_result = Reducer(partial_result, input[x + i]);
    965   //     }
    966   //   } else {
    967   //     for (i = 0; i < kTileSize; ++i) {
    968   //       if (x + i < num_elems) {
    969   //         partial_result = Reducer(partial_result, input[x + i]);
    970   //       }
    971   //     }
    972   //   }
    973   //   for (i = warpSize / 2; i > 0; i /= 2) {
    974   //     partial_result = Reducer(partial_result,
    975   //                              __shfl_down(partial_result, i));
    976   //   }
    977   //   if (lane_id == 0) {
    978   //     AtomicReducer(&output[y], partial_result);
    979   //   }
    980   // }
    981   //
    982   // // Choose num_blocks and threads_per_block such that:
    983   // //
    984   // //   num_blocks * threads_per_block =
    985   // //     RoundUpToNextMultipleOf(Ceil(num_elems / kTileSize), warpSize),
    986   // //
    987   // // and threads_per_block is a multiple of warpSize.
    988   // reduce_kernel<<<num_blocks, threads_per_block>>>();
    989   //
    990   auto loop_body_emitter =
    991       [=](const llvm_ir::IrArray::Index& tile_index) -> Status {
    992     llvm::Type* element_ir_type =
    993         llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_);
    994     llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
    995         element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result");
    996     {
    997       TF_ASSIGN_OR_RETURN(llvm::Value * init_ir_value,
    998                           init_value_gen(llvm_ir::IrArray::Index({})));
    999       ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address);
   1000     }
   1001 
   1002     llvm::Value* x_in_tiles = tile_index[0];
   1003 
   1004     // Emit an inner for-loop that reduces the elements in the tile.
   1005     auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status {
   1006       std::unique_ptr<llvm_ir::ForLoop> tile_element_loop =
   1007           llvm_ir::ForLoop::EmitForLoop("element_id_in_tile",
   1008                                         ir_builder_.getInt64(0),
   1009                                         ir_builder_.getInt64(kTileSize),
   1010                                         ir_builder_.getInt64(1), &ir_builder_);
   1011 
   1012       // Emit the body of the partial reduction loop.
   1013       llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(),
   1014                                      &ir_builder_);
   1015       llvm::Value* x = ir_builder_.CreateNSWAdd(
   1016           ir_builder_.CreateNSWMul(x_in_tiles, ir_builder_.getInt64(kTileSize)),
   1017           tile_element_loop->GetIndVarValue());
   1018       // Unless we know the tile is entirely in bounds, we have to emit a
   1019       // x-in-bounds check before reading from the input.
   1020       if (!tile_in_bounds) {
   1021         llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
   1022             ir_builder_.CreateICmpULT(x, ir_builder_.getInt64(num_elems)),
   1023             "x_in_bounds", &ir_builder_);
   1024 
   1025         // Emit code that reads the input element and accumulates it to
   1026         // the partial reduction result.
   1027         llvm_ir::SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
   1028       }
   1029       llvm_ir::IrArray::Index input_index(
   1030           /*linear=*/x, input_shape, &ir_builder_);
   1031       llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type);
   1032       TF_ASSIGN_OR_RETURN(llvm::Value * input_ir_value, input_gen(input_index));
   1033       ir_builder_.CreateStore(input_ir_value, input_address);
   1034       return (EmitCallToNestedComputation(
   1035           *reducer, {partial_reduction_result_address, input_address},
   1036           partial_reduction_result_address));
   1037     };
   1038 
   1039     // x_end = kTileSize + x_in_tiles * kTileSize, i.e., the location that's
   1040     // immediately beyond the tile.
   1041     llvm::Value* x_end = ir_builder_.CreateNSWAdd(
   1042         ir_builder_.getInt64(kTileSize),
   1043         ir_builder_.CreateNSWMul(x_in_tiles, ir_builder_.getInt64(kTileSize)));
   1044     // The tile is entirely in bound if all_threads_in_bounds or
   1045     // x_end <= num_elems.
   1046     llvm::Value* tile_in_bounds = ir_builder_.CreateOr(
   1047         ir_builder_.CreateICmpULE(x_end, ir_builder_.getInt64(num_elems)),
   1048         ir_builder_.getInt1(all_threads_in_bounds));
   1049     llvm_ir::LlvmIfData if_tile_in_bounds_data =
   1050         llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_);
   1051     llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block,
   1052                                    &ir_builder_);
   1053     TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true));
   1054     llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block,
   1055                                    &ir_builder_);
   1056     TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false));
   1057 
   1058     // After the if-then-else statement on tile_in_bounds, emit calls to
   1059     // shfl_down that accumulate the partial reduction results of all threads
   1060     // from the warp.
   1061     llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block,
   1062                                    &ir_builder_);
   1063     int bit_width = llvm_ir::GetSizeInBits(element_ir_type);
   1064     // bitcast cannot be applied to aggregate types (even packed ones), so we
   1065     // instead bitcast addresses of load/store to intN* of the same bit-width.
   1066     llvm::Type* shuffle_ir_type = element_ir_type->isStructTy()
   1067                                       ? ir_builder_.getIntNTy(bit_width)
   1068                                       : element_ir_type;
   1069     for (int shuffle_distance = kWarpSize / 2; shuffle_distance >= 1;
   1070          shuffle_distance /= 2) {
   1071       llvm::Value* partial_reduction_result = ir_builder_.CreateLoad(
   1072           ir_builder_.CreateBitCast(partial_reduction_result_address,
   1073                                     shuffle_ir_type->getPointerTo()),
   1074           "partial_reduction_result");
   1075       llvm::Value* result_from_other_lane = ir_builder_.CreateAlloca(
   1076           element_ir_type, nullptr, "result_from_other_lane");
   1077       ir_builder_.CreateStore(
   1078           EmitShuffleDown(partial_reduction_result,
   1079                           ir_builder_.getInt32(shuffle_distance), &ir_builder_),
   1080           ir_builder_.CreateBitCast(result_from_other_lane,
   1081                                     shuffle_ir_type->getPointerTo()));
   1082       TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
   1083           *reducer, {partial_reduction_result_address, result_from_other_lane},
   1084           partial_reduction_result_address));
   1085     }
   1086 
   1087     const HloInstruction* output =
   1088         reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce;
   1089 
   1090     // Emit an atomic operation that accumulates the partial reduction result of
   1091     // lane 0 (which holds the partially accumulated result for its warp) to the
   1092     // output element.
   1093     llvm::Value* lane_id = ir_builder_.CreateURem(
   1094         x_in_tiles, ir_builder_.getInt64(kWarpSize), "lane_id");
   1095     llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse(
   1096         ir_builder_.CreateICmpEQ(lane_id, ir_builder_.getInt64(0)),
   1097         "lane_id_is_zero", &ir_builder_);
   1098     llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block,
   1099                                    &ir_builder_);
   1100     llvm::Value* output_address =
   1101         GetIrArray(*output, *output)
   1102             .EmitArrayElementAddress(
   1103                 llvm_ir::IrArray::Index(/*linear=*/ir_builder_.getInt64(0),
   1104                                         output->shape(), &ir_builder_),
   1105                 &ir_builder_, "output_element_address");
   1106     return EmitAtomicOperationForNestedComputation(
   1107         *reducer, output_address, partial_reduction_result_address);
   1108   };
   1109 
   1110   // Emit a parallel loop that iterates through all input tiles, one per thread.
   1111   Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout(
   1112       reduce->shape().element_type(), {num_tiles}, {0});
   1113   LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
   1114       tiled_input_shape, ir_emitter_context_->device_description());
   1115   CHECK(LastThunk()->kind() == Thunk::Kind::kSequential);
   1116   UpdateLaunchDimensions(
   1117       launch_dimensions,
   1118       static_cast<SequentialThunk*>(LastThunk())->thunks().back().get(),
   1119       ir_emitter_context_->llvm_module());
   1120   return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape,
   1121                              launch_dimensions, &ir_builder_)
   1122       .EmitLoop(IrName(reduce));
   1123 }
   1124 
   1125 Status IrEmitterUnnested::EmitColumnReduction(
   1126     int64 height, int64 width, HloInstruction* reduce, const Shape& input_shape,
   1127     const llvm_ir::ElementGenerator& input_gen,
   1128     const llvm_ir::ElementGenerator& init_value_gen, HloComputation* reducer) {
   1129   // Divide the input matrix into tiles of size Kx1. For example, when the
   1130   // input matrix is 4x4 and K=2, the tiled matrix looks like
   1131   //
   1132   //   0123
   1133   //   0123
   1134   //   4567
   1135   //   4567  // Numbers indicate tile IDs.
   1136   //
   1137   // Each tile is first partially reduced to a scalar by a thread, and then the
   1138   // scalar is accumulated to the output vector using atomic operations. We
   1139   // choose 16 as the tile size, which matches Eigen's ColumnReduceKernel.
   1140   constexpr int64 kTileSize = 16;
   1141   // If the height is not a multiple of the tile size, we pad the bottom of the
   1142   // input matrix.
   1143   const int64 height_in_tiles = CeilOfRatio(height, kTileSize);
   1144 
   1145   // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x;
   1146   //      linear_index < height_in_tiles * width;
   1147   //      linear_index += blockDim.x * gridDim.x) {
   1148   //   y_in_tiles = linear_index / width;
   1149   //   x = linear_index % width;
   1150   //
   1151   //   partial_result = init_value;
   1152   //   if (height % kTileSize == 0 ||
   1153   //       y_in_tiles * kTileSize + kTileSize <= height) {
   1154   //     for (element_id_in_tile : range(kTileSize)) {
   1155   //       y = y_in_tiles * kTileSize + element_id_in_tile;
   1156   //       partial_result = Reducer(partial_result, input[y][x]);
   1157   //     }
   1158   //   } else {
   1159   //     for (element_id_in_tile : range(kTileSize)) {
   1160   //       y = y_in_tiles * kTileSize + element_id_in_tile;
   1161   //       if (y < height) {
   1162   //         partial_result = Reducer(partial_result, input[y][x]);
   1163   //       }
   1164   //     }
   1165   //   }
   1166   //   AtomicReducer(&output[x], partial_result);
   1167   // }
   1168   auto loop_body_emitter =
   1169       [=](const llvm_ir::IrArray::Index& tile_index) -> Status {
   1170     // Emit the loop body that reduces one tile.
   1171     llvm::Type* element_ir_type =
   1172         llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_);
   1173     llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
   1174         element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result");
   1175     {
   1176       TF_ASSIGN_OR_RETURN(llvm::Value * init_ir_value,
   1177                           init_value_gen(llvm_ir::IrArray::Index({})));
   1178       ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address);
   1179     }
   1180 
   1181     // Emit an inner for-loop that partially reduces the elements in the given
   1182     // tile.
   1183     llvm::Value* y_in_tiles = tile_index[0];
   1184     llvm::Value* x = tile_index[1];
   1185 
   1186     auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status {
   1187       std::unique_ptr<llvm_ir::ForLoop> tile_element_loop =
   1188           llvm_ir::ForLoop::EmitForLoop("element_id_in_tile",
   1189                                         ir_builder_.getInt64(0),
   1190                                         ir_builder_.getInt64(kTileSize),
   1191                                         ir_builder_.getInt64(1), &ir_builder_);
   1192 
   1193       // Emit the body of the partial reduction loop.
   1194       llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(),
   1195                                      &ir_builder_);
   1196       llvm::Value* y = ir_builder_.CreateNSWAdd(
   1197           ir_builder_.CreateNSWMul(y_in_tiles, ir_builder_.getInt64(kTileSize)),
   1198           tile_element_loop->GetIndVarValue());
   1199       // Unless we know the tile is entirely in bounds, we have to emit a
   1200       // y-in-bounds check before reading from the input.
   1201       if (!tile_in_bounds) {
   1202         llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
   1203             ir_builder_.CreateICmpULT(y, ir_builder_.getInt64(height)),
   1204             "y_in_bounds", &ir_builder_);
   1205 
   1206         // Emit code that reads the input element and accumulates it to
   1207         // the partial reduction result.
   1208         llvm_ir::SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
   1209       }
   1210       llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type);
   1211       {
   1212         // {y,x} is an index to input_matrix_shape [height,width]. We need to
   1213         // convert that to an index to input_shape (the shape of the operand of
   1214         // "reduce"). This conversion is composed of a transposition from
   1215         // input_shape to normalized_input_shape and a reshape from
   1216         // normalized_input_shape to input_matrix_shape.
   1217         const Shape normalized_input_shape =
   1218             ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
   1219                 input_shape);
   1220         auto input_shape_min2maj = LayoutUtil::MinorToMajor(input_shape);
   1221         const std::vector<int64> transpose_dimension_mapping(
   1222             input_shape_min2maj.rbegin(), input_shape_min2maj.rend());
   1223 
   1224         const Shape input_matrix_shape =
   1225             ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(),
   1226                                                      {height, width});
   1227         const llvm_ir::IrArray::Index input_matrix_index(
   1228             {y, x}, input_matrix_shape, &ir_builder_);
   1229         const llvm_ir::IrArray::Index input_index =
   1230             input_matrix_index
   1231                 .SourceIndexOfReshape(input_matrix_shape,
   1232                                       normalized_input_shape, &ir_builder_)
   1233                 .SourceIndexOfTranspose(normalized_input_shape, input_shape,
   1234                                         transpose_dimension_mapping,
   1235                                         &ir_builder_);
   1236         TF_ASSIGN_OR_RETURN(llvm::Value * input_ir_value,
   1237                             input_gen(input_index));
   1238         ir_builder_.CreateStore(input_ir_value, input_address);
   1239       }
   1240       return (EmitCallToNestedComputation(
   1241           *reducer, {partial_reduction_result_address, input_address},
   1242           partial_reduction_result_address));
   1243     };
   1244 
   1245     // y_end = kTileSize + y_in_tiles * kTileSize, i.e., the y location that's
   1246     // immediately beyond the tile.
   1247     llvm::Value* y_end = ir_builder_.CreateNSWAdd(
   1248         ir_builder_.getInt64(kTileSize),
   1249         ir_builder_.CreateNSWMul(y_in_tiles, ir_builder_.getInt64(kTileSize)));
   1250     llvm::Value* tile_in_bounds = ir_builder_.CreateOr(
   1251         ir_builder_.CreateICmpULE(y_end, ir_builder_.getInt64(height)),
   1252         ir_builder_.getInt1(height % kTileSize == 0));
   1253     // The tile is entirely in bound if "height" is a multiple of kTileSize or
   1254     // y_end <= height.
   1255     llvm_ir::LlvmIfData if_tile_in_bounds_data =
   1256         llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_);
   1257     llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block,
   1258                                    &ir_builder_);
   1259     TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true));
   1260     llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block,
   1261                                    &ir_builder_);
   1262     TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false));
   1263 
   1264     // After the if-then-else statement on tile_in_bounds, emit atomic
   1265     // operations to accumulate the partial reduction result to the output
   1266     // element.
   1267     llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block,
   1268                                    &ir_builder_);
   1269     const HloInstruction* output =
   1270         reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce;
   1271     llvm::Value* output_address =
   1272         GetIrArray(*output, *output)
   1273             .EmitArrayElementAddress(
   1274                 llvm_ir::IrArray::Index(x, output->shape(), &ir_builder_),
   1275                 &ir_builder_, "output_element_address");
   1276     return EmitAtomicOperationForNestedComputation(
   1277         *reducer, output_address, partial_reduction_result_address);
   1278   };
   1279 
   1280   // Emit a parallel loop that iterate through all input tiles.
   1281   Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout(
   1282       reduce->shape().element_type(), {height_in_tiles, width}, {1, 0});
   1283   LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
   1284       tiled_input_shape, ir_emitter_context_->device_description());
   1285   CHECK(LastThunk()->kind() == Thunk::Kind::kSequential);
   1286   UpdateLaunchDimensions(
   1287       launch_dimensions,
   1288       static_cast<SequentialThunk*>(LastThunk())->thunks().back().get(),
   1289       ir_emitter_context_->llvm_module());
   1290   return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape,
   1291                              launch_dimensions, &ir_builder_)
   1292       .EmitLoop(IrName(reduce));
   1293 }
   1294 
   1295 Status IrEmitterUnnested::EmitRowReduction(
   1296     int64 depth, int64 height, int64 width, HloInstruction* reduce,
   1297     const Shape& input_shape, const llvm_ir::ElementGenerator& input_gen,
   1298     const llvm_ir::ElementGenerator& init_value_gen, HloComputation* reducer) {
   1299   // A naive algorithm is:
   1300   // 1. Divide the input tensor into tiles of size 1x1xK.
   1301   // 2. Partially reduces each tile to a scalar using one thread.
   1302   // 3. Accumulates that scalar to the output vector using atomic operations.
   1303   //
   1304   // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x;
   1305   //      linear_index < depth * height * width_in_tiles;
   1306   //      linear_index += blockDim.x * gridDim.x) {
   1307   //   int x_in_tiles = linear_index % width_in_tiles;
   1308   //   int y = linear_index / width_in_tiles % height;
   1309   //   int z = linear_index / (height * width_in_tiles);
   1310   //   float partial_result = 0;
   1311   //   for (element_id_in_tile : range(kTileSize)) {
   1312   //     int x = x_in_tiles * kTileSize + element_id_in_tile;
   1313   //     if (x < width)
   1314   //       partial_result = reducer(partial_result, input[z][y][z]);
   1315   //   }
   1316   //   AtomicReducer(&output[y], partial_result);
   1317   // }
   1318   //
   1319   // Three optimizations are performed.
   1320   //
   1321   // 1. To coalesce global memory accesses, dilate the tile with a factor of 32
   1322   // (i.e. the warp size). For example, suppose the width is 8x32=256. Instead
   1323   // of making each tile consecutive, we let make tile 0 column
   1324   // [0,32,64,...,224], tile 1 column [1,33,65,...,225], and so on. This ensures
   1325   // that threads in a warp access consecutive memory in one iteration (i.e.
   1326   // coalesced). In the above example, the warp that contains thread 0-31
   1327   // accesses column 0-31 in the first iteration, and 32-63 in the second
   1328   // iteration, and so on.
   1329   //
   1330   // 2. Partially accumulate partial reduced results computed by threads in the
   1331   // same warp using shfl_down. Using shfl_down is faster than directly using
   1332   // atomic operations because shfl_down transfers the data between threads
   1333   // using shared memory and threads in the same warp run in lock step (thus no
   1334   // extra synchronization needed). See
   1335   // https://devblogs.nvidia.com/parallelforall/faster-parallel-reductions-kepler/
   1336   // for details. The downside is, to produce correct results when using
   1337   // shfl_down, we need to guarantee threads in the same warp work on input
   1338   // elements with the same y, so the number of tiles in each row must be a
   1339   // multiple of 32.
   1340   //
   1341   // 3. Specialize the case that the entire tile is in bounds. When that is
   1342   // true, we don't need to emit "if(x<width)" inside the loop on
   1343   // element_id_in_tile, which makes the code more friendly to optimizations
   1344   // such as LICM.
   1345   //
   1346   // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x;
   1347   //      linear_index < depth * height * width_in_tiles;
   1348   //      linear_index += blockDim.x * gridDim.x) {
   1349   //   int x_in_tiles = linear_index % width_in_tiles;
   1350   //   int y = linear_index / width_in_tiles % height;
   1351   //   int z = linear_index / (height * width_in_tiles);
   1352   //   int warp_id = x_in_tiles / warpSize;
   1353   //   int lane_id = x_in_tiles % warpSize;
   1354   //   float partial_result = 0;
   1355   //   int x = warp_id * kTileSize * warpSize + lane_id;
   1356   //   if (width % (kTileSize * warpSize) == 0 ||
   1357   //       x + (kTileSize - 1) * warpSize < width) {
   1358   //     // The entire tile is in bounds.
   1359   //     for (int element_id_in_tile = 0; element_id_in_tile < kTileSize;
   1360   //        ++element_id_in_tile, x += warpSize) {
   1361   //       partial_result = Reducer(partial_result, input[z][y][x]);
   1362   //     }
   1363   //   } else {
   1364   //     // The tile is partially in bounds.
   1365   //     for (int element_id_in_tile = 0; element_id_in_tile < kTileSize;
   1366   //          ++element_id_in_tile, x += warpSize) {
   1367   //       if (x < width)
   1368   //         partial_result = Reducer(partial_result, input[z][y][x]);
   1369   //     }
   1370   //   }
   1371   //   for (shuffle_distance = 16; shuffle_distance > 0; shuffle_distance /= 2)
   1372   //     partial_result = Reducer(
   1373   //         partial_result,
   1374   //         __shfl_down_sync(CUDA_WARP_ALL, partial_result, shuffle_distance));
   1375   //   if (lane_id == 0)
   1376   //     AtomicReducer(&output[y], partial_result);
   1377   // }
   1378   //
   1379   // Choose 8 as the tile size, which matches Eigen's RowReduceKernel.
   1380   constexpr int64 kTileSize = 8;
   1381   // Round the width in tiles up to the nearest multiple of kWarpSize, so that
   1382   // the use of shfl_down is valid.
   1383   const int64 width_in_tiles =
   1384       RoundUpToNearest(CeilOfRatio(width, kTileSize), kWarpSize);
   1385 
   1386   auto loop_body_emitter =
   1387       [=](const llvm_ir::IrArray::Index& tile_index) -> Status {
   1388     // Emit the loop body that reduces one tile.
   1389     llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType(
   1390         input_shape.element_type(), ir_emitter_context_->llvm_module());
   1391     llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
   1392         element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result");
   1393     {
   1394       TF_ASSIGN_OR_RETURN(llvm::Value * init_ir_value,
   1395                           init_value_gen(llvm_ir::IrArray::Index({})));
   1396       ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address);
   1397     }
   1398 
   1399     // Emit an inner for-loop that partially reduces the elements in the given
   1400     // tile.
   1401     llvm::Value* z = tile_index[0];
   1402     llvm::Value* y = tile_index[1];
   1403     llvm::Value* x_tile = tile_index[2];
   1404     llvm::Value* warp_id = ir_builder_.CreateUDiv(
   1405         x_tile, ir_builder_.getInt64(kWarpSize), "warp_id");
   1406     llvm::Value* lane_id = ir_builder_.CreateURem(
   1407         x_tile, ir_builder_.getInt64(kWarpSize), "lane_id");
   1408 
   1409     // The x-location of the last element in this tile.
   1410     //   last_x = lane_id + warpSize * (kTileSize - 1 + warp_id * kTileSize);
   1411     llvm::Value* last_x = ir_builder_.CreateNSWAdd(
   1412         lane_id,
   1413         ir_builder_.CreateNSWMul(
   1414             ir_builder_.getInt64(kWarpSize),
   1415             ir_builder_.CreateNSWAdd(
   1416                 ir_builder_.getInt64(kTileSize - 1),
   1417                 ir_builder_.CreateNSWMul(warp_id,
   1418                                          ir_builder_.getInt64(kTileSize)))));
   1419 
   1420     auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status {
   1421       std::unique_ptr<llvm_ir::ForLoop> tile_element_loop =
   1422           llvm_ir::ForLoop::EmitForLoop("element_id_in_tile",
   1423                                         ir_builder_.getInt64(0),
   1424                                         ir_builder_.getInt64(kTileSize),
   1425                                         ir_builder_.getInt64(1), &ir_builder_);
   1426 
   1427       // Emit the body of the partial reduction loop.
   1428       llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(),
   1429                                      &ir_builder_);
   1430       // x = lane_id + warpSize * (element_id_in_tile + warp_id * kTileSize);
   1431       llvm::Value* x = ir_builder_.CreateNSWAdd(
   1432           lane_id,
   1433           ir_builder_.CreateNSWMul(
   1434               ir_builder_.getInt64(kWarpSize),
   1435               ir_builder_.CreateNSWAdd(
   1436                   tile_element_loop->GetIndVarValue(),
   1437                   ir_builder_.CreateNSWMul(warp_id,
   1438                                            ir_builder_.getInt64(kTileSize)))));
   1439 
   1440       // Unless we know the tile is entirely in bounds, we have to emit a
   1441       // x-in-bounds check before reading from the input.
   1442       if (!tile_in_bounds) {
   1443         llvm_ir::LlvmIfData if_x_in_bounds_data = llvm_ir::EmitIfThenElse(
   1444             ir_builder_.CreateICmpULT(x, ir_builder_.getInt64(width)),
   1445             "x_in_bounds", &ir_builder_);
   1446 
   1447         // Points ir_builder_ to the then-block.
   1448         llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block,
   1449                                        &ir_builder_);
   1450       }
   1451 
   1452       // Emit code that reads the input element and accumulates it to the
   1453       // partial reduction result.
   1454       llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type);
   1455       {
   1456         // {z,y,x} is an index to input_3d_tensor_shape [depth,height,width]. We
   1457         // need to convert that to an index to input_shape (the shape of the
   1458         // operand of "reduce"). This conversion is composed of a transposition
   1459         // from input_shape to normalized_input_shape and a reshape from
   1460         // normalized_input_shape to input_3d_tensor_shape.
   1461         const Shape normalized_input_shape =
   1462             ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
   1463                 input_shape);
   1464         auto input_shape_min2maj = LayoutUtil::MinorToMajor(input_shape);
   1465         const std::vector<int64> transpose_dimension_mapping(
   1466             input_shape_min2maj.rbegin(), input_shape_min2maj.rend());
   1467         const Shape input_3d_tensor_shape =
   1468             ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(),
   1469                                                      {depth, height, width});
   1470         const llvm_ir::IrArray::Index input_3d_tensor_index(
   1471             {z, y, x}, input_3d_tensor_shape, &ir_builder_);
   1472         const llvm_ir::IrArray::Index input_index =
   1473             input_3d_tensor_index
   1474                 .SourceIndexOfReshape(input_3d_tensor_shape,
   1475                                       normalized_input_shape, &ir_builder_)
   1476                 .SourceIndexOfTranspose(normalized_input_shape, input_shape,
   1477                                         transpose_dimension_mapping,
   1478                                         &ir_builder_);
   1479         TF_ASSIGN_OR_RETURN(llvm::Value * input_ir_value,
   1480                             input_gen(input_index));
   1481         ir_builder_.CreateStore(input_ir_value, input_address);
   1482       }
   1483       return EmitCallToNestedComputation(
   1484           *reducer, {partial_reduction_result_address, input_address},
   1485           partial_reduction_result_address);
   1486     };
   1487 
   1488     llvm::Value* tile_in_bounds = ir_builder_.CreateOr(
   1489         ir_builder_.getInt1(width % (kTileSize * kWarpSize) == 0),
   1490         ir_builder_.CreateICmpULT(last_x, ir_builder_.getInt64(width)));
   1491     llvm_ir::LlvmIfData if_tile_in_bounds_data =
   1492         llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_);
   1493     llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block,
   1494                                    &ir_builder_);
   1495     TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true));
   1496     llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block,
   1497                                    &ir_builder_);
   1498     TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false));
   1499 
   1500     // After the if-then-else statement on tile_in_bounds, emit calls to
   1501     // shfl_down that accumulate the partial reduction results of all threads
   1502     // from the warp.
   1503     llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block,
   1504                                    &ir_builder_);
   1505     int bit_width = llvm_ir::GetSizeInBits(element_ir_type);
   1506     // bitcast cannot be applied to aggregate types (even packed ones), so we
   1507     // instead bitcast addresses of load/store to intN* of the same bit-width.
   1508     llvm::Type* shuffle_ir_type = element_ir_type->isStructTy()
   1509                                       ? ir_builder_.getIntNTy(bit_width)
   1510                                       : element_ir_type;
   1511     for (int shuffle_distance = 16; shuffle_distance >= 1;
   1512          shuffle_distance /= 2) {
   1513       llvm::Value* partial_reduction_result = ir_builder_.CreateLoad(
   1514           ir_builder_.CreateBitCast(partial_reduction_result_address,
   1515                                     shuffle_ir_type->getPointerTo()),
   1516           "partial_reduction_result");
   1517       llvm::Value* result_from_other_lane = ir_builder_.CreateAlloca(
   1518           element_ir_type, nullptr, "result_from_other_lane");
   1519       ir_builder_.CreateStore(
   1520           EmitShuffleDown(partial_reduction_result,
   1521                           ir_builder_.getInt32(shuffle_distance), &ir_builder_),
   1522           ir_builder_.CreateBitCast(result_from_other_lane,
   1523                                     shuffle_ir_type->getPointerTo()));
   1524       TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
   1525           *reducer, {partial_reduction_result_address, result_from_other_lane},
   1526           partial_reduction_result_address));
   1527     }
   1528 
   1529     const HloInstruction* output =
   1530         reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce;
   1531 
   1532     // Emit an atomic operation that accumulates the partial reduction result of
   1533     // lane 0 (which holds the partially accumulated result for its warp) to the
   1534     // output element.
   1535     llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse(
   1536         ir_builder_.CreateICmpEQ(lane_id, ir_builder_.getInt64(0)),
   1537         "lane_id_is_zero", &ir_builder_);
   1538     llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block,
   1539                                    &ir_builder_);
   1540     llvm::Value* output_address =
   1541         GetIrArray(*output, *output)
   1542             .EmitArrayElementAddress(
   1543                 llvm_ir::IrArray::Index(y, output->shape(), &ir_builder_),
   1544                 &ir_builder_, "output_element_address");
   1545     return EmitAtomicOperationForNestedComputation(
   1546         *reducer, output_address, partial_reduction_result_address);
   1547   };
   1548 
   1549   // Emit a parallel loop that iterates through every input tiles.
   1550   Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout(
   1551       reduce->shape().element_type(), {depth, height, width_in_tiles},
   1552       {2, 1, 0});
   1553   LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
   1554       tiled_input_shape, ir_emitter_context_->device_description());
   1555   CHECK(LastThunk()->kind() == Thunk::Kind::kSequential);
   1556   UpdateLaunchDimensions(
   1557       launch_dimensions,
   1558       static_cast<SequentialThunk*>(LastThunk())->thunks().back().get(),
   1559       ir_emitter_context_->llvm_module());
   1560   return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape,
   1561                              launch_dimensions, &ir_builder_)
   1562       .EmitLoop(IrName(reduce));
   1563 }
   1564 
   1565 // Figures out whether `reduce` is a row or column reduction, and which
   1566 // dimensions to reduce, and calls either `EmitRowReduction` or
   1567 // `EmitColumnReduction` as appropriate.
   1568 // Prerequisite: all the dimensions to keep are contiguous in the input layout
   1569 //               and, if `reduce` is fused, the fused subgraph is pure
   1570 //               elementwise.
   1571 Status IrEmitterUnnested::EmitReductionToVector(
   1572     HloInstruction* reduce, const Shape& input_shape,
   1573     const llvm_ir::ElementGenerator& input_gen,
   1574     const llvm_ir::ElementGenerator& init_value_gen,
   1575     tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
   1576     HloComputation* reducer) {
   1577   // This emission requires "reduce" to have an input layout. It is either set
   1578   // by LayoutAssignment (for a top-level kReduce) or by InstructionFusion (for
   1579   // a fused kReduce).
   1580   CHECK(input_shape.has_layout()) << "LayoutAssignment or InstructionFusion "
   1581                                      "doesn't set the input layout of "
   1582                                   << reduce->ToString();
   1583 
   1584   // Specialize multi-dimensional-array-to-vector reduction.
   1585   std::vector<int64> input_dims_to_keep;
   1586   for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape);
   1587        ++input_dim) {
   1588     if (std::find(dimensions_to_reduce.begin(), dimensions_to_reduce.end(),
   1589                   input_dim) == dimensions_to_reduce.end()) {
   1590       input_dims_to_keep.push_back(input_dim);
   1591     }
   1592   }
   1593 
   1594   // Sort the dimensions to keep from minor to major, to facilitate checking
   1595   // whether another dimension is major or minor of them.
   1596   std::sort(input_dims_to_keep.begin(), input_dims_to_keep.end(),
   1597             [&input_shape](int64 dim_a, int64 dim_b) {
   1598               return PositionInContainer(LayoutUtil::MinorToMajor(input_shape),
   1599                                          dim_a) <
   1600                      PositionInContainer(LayoutUtil::MinorToMajor(input_shape),
   1601                                          dim_b);
   1602             });
   1603   // Now, if output rank is at least 1, `input_dims_to_keep.front()` is
   1604   // minormost and `input_dims_to_keep.back()` is majormost.
   1605 
   1606   // If the dimensions to keep are minormost, emit a column reduction. As all
   1607   // the dimensions to keep are contiguous, by prerequisite of
   1608   // `EmitReductionToVector`, we only need to check whether the minormost
   1609   // dimension of the input is to keep.
   1610   if (input_dims_to_keep.empty()) {
   1611     return EmitReductionToScalar(reduce, input_shape, input_gen, init_value_gen,
   1612                                  reducer);
   1613   } else if (input_dims_to_keep.front() ==
   1614              LayoutUtil::Minor(input_shape.layout(), 0)) {
   1615     // Column reduction. Treat the result of "input" as a matrix whose width
   1616     // is the most minor dimension and height the product of other dimensions,
   1617     // and treat "reduce" as a column reduction of the input matrix.
   1618     const int64 width = ShapeUtil::ElementsIn(reduce->shape());
   1619     // "width" can be zero, so don't do
   1620     //   height = ShapeUtil::ElementsIn(input_shape) / width;
   1621     int64 height = 1;
   1622     for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape);
   1623          ++input_dim) {
   1624       if (!std::count(input_dims_to_keep.begin(), input_dims_to_keep.end(),
   1625                       input_dim)) {
   1626         height *= input_shape.dimensions(input_dim);
   1627       }
   1628     }
   1629     return EmitColumnReduction(height, width, reduce, input_shape, input_gen,
   1630                                init_value_gen, reducer);
   1631   } else {
   1632     // Reduce the row dimension of a matrix or reduce dimension 0 and 2 in a
   1633     // 3D tensor. The size of dimension 1 (the height) is the size of the
   1634     // dimension to keep, the size of dimension 0 (the depth) is the product
   1635     // of dimensions that are more major than the dimension to keep, and the
   1636     // size of dimension 2 (the width) is the product of more minor
   1637     // dimensions.
   1638     int64 depth = 1;
   1639     int64 width = 1;
   1640     for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape);
   1641          ++input_dim) {
   1642       if (PositionInContainer(LayoutUtil::MinorToMajor(input_shape),
   1643                               input_dim) >
   1644           PositionInContainer(LayoutUtil::MinorToMajor(input_shape),
   1645                               input_dims_to_keep.back())) {
   1646         depth *= input_shape.dimensions(input_dim);
   1647       } else if (PositionInContainer(LayoutUtil::MinorToMajor(input_shape),
   1648                                      input_dim) <
   1649                  PositionInContainer(LayoutUtil::MinorToMajor(input_shape),
   1650                                      input_dims_to_keep.front())) {
   1651         width *= input_shape.dimensions(input_dim);
   1652       }
   1653     }
   1654     const int64 height = ShapeUtil::ElementsIn(reduce->shape());
   1655     return EmitRowReduction(depth, height, width, reduce, input_shape,
   1656                             input_gen, init_value_gen, reducer);
   1657   }
   1658 }
   1659 
   1660 Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
   1661   auto input = reduce->operand(0);
   1662   auto init_value = reduce->operand(1);
   1663   tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce(reduce->dimensions());
   1664   HloComputation* reducer = reduce->to_apply();
   1665   // HandleReduce specializes reduction from a multi-dimensional array to a 1D
   1666   // array. The specialized version requires an initializer thunk that
   1667   // initializes the output array to the initial value of the reduce.
   1668   if (IsReductionToVector(*reduce) &&
   1669       // NVPTX backend can't do atomic cmpxchg any narrower than 32 bits
   1670       32 <= primitive_util::BitWidth(reduce->shape().element_type())) {
   1671     std::vector<std::unique_ptr<Thunk>> thunks;
   1672     thunks.emplace_back(BuildKernelThunk(reduce));
   1673     TF_RETURN_IF_ERROR(EmitInitializer(
   1674         reduce, static_cast<KernelThunk*>(thunks.back().get())));
   1675     bindings_.UnbindAllLocalIrValues();
   1676     thunks.emplace_back(BuildKernelThunk(reduce));
   1677     thunk_sequence_->emplace_back(
   1678         MakeUnique<SequentialThunk>(std::move(thunks), reduce));
   1679     return EmitReductionToVector(
   1680         reduce, input->shape(),
   1681         [&](const llvm_ir::IrArray::Index& index) {
   1682           return GetIrArray(*input, *reduce)
   1683               .EmitReadArrayElement(index, &ir_builder_);
   1684         },
   1685         [&](const llvm_ir::IrArray::Index& index) {
   1686           return GetIrArray(*init_value, *reduce)
   1687               .EmitReadArrayElement(index, &ir_builder_);
   1688         },
   1689         dimensions_to_reduce, reducer);
   1690   }
   1691 
   1692   thunk_sequence_->emplace_back(BuildKernelThunk(reduce));
   1693   return IrEmitter::HandleReduce(reduce);
   1694 }
   1695 
   1696 Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) {
   1697   bool all_tuple_elements_have_buffer =
   1698       c_all_of(tuple->operands(), [&](HloInstruction* tuple_element) {
   1699         return ir_emitter_context_->buffer_assignment().HasTopLevelAllocation(
   1700             tuple_element);
   1701       });
   1702   // Tuples (especially tuples that are the final result of a computation) can
   1703   // be so huge that if we were to emit a kernel that took each tuple element as
   1704   // a parameter, we would exceed the max allowable number of parameters to a
   1705   // GPU kernel, b/31336476. As an optimization, if all tuple elements have a
   1706   // buffer, we collect their buffer addresses in a host array, and then copy
   1707   // that array to the tuple's buffer.
   1708   //
   1709   // Some tuple elements (e.g. const or bitcast of const) might not have a
   1710   // buffer -- their contents are stored in code. In that case, we fall back to
   1711   // emitting kernels which have access to their buffer addresses in code.
   1712   if (all_tuple_elements_have_buffer) {
   1713     std::vector<BufferAllocation::Slice> tuple_element_buffers;
   1714     for (const HloInstruction* tuple_element : tuple->operands()) {
   1715       tuple_element_buffers.push_back(GetAllocationSlice(*tuple_element));
   1716     }
   1717     thunk_sequence_->emplace_back(MakeUnique<TupleThunk>(
   1718         tuple_element_buffers, GetAllocationSlice(*tuple), tuple));
   1719     return Status::OK();
   1720   }
   1721   thunk_sequence_->emplace_back(BuildKernelThunk(tuple));
   1722   return IrEmitter::HandleTuple(tuple);
   1723 }
   1724 
   1725 Status IrEmitterUnnested::HandleGetTupleElement(HloInstruction*) {
   1726   // GetTupleElement IR is emitted in the IR context of the user instruction,
   1727   // and so we do not build a kernel for GetTupleElement instructions.
   1728   return Status::OK();
   1729 }
   1730 
   1731 Status IrEmitterUnnested::HandleSelectAndScatter(
   1732     HloInstruction* select_and_scatter) {
   1733   CHECK_EQ(select_and_scatter->operand_count(), 3);
   1734   const auto* operand = select_and_scatter->operand(0);
   1735   const auto* source = select_and_scatter->operand(1);
   1736   const Window& window = select_and_scatter->window();
   1737   PrimitiveType operand_element_type = operand->shape().element_type();
   1738   const int64 rank = ShapeUtil::Rank(operand->shape());
   1739   CHECK_EQ(rank, ShapeUtil::Rank(source->shape()));
   1740   CHECK_EQ(rank, window.dimensions_size());
   1741 
   1742   {
   1743     std::vector<std::unique_ptr<Thunk>> thunks;
   1744     thunks.emplace_back(BuildKernelThunk(select_and_scatter));
   1745     TF_RETURN_IF_ERROR(EmitInitializer(
   1746         select_and_scatter, static_cast<KernelThunk*>(thunks.back().get())));
   1747     bindings_.UnbindAllLocalIrValues();
   1748     thunks.emplace_back(BuildKernelThunk(select_and_scatter));
   1749     thunk_sequence_->emplace_back(
   1750         MakeUnique<SequentialThunk>(std::move(thunks), select_and_scatter));
   1751   }
   1752 
   1753   // TODO(b/31410564): Implement dilation rate for select-and-scatter.
   1754   if (window_util::HasDilation(window)) {
   1755     return Unimplemented(
   1756         "Dilation for SelectAndScatter not implemented on GPU.");
   1757   }
   1758 
   1759   // kSelectAndScatter is implemented as two kernel launches: the first launch
   1760   // initializes the output array to the given initial value,
   1761   // and the second accumulates the "source" matrix to the
   1762   // selected elements in the output array. The first launch is already
   1763   // implemented by the initializer thunk generated earlier, so this function
   1764   // only needs to take care of the select-and-scatter part.
   1765   //
   1766   // Pseudo code for select-and-scatter:
   1767   //
   1768   // for (coordinates S in the source):  # This loop is parallel.
   1769   //   initialized_flag = false
   1770   //   for (coordinates W in the window):
   1771   //     I = S * stride + W - pad_low
   1772   //     if I within bounds of operand:
   1773   //       if !(initialized_flag and select(selected_value, operand(I))):
   1774   //         selected_value = operand(I)
   1775   //         selected_index = I
   1776   //         initialized_flag = true
   1777   //   output(selected_index) = scatter(output(selected_index), source(S))
   1778   auto loop_body_emitter =
   1779       [=](const llvm_ir::IrArray::Index& source_index) -> Status {
   1780     // Allocate space to keep the currently selected value, its index, and a
   1781     // boolean flag if the value is initialized. The initialized_flag is set
   1782     // false.
   1783     llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry(
   1784         llvm_ir::PrimitiveTypeToIrType(operand_element_type,
   1785                                        ir_emitter_context_->llvm_module()),
   1786         "selected_value_address", &ir_builder_);
   1787     llvm::Value* selected_index_address =
   1788         llvm_ir::EmitAllocaAtFunctionEntryWithCount(
   1789             ir_builder_.getInt64Ty(), ir_builder_.getInt32(rank),
   1790             "selected_index_address", &ir_builder_);
   1791     llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry(
   1792         ir_builder_.getInt1Ty(), "initialized_flag_address", &ir_builder_);
   1793     ir_builder_.CreateStore(ir_builder_.getInt1(false),
   1794                             initialized_flag_address);
   1795 
   1796     // Create the inner loop to iterate over the window.
   1797     llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "inner"),
   1798                                       &ir_builder_);
   1799     std::vector<int64> window_size;
   1800     for (const auto& dim : window.dimensions()) {
   1801       window_size.push_back(dim.size());
   1802       CHECK_GT(dim.size(), 0);
   1803     }
   1804     const llvm_ir::IrArray::Index window_index = window_loops.AddLoopsForShape(
   1805         ShapeUtil::MakeShape(operand_element_type, window_size), "window");
   1806     llvm_ir::SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(),
   1807                                    &ir_builder_);
   1808 
   1809     // Compute the operand index to visit and evaluate the condition whether the
   1810     // operand index is within the bounds. The unsigned comparison includes
   1811     // checking whether the operand index >= 0.
   1812     llvm_ir::IrArray::Index operand_index(source_index.size());
   1813     llvm::Value* in_bounds_condition = ir_builder_.getInt1(true);
   1814     for (int64 i = 0; i < rank; ++i) {
   1815       llvm::Value* strided_index = ir_builder_.CreateNSWMul(
   1816           source_index[i], ir_builder_.getInt64(window.dimensions(i).stride()));
   1817       operand_index[i] = ir_builder_.CreateNSWSub(
   1818           ir_builder_.CreateNSWAdd(strided_index, window_index[i]),
   1819           ir_builder_.getInt64(window.dimensions(i).padding_low()));
   1820       llvm::Value* index_condition = ir_builder_.CreateICmpULT(
   1821           operand_index[i],
   1822           ir_builder_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
   1823       in_bounds_condition =
   1824           ir_builder_.CreateAnd(in_bounds_condition, index_condition);
   1825     }
   1826     CHECK(in_bounds_condition != nullptr);
   1827 
   1828     // Only need to do something if the operand index is within the bounds.
   1829     // First check if the initialized_flag is set.
   1830     llvm_ir::LlvmIfData if_in_bounds =
   1831         llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &ir_builder_);
   1832     llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, &ir_builder_);
   1833     llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse(
   1834         ir_builder_.CreateLoad(initialized_flag_address), "initialized",
   1835         &ir_builder_);
   1836 
   1837     // If the initialized_flag is false, initialize the selected value and index
   1838     // with the currently visiting operand.
   1839     llvm_ir::SetToFirstInsertPoint(if_initialized.false_block, &ir_builder_);
   1840     const auto save_operand_index = [&](
   1841         const llvm_ir::IrArray::Index& operand_index) {
   1842       for (int64 i = 0; i < rank; ++i) {
   1843         llvm::Value* selected_index_address_slot =
   1844             ir_builder_.CreateInBoundsGEP(selected_index_address,
   1845                                           {ir_builder_.getInt32(i)});
   1846         ir_builder_.CreateStore(operand_index[i], selected_index_address_slot);
   1847       }
   1848     };
   1849     llvm_ir::IrArray operand_array = GetIrArray(*operand, *select_and_scatter);
   1850     llvm::Value* operand_data =
   1851         operand_array.EmitReadArrayElement(operand_index, &ir_builder_);
   1852     ir_builder_.CreateStore(operand_data, selected_value_address);
   1853     save_operand_index(operand_index);
   1854     ir_builder_.CreateStore(ir_builder_.getInt1(true),
   1855                             initialized_flag_address);
   1856 
   1857     // If the initialized_flag is true, call the `select` function to
   1858     // potentially update the selected value and index with the currently
   1859     // visiting operand.
   1860     llvm_ir::SetToFirstInsertPoint(if_initialized.true_block, &ir_builder_);
   1861     const Shape output_shape = ShapeUtil::MakeShape(PRED, {});
   1862     llvm::Value* operand_address =
   1863         operand_array.EmitArrayElementAddress(operand_index, &ir_builder_);
   1864     llvm::Value* select_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
   1865         llvm_ir::PrimitiveTypeToIrType(PRED,
   1866                                        ir_emitter_context_->llvm_module()),
   1867         "select_return_buffer", &ir_builder_);
   1868     TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
   1869         *select_and_scatter->select(),
   1870         {selected_value_address, operand_address}, select_return_buffer));
   1871     llvm::Value* result = ir_builder_.CreateLoad(select_return_buffer);
   1872 
   1873     // If the 'select' function returns false, update the selected value and the
   1874     // index to the currently visiting operand.
   1875     llvm::Value* cond = ir_builder_.CreateICmpNE(
   1876         result,
   1877         llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(
   1878                                    PRED, ir_emitter_context_->llvm_module()),
   1879                                0),
   1880         "boolean_predicate");
   1881     llvm_ir::LlvmIfData if_select_lhs =
   1882         llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &ir_builder_);
   1883     llvm_ir::SetToFirstInsertPoint(if_select_lhs.false_block, &ir_builder_);
   1884     ir_builder_.CreateStore(ir_builder_.CreateLoad(operand_address),
   1885                             selected_value_address);
   1886     save_operand_index(operand_index);
   1887 
   1888     // After iterating over the window elements, scatter the source element to
   1889     // the selected index of the output. The value we store at the output
   1890     // location is computed by calling the `scatter` function with the source
   1891     // value and the current output value.
   1892     llvm_ir::SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(),
   1893                                    &ir_builder_);
   1894     llvm_ir::IrArray::Index selected_index;
   1895     for (int64 i = 0; i < rank; ++i) {
   1896       llvm::Value* selected_index_address_slot = ir_builder_.CreateInBoundsGEP(
   1897           selected_index_address, {ir_builder_.getInt32(i)});
   1898       selected_index.push_back(
   1899           ir_builder_.CreateLoad(selected_index_address_slot));
   1900     }
   1901     llvm::Value* source_value_address =
   1902         GetIrArray(*source, *select_and_scatter)
   1903             .EmitArrayElementAddress(source_index, &ir_builder_);
   1904     llvm::Value* output_value_address =
   1905         GetIrArray(*select_and_scatter, *select_and_scatter)
   1906             .EmitArrayElementAddress(selected_index, &ir_builder_);
   1907     return EmitAtomicOperationForNestedComputation(
   1908         *select_and_scatter->scatter(), output_value_address,
   1909         source_value_address);
   1910   };
   1911 
   1912   LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
   1913       source->shape(), ir_emitter_context_->device_description());
   1914   UpdateLaunchDimensions(
   1915       launch_dimensions,
   1916       // IrEmitterUnnested implements kSelectAndScatter as a SequentialThunk
   1917       // consisting of two thunks, an initializer KernelThunk that initializes
   1918       // the output and another KernelThunk that accumulates the scattered
   1919       // elements.
   1920       static_cast<SequentialThunk*>(LastThunk())->thunks().back().get(),
   1921       ir_emitter_context_->llvm_module());
   1922   return ParallelLoopEmitter(loop_body_emitter, source->shape(),
   1923                              launch_dimensions, &ir_builder_)
   1924       .EmitLoop(IrName(select_and_scatter));
   1925 }
   1926 
   1927 Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while) {
   1928   HloComputation* condition = xla_while->while_condition();
   1929   TF_RET_CHECK(ShapeUtil::IsScalar(condition->root_instruction()->shape()) &&
   1930                condition->root_instruction()->shape().element_type() == PRED)
   1931       << "While condition computation must return bool";
   1932   // Build ForThunk for conformant while loops, otherwise build WhileThunk.
   1933   auto result = CanTransformWhileToFor(xla_while);
   1934   if (result.ok()) {
   1935     auto tuple = result.ConsumeValueOrDie();
   1936     // loop_trip_count = (limit - start + increment - 1) / increment
   1937     const int64 loop_trip_count =
   1938         (std::get<1>(tuple) - std::get<0>(tuple) + std::get<2>(tuple) - 1) /
   1939         std::get<2>(tuple);
   1940     thunk_sequence_->emplace_back(BuildForThunk(xla_while, loop_trip_count));
   1941     VLOG(3) << "Built ForThunk for while: " << xla_while->name();
   1942   } else {
   1943     thunk_sequence_->emplace_back(BuildWhileThunk(xla_while));
   1944     VLOG(3) << "Built WhileThunk for while: " << xla_while->name()
   1945             << " while-to-for transform status: " << result.status();
   1946   }
   1947   return Status::OK();
   1948 }
   1949 
   1950 Status IrEmitterUnnested::HandleRng(HloInstruction* random) {
   1951   thunk_sequence_->push_back(BuildKernelThunk(random));
   1952   return IrEmitter::HandleRng(random);
   1953 }
   1954 
   1955 Status IrEmitterUnnested::HandleSelect(HloInstruction* select) {
   1956   thunk_sequence_->push_back(BuildKernelThunk(select));
   1957   return IrEmitter::HandleSelect(select);
   1958 }
   1959 
   1960 Status IrEmitterUnnested::HandleInfeed(HloInstruction* infeed) {
   1961   thunk_sequence_->emplace_back(BuildInfeedThunk(infeed));
   1962   return Status::OK();
   1963 }
   1964 
   1965 // Figures out how to access the buffers for all subshapes of hlo's operands and
   1966 // for hlo itself (i.e. all the buffers produced by HLO).
   1967 //
   1968 // Returns a map keyed on the pair {HloInstruction, ShapeIndex}.  The value for
   1969 // this key is a pair {Slice, ShapeIndex}, where the slice tells you the root
   1970 // buffer to look in, and the ShapeIndex describes how to dereference starting
   1971 // at that buffer to get to the buffer in question.
   1972 //
   1973 // For example, if {hlo, {1}} is mapped to {slice, {3, 4}}, then the buffer for
   1974 // hlo at ShapeIndex {1} (i.e. the buffer for the second tuple element of hlo)
   1975 // is found at slice[3][4].  That is, slice is a void***, which we dereference
   1976 // twice -- first at index 3, and then at index 4 -- to get the address of our
   1977 // buffer.
   1978 //
   1979 // This function conservatively assumes that we'll touch all sub-buffers of
   1980 // every operand and of the output.
   1981 static std::map<std::pair<const HloInstruction*, ShapeIndex>,
   1982                 std::pair<BufferAllocation::Slice, ShapeIndex>>
   1983 GetHloBufferSlices(const HloInstruction* hlo,
   1984                    const BufferAssignment& buffer_assn) {
   1985   std::map<std::pair<const HloInstruction*, ShapeIndex>,
   1986            std::pair<BufferAllocation::Slice, ShapeIndex>>
   1987       slices;
   1988 
   1989   // Tries to find a slice plus an array of indices i1, ..., iN such that the
   1990   // sub-buffer for instr at index can be found at slice[i1]...[iN].
   1991   auto find_slice_for = [&](const HloInstruction* instr,
   1992                             const ShapeIndex& index)
   1993       -> optional<std::pair<BufferAllocation::Slice, ShapeIndex>> {
   1994     // Simple, common case: Is the buffer for instr known at runtime?  If so,
   1995     // we're done.
   1996     auto slice = GetKnownAtRuntimeSlice(instr, index, buffer_assn);
   1997     if (slice.has_value()) {
   1998       return {{*slice, ShapeIndex()}};
   1999     }
   2000 
   2001     // If we don't know the buffer for instr at index, see if we know the buffer
   2002     // for instr at index without its last element.  If so, we can dynamically
   2003     // find the buffer for instr by dereferencing a pointer in that buffer.
   2004     // Continue looking this way until we run out of elements in 'index'.
   2005     ShapeIndex new_index = index;
   2006     ShapeIndex gte_indices;
   2007     while (!new_index.empty()) {
   2008       gte_indices.push_front(new_index.back());
   2009       new_index.pop_back();
   2010       auto slice = GetKnownAtRuntimeSlice(instr, new_index, buffer_assn);
   2011       if (slice.has_value()) {
   2012         return {{*slice, gte_indices}};
   2013       }
   2014     }
   2015 
   2016     // If *that* didn't work, check whether instr is a GTE instruction.  If it
   2017     // is, see if we can get a buffer for its parent, and continue walking up
   2018     // parents until we find a defined buffer or we hit something that's not a
   2019     // GTE.
   2020     const HloInstruction* parent = instr;
   2021     while (parent->opcode() == HloOpcode::kGetTupleElement) {
   2022       gte_indices.push_front(parent->tuple_index());
   2023       parent = parent->operand(0);
   2024 
   2025       auto slice = GetKnownAtRuntimeSlice(parent, {}, buffer_assn);
   2026       if (slice.has_value()) {
   2027         return {{*slice, gte_indices}};
   2028       }
   2029     }
   2030 
   2031     return nullopt;
   2032   };
   2033 
   2034   // Adds entries for all subshapes of instr to `slices`.
   2035   auto add_slices_for = [&](const HloInstruction* instr) {
   2036     // GPU constants don't have buffers; don't bother looking for one.
   2037     if (instr->IsConstant()) {
   2038       return;
   2039     }
   2040 
   2041     ShapeUtil::ForEachSubshape(
   2042         instr->shape(), [&](const Shape& /*shape*/, const ShapeIndex& index) {
   2043           if (slices.count({instr, index})) {
   2044             // HLOs can have duplicate operands; don't bother redoing work.
   2045             return;
   2046           }
   2047           auto maybe_slice = find_slice_for(instr, index);
   2048           if (maybe_slice.has_value()) {
   2049             slices[{instr, index}] = *maybe_slice;
   2050           } else {
   2051             VLOG(1) << "Couldn't find buffer for " << instr->ToString()
   2052                     << " at index " << index.ToString();
   2053           }
   2054         });
   2055   };
   2056 
   2057   add_slices_for(hlo);
   2058   for (const HloInstruction* operand : hlo->operands()) {
   2059     // Conservatively assume we'll need the buffers for all subshapes of the
   2060     // operand.
   2061     add_slices_for(operand);
   2062   }
   2063 
   2064   return slices;
   2065 }
   2066 
   2067 Status IrEmitterUnnested::HandleGather(HloInstruction* gather) {
   2068   // TODO(b/72710576): Gather is not implemented on GPUs
   2069   return Unimplemented("Gather is not implemented on GPUs.");
   2070 }
   2071 
   2072 std::unique_ptr<Thunk> IrEmitterUnnested::BuildKernelThunk(
   2073     const HloInstruction* inst) {
   2074   const BufferAssignment& buffer_assn =
   2075       ir_emitter_context_->buffer_assignment();
   2076 
   2077   std::map<std::pair<const HloInstruction*, ShapeIndex>,
   2078            std::pair<BufferAllocation::Slice, ShapeIndex>>
   2079       hlo_slices = GetHloBufferSlices(inst, buffer_assn);
   2080 
   2081   // Figure out which buffer allocations need to be passed as arguments to our
   2082   // kernel.  This is simply all of the allocations referenced in hlo_slices,
   2083   // plus the XLA temp buffer (if we have it).  We always include the temp
   2084   // buffer because even if the kernel itself doesn't use it, a nested
   2085   // subcomputation within the kernel (e.g. a kMap's computation) might.
   2086   std::unordered_set<const BufferAllocation*> buffers_needed;
   2087   for (const auto& kv : hlo_slices) {
   2088     buffers_needed.insert(kv.second.first.allocation());
   2089   }
   2090   tensorflow::gtl::optional<const BufferAllocation*> temp_buffer;
   2091   for (const BufferAllocation& alloc : buffer_assn.Allocations()) {
   2092     if (alloc.IsPreallocatedTempBuffer()) {
   2093       if (!temp_buffer.has_value()) {
   2094         temp_buffer = &alloc;
   2095       } else {
   2096         LOG(FATAL) << "Multiple temp buffers found, but only one is allowed!";
   2097       }
   2098     }
   2099   }
   2100   if (temp_buffer.has_value()) {
   2101     buffers_needed.insert(*temp_buffer);
   2102   }
   2103 
   2104   // We'll pass a pointer to each of the elements of `buffers` to our kernel, in
   2105   // this order.
   2106   std::vector<const BufferAllocation*> buffers(buffers_needed.begin(),
   2107                                                buffers_needed.end());
   2108   std::sort(buffers.begin(), buffers.end(),
   2109             [](const BufferAllocation* a, const BufferAllocation* b) {
   2110               return a->index() < b->index();
   2111             });
   2112 
   2113   llvm::Function* kernel = BuildKernelPrototype(*inst, buffers);
   2114 
   2115   // Build a map from a BufferAllocation to the corresponding argument in our
   2116   // kernel.
   2117   std::unordered_map<const BufferAllocation*, llvm::Value*> kernel_args;
   2118   {
   2119     auto arg_it = kernel->arg_begin();
   2120     auto buffers_it = buffers.begin();
   2121     for (; arg_it != kernel->arg_end(); ++arg_it, ++buffers_it) {
   2122       kernel_args[*buffers_it] = arg_it;
   2123     }
   2124   }
   2125 
   2126   // For each buffer our kernel might want to touch, bind it to a value derived
   2127   // from our kernel args.
   2128   for (const auto& kv : hlo_slices) {
   2129     const HloInstruction* instr = kv.first.first;
   2130     const ShapeIndex& index = kv.first.second;
   2131     const BufferAllocation::Slice& slice = kv.second.first;
   2132     const ShapeIndex& gte_index = kv.second.second;
   2133 
   2134     VLOG(3) << "Buffer for " << instr->ToString() << " at " << index.ToString()
   2135             << " is found in slice " << slice.ToString() << " at GTE index "
   2136             << gte_index.ToString();
   2137 
   2138     llvm::Value* loc =
   2139         ir_builder_.CreateInBoundsGEP(kernel_args.at(slice.allocation()),
   2140                                       {ir_builder_.getInt64(slice.offset())});
   2141 
   2142     // If gte_index is nonempty, we have to dereference `loc` to get to the
   2143     // value we're ultimately interested in.
   2144     llvm::Type* int8_double_pointer =
   2145         llvm::PointerType::get(ir_builder_.getInt8PtrTy(), /*AddressSpace=*/0);
   2146     for (int64 idx : gte_index) {
   2147       loc = ir_builder_.CreateBitCast(loc, int8_double_pointer);
   2148       loc = ir_builder_.CreateLoad(
   2149           ir_builder_.CreateInBoundsGEP(loc, {ir_builder_.getInt64(idx)}));
   2150     }
   2151 
   2152     bindings_.BindHloToIrValue(*instr, loc, index);
   2153   }
   2154 
   2155   // Bind the temp buffer so that nested subcomputations can find it if they
   2156   // need.
   2157   if (temp_buffer.has_value()) {
   2158     bindings_.SetTempBufferBase(kernel_args.at(*temp_buffer));
   2159   } else {
   2160     bindings_.SetTempBufferBase(
   2161         llvm::ConstantPointerNull::get(ir_builder_.getInt8PtrTy()));
   2162   }
   2163 
   2164   return MakeUnique<KernelThunk>(buffers, llvm_ir::AsString(kernel->getName()),
   2165                                  inst);
   2166 }
   2167 
   2168 std::unique_ptr<Thunk> IrEmitterUnnested::BuildHostToDeviceCopyThunk(
   2169     const HloInstruction* inst) {
   2170   const HloInstruction* operand = inst->operand(0);
   2171   CHECK_EQ(HloOpcode::kConstant, operand->opcode());
   2172   return MakeUnique<HostToDeviceCopyThunk>(
   2173       /*source_address=*/operand->literal().untyped_data(),
   2174       /*destination_buffer=*/GetAllocationSlice(*inst),
   2175       /*mem_size=*/
   2176       llvm_ir::ByteSizeOf(operand->shape(),
   2177                           ir_emitter_context_->llvm_module()->getDataLayout()),
   2178       inst);
   2179 }
   2180 
   2181 std::unique_ptr<Thunk> IrEmitterUnnested::BuildDeviceToDeviceCopyThunk(
   2182     const HloInstruction* inst) {
   2183   const HloInstruction* operand = inst->operand(0);
   2184   return MakeUnique<DeviceToDeviceCopyThunk>(
   2185       /*source_address=*/GetAllocationSlice(*operand),
   2186       /*destination_buffer=*/GetAllocationSlice(*inst),
   2187       /*mem_size=*/
   2188       llvm_ir::ByteSizeOf(operand->shape(),
   2189                           ir_emitter_context_->llvm_module()->getDataLayout()),
   2190       inst);
   2191 }
   2192 
   2193 std::unique_ptr<Thunk> IrEmitterUnnested::BuildInfeedThunk(
   2194     const HloInstruction* inst) {
   2195   CHECK_EQ(HloOpcode::kInfeed, inst->opcode());
   2196 
   2197   std::vector<BufferAllocation::Slice> tuple_element_buffers;
   2198   for (int64 i = 0; i < inst->shape().tuple_shapes_size(); ++i) {
   2199     BufferAllocation::Slice buffer = ir_emitter_context_->buffer_assignment()
   2200                                          .GetUniqueSlice(inst, {i})
   2201                                          .ConsumeValueOrDie();
   2202     tuple_element_buffers.push_back(buffer);
   2203   }
   2204 
   2205   return MakeUnique<InfeedThunk>(
   2206       tuple_element_buffers,
   2207       /*destination_buffer=*/GetAllocationSlice(*inst), inst);
   2208 }
   2209 
   2210 std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk(
   2211     const HloInstruction* inst) {
   2212   if (inst->opcode() == HloOpcode::kDot) {
   2213     const HloInstruction* lhs = inst->operand(0);
   2214     const HloInstruction* rhs = inst->operand(1);
   2215     return MakeUnique<GemmThunk>(
   2216         GetAllocationSlice(*lhs),   // The buffer assigned to LHS.
   2217         GetAllocationSlice(*rhs),   // The buffer assigned to RHS.
   2218         GetAllocationSlice(*inst),  // The output buffer.
   2219         lhs->shape(),               // The shape of LHS.
   2220         rhs->shape(),               // The shape of RHS.
   2221         inst->shape(),              // The shape of the output.
   2222         false,                      // Do not transpose LHS.
   2223         false,                      // Do not transpose RHS.
   2224         inst);
   2225   }
   2226 
   2227   if (inst->opcode() == HloOpcode::kFusion) {
   2228     const HloInstruction* dot = inst->fused_expression_root();
   2229     DCHECK(dot->opcode() == HloOpcode::kDot);
   2230     const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0));
   2231     const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1));
   2232     DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter &&
   2233            rhs_parameter->opcode() == HloOpcode::kParameter);
   2234     const HloInstruction* lhs =
   2235         inst->operand(lhs_parameter->parameter_number());
   2236     const HloInstruction* rhs =
   2237         inst->operand(rhs_parameter->parameter_number());
   2238 
   2239     return MakeUnique<GemmThunk>(
   2240         GetAllocationSlice(*lhs),             // The buffer assigned to LHS.
   2241         GetAllocationSlice(*rhs),             // The buffer assigned to RHS.
   2242         GetAllocationSlice(*inst),            // The output buffer.
   2243         lhs->shape(),                         // The shape of LHS.
   2244         rhs->shape(),                         // The shape of RHS.
   2245         inst->shape(),                        // The shape of the output.
   2246         dot->operand(0)->IsRank2Transpose(),  // Transpose LHS.
   2247         dot->operand(1)->IsRank2Transpose(),  // Trasnpose RHS.
   2248         inst);
   2249   }
   2250 
   2251   LOG(FATAL) << "Cannot build a GemmThunk for " << inst->ToString();
   2252 }
   2253 
   2254 std::unique_ptr<Thunk> IrEmitterUnnested::BuildFftThunk(
   2255     const HloInstruction* inst) {
   2256   const HloInstruction* operand = inst->operand(0);
   2257   return MakeUnique<FftThunk>(inst->fft_type(), inst->fft_length(),
   2258                               /*input_buffer=*/GetAllocationSlice(*operand),
   2259                               /*output_buffer=*/GetAllocationSlice(*inst),
   2260                               /*input_shape=*/operand->shape(),
   2261                               /*output_shape=*/inst->shape(), inst);
   2262 }
   2263 
   2264 Status IrEmitterUnnested::EmitInitializer(const HloInstruction* hlo,
   2265                                           KernelThunk* thunk) {
   2266   bool fused = HloOpcode::kFusion == hlo->opcode();
   2267 
   2268   const HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo;
   2269   CHECK(inst->opcode() == HloOpcode::kSelectAndScatter ||
   2270         inst->opcode() == HloOpcode::kReduce);
   2271   const HloInstruction* init_value = nullptr;
   2272   switch (inst->opcode()) {
   2273     case HloOpcode::kSelectAndScatter:
   2274       init_value = inst->operand(2);
   2275       break;
   2276     case HloOpcode::kReduce:
   2277       init_value = inst->operand(1);
   2278       break;
   2279     default:
   2280       LOG(FATAL) << "Opcode " << inst->opcode()
   2281                  << " should not need an initializer.";
   2282   }
   2283 
   2284   if (fused && init_value->opcode() == HloOpcode::kParameter) {
   2285     init_value = hlo->operand(init_value->parameter_number());
   2286   }
   2287 
   2288   return EmitTargetElementLoopInThunk(
   2289       *hlo,
   2290       [=](const llvm_ir::IrArray::Index& index) {
   2291         return GetIrArray(*init_value, *hlo)
   2292             .EmitReadArrayElement(index, &ir_builder_);
   2293       },
   2294       thunk);
   2295 }
   2296 
   2297 namespace {
   2298 
   2299 // Checks that the buffers corresponding to the given two HLOs share the same
   2300 // allocation.
   2301 Status CheckHloBuffersShareAllocation(
   2302     const HloInstruction* a, const HloInstruction* b, const ShapeIndex& index,
   2303     const BufferAssignment& buffer_assignment) {
   2304   const BufferAllocation::Slice slice_a =
   2305       buffer_assignment.GetUniqueSlice(a, index).ConsumeValueOrDie();
   2306   const BufferAllocation::Slice slice_b =
   2307       buffer_assignment.GetUniqueSlice(b, index).ConsumeValueOrDie();
   2308   if (slice_a != slice_b) {
   2309     return InternalError(
   2310         "instruction %s %s does not share allocation with instruction %s %s",
   2311         a->ToString().c_str(), slice_a.ToString().c_str(),
   2312         b->ToString().c_str(), slice_b.ToString().c_str());
   2313   }
   2314   return Status::OK();
   2315 }
   2316 
   2317 // Checks that all buffers used during while loop iteration share the same
   2318 // buffer allocation. This includes buffers for while result, while init
   2319 // operand, condition parameter, body parameter and body result.
   2320 // Returns OK on success, error status otherwise.
   2321 Status CheckWhileBuffersShareAllocation(
   2322     const HloInstruction* xla_while,
   2323     const BufferAssignment& buffer_assignment) {
   2324   return ShapeUtil::ForEachSubshapeWithStatus(
   2325       xla_while->shape(),
   2326       [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
   2327         const HloInstruction* condition_parameter =
   2328             xla_while->while_condition()->parameter_instruction(0);
   2329         const HloComputation* body = xla_while->while_body();
   2330         const HloInstruction* body_parameter = body->parameter_instruction(0);
   2331         const HloInstruction* body_result = body->root_instruction();
   2332         TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
   2333             xla_while, xla_while->operand(0), index, buffer_assignment));
   2334         TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
   2335             xla_while, condition_parameter, index, buffer_assignment));
   2336         TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
   2337             xla_while, body_parameter, index, buffer_assignment));
   2338         TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
   2339             xla_while, body_result, index, buffer_assignment));
   2340         return Status::OK();
   2341       });
   2342 }
   2343 
   2344 // Checks that the buffers used in a conditional instruction are shared with the
   2345 // operands and result as follows:
   2346 //   * The result buffer of the conditional should share the allocation with the
   2347 //     result buffers of the true and false computations.
   2348 //   * The buffer of operand 1 should share the allocation with the buffer of
   2349 //     the parameter 0 instruction of the true computation.
   2350 //   * The buffer of operand 2 should share the allocation with the buffer of
   2351 //     the parameter 0 instruction of the false computation.
   2352 Status CheckConditionalBuffersShareAllocation(
   2353     const HloInstruction* conditional,
   2354     const BufferAssignment& buffer_assignment) {
   2355   TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
   2356       conditional->shape(),
   2357       [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
   2358         TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
   2359             conditional, conditional->true_computation()->root_instruction(),
   2360             index, buffer_assignment));
   2361         TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
   2362             conditional, conditional->false_computation()->root_instruction(),
   2363             index, buffer_assignment));
   2364         return Status::OK();
   2365       }));
   2366   TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
   2367       conditional->operand(1)->shape(),
   2368       [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
   2369         return CheckHloBuffersShareAllocation(
   2370             conditional->operand(1),
   2371             conditional->true_computation()->parameter_instruction(0), index,
   2372             buffer_assignment);
   2373       }));
   2374   TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
   2375       conditional->operand(2)->shape(),
   2376       [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
   2377         return CheckHloBuffersShareAllocation(
   2378             conditional->operand(2),
   2379             conditional->false_computation()->parameter_instruction(0), index,
   2380             buffer_assignment);
   2381       }));
   2382   return Status::OK();
   2383 }
   2384 
   2385 }  // namespace
   2386 
   2387 std::unique_ptr<Thunk> IrEmitterUnnested::BuildWhileThunk(
   2388     const HloInstruction* hlo) {
   2389   // Check that all while-related buffers share an allocation.
   2390   TF_CHECK_OK(CheckWhileBuffersShareAllocation(
   2391       hlo, ir_emitter_context_->buffer_assignment()));
   2392 
   2393   // Generate thunk sequence for while 'condition'.
   2394   HloComputation* condition = hlo->while_condition();
   2395   IrEmitterUnnested ir_emitter_condition(hlo_module_config_, condition,
   2396                                          ir_emitter_context_);
   2397   TF_CHECK_OK(condition->root_instruction()->Accept(&ir_emitter_condition));
   2398 
   2399   // Generate thunk sequence for while 'body'.
   2400   HloComputation* body = hlo->while_body();
   2401   IrEmitterUnnested ir_emitter_body(hlo_module_config_, body,
   2402                                     ir_emitter_context_);
   2403   TF_CHECK_OK(body->root_instruction()->Accept(&ir_emitter_body));
   2404 
   2405   return MakeUnique<WhileThunk>(
   2406       GetAllocationSlice(*condition->root_instruction()),  // cond result
   2407       ir_emitter_condition.ConsumeThunkSequence(),
   2408       ir_emitter_body.ConsumeThunkSequence(), hlo);
   2409 }
   2410 
   2411 std::unique_ptr<Thunk> IrEmitterUnnested::BuildForThunk(
   2412     const HloInstruction* hlo, const int64 loop_limit) {
   2413   // Check that all while-related buffers share an allocation.
   2414   TF_CHECK_OK(CheckWhileBuffersShareAllocation(
   2415       hlo, ir_emitter_context_->buffer_assignment()));
   2416 
   2417   // Generate thunk sequence for while 'body' (will be used a For loop body).
   2418   HloComputation* body = hlo->while_body();
   2419   IrEmitterUnnested ir_emitter_body(hlo_module_config_, body,
   2420                                     ir_emitter_context_);
   2421   TF_CHECK_OK(body->root_instruction()->Accept(&ir_emitter_body));
   2422 
   2423   return MakeUnique<ForThunk>(loop_limit,
   2424                               ir_emitter_body.ConsumeThunkSequence(), hlo);
   2425 }
   2426 
   2427 std::unique_ptr<Thunk> IrEmitterUnnested::BuildConditionalThunk(
   2428     const HloInstruction* hlo) {
   2429   // Check that the buffers used in conditional are shared with the operands and
   2430   // result appropriately.
   2431   TF_CHECK_OK(CheckConditionalBuffersShareAllocation(
   2432       hlo, ir_emitter_context_->buffer_assignment()));
   2433 
   2434   HloComputation* true_computation = hlo->true_computation();
   2435   IrEmitterUnnested ir_emitter_true(hlo_module_config_, true_computation,
   2436                                     ir_emitter_context_);
   2437   TF_CHECK_OK(true_computation->root_instruction()->Accept(&ir_emitter_true));
   2438 
   2439   HloComputation* false_computation = hlo->false_computation();
   2440   IrEmitterUnnested ir_emitter_false(hlo_module_config_, false_computation,
   2441                                      ir_emitter_context_);
   2442   TF_CHECK_OK(false_computation->root_instruction()->Accept(&ir_emitter_false));
   2443 
   2444   return MakeUnique<ConditionalThunk>(
   2445       GetAllocationSlice(*hlo->operand(0)),
   2446       GetAllocationSlice(*hlo->operand(1)),
   2447       GetAllocationSlice(*hlo->operand(2)),
   2448       std::move(*ir_emitter_true.ConsumeThunkSequence()),
   2449       std::move(*ir_emitter_false.ConsumeThunkSequence()), hlo);
   2450 }
   2451 
   2452 Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
   2453     const HloInstruction& hlo,
   2454     const llvm_ir::ElementGenerator& element_generator, KernelThunk* thunk) {
   2455   VLOG(3) << bindings_.ToString();
   2456 
   2457   const Shape& element_shape = hlo.IsMultiOutputFusion()
   2458                                    ? ShapeUtil::GetSubshape(hlo.shape(), {0})
   2459                                    : hlo.shape();
   2460   LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
   2461       element_shape, ir_emitter_context_->device_description());
   2462   UpdateLaunchDimensions(launch_dimensions, thunk,
   2463                          ir_emitter_context_->llvm_module());
   2464   if (!hlo.IsMultiOutputFusion()) {
   2465     return ParallelLoopEmitter(element_generator, GetIrArray(hlo, hlo),
   2466                                launch_dimensions, &ir_builder_)
   2467         .EmitLoop(IrName(&hlo));
   2468   }
   2469 
   2470   // For multiple outputs fusion, we need to emit each operand and the root.
   2471   std::vector<llvm_ir::IrArray> output_arrays;
   2472   for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) {
   2473     output_arrays.push_back(GetIrArray(hlo, hlo, {i}));
   2474   }
   2475   TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, output_arrays,
   2476                                          launch_dimensions, &ir_builder_)
   2477                          .EmitLoop(IrName(&hlo)));
   2478 
   2479   std::vector<llvm::Value*> tuple_operand_ptrs;
   2480   for (int64 i = 0; i < output_arrays.size(); ++i) {
   2481     tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
   2482   }
   2483   ir_builder_.SetInsertPoint(ir_builder_.GetInsertBlock()->getTerminator());
   2484   llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &ir_builder_,
   2485                      module_);
   2486   return Status::OK();
   2487 }
   2488 
   2489 Status IrEmitterUnnested::EmitTargetElementLoop(
   2490     const HloInstruction& hlo,
   2491     const llvm_ir::ElementGenerator& element_generator) {
   2492   CHECK(Thunk::Kind::kKernel == LastThunk()->kind());
   2493   return EmitTargetElementLoopInThunk(hlo, element_generator,
   2494                                       static_cast<KernelThunk*>(LastThunk()));
   2495 }
   2496 
   2497 }  // namespace gpu
   2498 }  // namespace xla
   2499