Home | History | Annotate | Download | only in cpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/cpu/ir_emitter.h"
     17 
     18 #include <stddef.h>
     19 #include <stdint.h>
     20 #include <algorithm>
     21 #include <iterator>
     22 #include <limits>
     23 #include <memory>
     24 #include <utility>
     25 #include <vector>
     26 
     27 #include "tensorflow/core/lib/math/math_util.h"
     28 #include "tensorflow/core/platform/logging.h"
     29 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
     30 #include "llvm/CodeGen/TargetRegisterInfo.h"
     31 #include "llvm/CodeGen/TargetSubtargetInfo.h"
     32 #include "llvm/IR/BasicBlock.h"
     33 #include "llvm/IR/Constants.h"
     34 #include "llvm/IR/GlobalVariable.h"
     35 #include "llvm/IR/Instructions.h"
     36 #include "llvm/IR/Intrinsics.h"
     37 #include "llvm/IR/LLVMContext.h"
     38 #include "tensorflow/compiler/xla/layout_util.h"
     39 #include "tensorflow/compiler/xla/map_util.h"
     40 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
     41 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
     42 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
     43 #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
     44 #include "tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h"
     45 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
     46 #include "tensorflow/compiler/xla/service/cpu/ir_function.h"
     47 #include "tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h"
     48 #include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
     49 #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
     50 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
     51 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     52 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
     53 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
     54 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
     55 #include "tensorflow/compiler/xla/service/llvm_ir/ops.h"
     56 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
     57 #include "tensorflow/compiler/xla/shape_util.h"
     58 #include "tensorflow/compiler/xla/status_macros.h"
     59 #include "tensorflow/compiler/xla/types.h"
     60 #include "tensorflow/compiler/xla/util.h"
     61 #include "tensorflow/compiler/xla/window_util.h"
     62 #include "tensorflow/core/lib/core/bits.h"
     63 #include "tensorflow/core/lib/core/errors.h"
     64 #include "tensorflow/core/lib/gtl/array_slice.h"
     65 #include "tensorflow/core/lib/gtl/flatmap.h"
     66 #include "tensorflow/core/lib/gtl/flatset.h"
     67 #include "tensorflow/core/lib/strings/strcat.h"
     68 #include "tensorflow/core/lib/strings/stringprintf.h"
     69 
     70 namespace xla {
     71 
     72 namespace {
     73 using llvm_ir::AsStringRef;
     74 using llvm_ir::IrName;
     75 using llvm_ir::SetToFirstInsertPoint;
     76 namespace gtl = tensorflow::gtl;
     77 }  // namespace
     78 
     79 namespace cpu {
     80 
     81 IrEmitter::IrEmitter(
     82     const HloModule& hlo_module, const BufferAssignment& assignment,
     83     llvm::Module* llvm_module,
     84     std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx,
     85     std::unordered_map<const HloComputation*, int64> computation_to_profile_idx,
     86     llvm::TargetMachine* target_machine,
     87     ExternalConstantPool* external_constant_pool)
     88     : assignment_(assignment),
     89       module_(llvm_module),
     90       arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()),
     91       ir_builder_(llvm_module->getContext()),
     92       instruction_to_profile_idx_(std::move(instruction_to_profile_idx)),
     93       computation_to_profile_idx_(std::move(computation_to_profile_idx)),
     94       alias_analysis_(hlo_module, assignment, &llvm_module->getContext()),
     95       hlo_module_config_(hlo_module.config()),
     96       parallel_cpu_backend_(
     97           options::CpuParallelBackendRequested(hlo_module_config_)),
     98       is_top_level_computation_(false),
     99       target_machine_features_(target_machine),
    100       external_constant_pool_(external_constant_pool) {
    101   ir_builder_.setFastMathFlags(llvm_ir::GetFastMathFlags(
    102       /*fast_math_enabled=*/hlo_module_config_.debug_options()
    103           .xla_enable_fast_math()));
    104 }
    105 
    106 StatusOr<llvm::Function*> IrEmitter::EmitComputation(
    107     HloComputation* computation, const string& function_name_prefix,
    108     bool is_top_level_computation,
    109     std::vector<const HloInstruction*>* instruction_order) {
    110   string function_name = name_uniquer_.GetUniqueName(function_name_prefix);
    111   VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix
    112           << "]; ordered? " << (instruction_order != nullptr);
    113   is_top_level_computation_ = is_top_level_computation;
    114   num_dynamic_loop_bounds_ = 0;
    115   if (!computation->root_instruction()->outer_dimension_partitions().empty()) {
    116     num_dynamic_loop_bounds_ =
    117         computation->root_instruction()->outer_dimension_partitions().size();
    118   }
    119 
    120   InitializeIrFunction(function_name);
    121   // The rdtscp instruction is x86 specific.  We will fallback to LLVM's generic
    122   // readcyclecounter if it is unavailable.
    123   bool use_rdtscp = arch_type_ == llvm::Triple::ArchType::x86 ||
    124                     arch_type_ == llvm::Triple::ArchType::x86_64;
    125   profiling_state_ = ProfilingState(use_rdtscp, GetProfileCountersArgument());
    126   if (instruction_order == nullptr) {
    127     TF_RETURN_IF_ERROR(computation->Accept(this));
    128   } else {
    129     TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, *instruction_order));
    130   }
    131   llvm::Function* ir_function = compute_function_->function();
    132   InsertOrDie(&emitted_functions_, computation, ir_function);
    133   // Delete 'compute_function', finalizing 'ir_function' and restoring caller
    134   // IR insert point.
    135   compute_function_.reset();
    136   return ir_function;
    137 }
    138 
    139 void IrEmitter::InitializeIrFunction(const string& function_name) {
    140   // Functions with local linkage get an inlining bonus.  Because we know
    141   // a-priori that embedded functions (non-entry functions) will not have its
    142   // name resolved, give it local linkage.
    143   llvm::Function::LinkageTypes linkage =
    144       is_top_level_computation_ ? llvm::GlobalValue::ExternalLinkage
    145                                 : llvm::GlobalValue::InternalLinkage;
    146   // Create and initialize new IrFunction.
    147   compute_function_.reset(
    148       new IrFunction(function_name, linkage,
    149                      options::OptimizeForSizeRequested(hlo_module_config_),
    150                      hlo_module_config_.debug_options().xla_enable_fast_math(),
    151                      module_, &ir_builder_, num_dynamic_loop_bounds_));
    152 }
    153 
    154 IrEmitter::~IrEmitter() {}
    155 
    156 Status IrEmitter::HandleBitcast(HloInstruction* bitcast) {
    157   VLOG(2) << "HandleBitcast: " << bitcast->ToString();
    158   emitted_value_[bitcast] =
    159       ir_builder_.CreateBitCast(GetEmittedValueFor(bitcast->operand(0)),
    160                                 IrShapeType(bitcast->shape())->getPointerTo(),
    161                                 AsStringRef(IrName(bitcast)));
    162   return Status::OK();
    163 }
    164 
    165 Status IrEmitter::HandleConstant(HloInstruction* constant) {
    166   VLOG(2) << "HandleConstant: " << constant->ToString();
    167   const Literal& literal = constant->literal();
    168   llvm::GlobalVariable* global_for_const;
    169 
    170   // We avoid creating large constants in the LLVM IR since LLVM is not
    171   // efficient for large constant arrays.  We still emit "small enough" constant
    172   // arrays into the Ir, in the off chance the LLVM optimizer can do something
    173   // interesting with it.
    174   const int kMaxInternalConstantSizeInBytes = 128;
    175   if (external_constant_pool_ &&
    176       ByteSizeOf(literal.shape()) >= kMaxInternalConstantSizeInBytes) {
    177     string global_name = tensorflow::strings::StrCat(
    178         "constant_global_", external_global_constant_counter_++);
    179     global_for_const = new llvm::GlobalVariable(
    180         /*Module=*/*module_,
    181         /*Type=*/IrShapeType(literal.shape()),
    182         /*isConstant=*/true,
    183         /*Linkage=*/llvm::GlobalValue::ExternalLinkage,
    184         /*Initializer=*/nullptr,
    185         /*Name=*/AsStringRef(global_name));
    186     global_for_const->setAlignment(MinimumAlignmentForShape(literal.shape()));
    187     external_constant_pool_->Insert(global_name, literal,
    188                                     MinimumAlignmentForShape(literal.shape()));
    189   } else {
    190     llvm::Constant* initializer =
    191         llvm_ir::ConvertLiteralToIrConstant(literal, module_);
    192     global_for_const = new llvm::GlobalVariable(
    193         /*Module=*/*module_,
    194         /*Type=*/initializer->getType(),
    195         /*isConstant=*/true,
    196         /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
    197         /*Initializer=*/initializer,
    198         /*Name=*/"");
    199     global_for_const->setAlignment(MinimumAlignmentForShape(literal.shape()));
    200   }
    201   emitted_value_[constant] = global_for_const;
    202   VLOG(2) << "  emitted value: " << llvm_ir::DumpToString(*global_for_const);
    203   VLOG(2) << "  its type: "
    204           << llvm_ir::DumpToString(*global_for_const->getType());
    205   return Status::OK();
    206 }
    207 
    208 Status IrEmitter::HandleCopy(HloInstruction* copy) {
    209   if (ShapeUtil::IsTuple(copy->shape())) {
    210     // kCopy shallow copies a tuple so just memcpy the top-level buffer.
    211     TF_RETURN_IF_ERROR(EmitTargetAddressForOp(copy));
    212     return EmitMemcpy(*(copy->operand(0)), *copy);
    213   } else {
    214     // Use the elemental emitter for non-tuple shapes.
    215     return DefaultAction(copy);
    216   }
    217 }
    218 
    219 // Calculate the alignment of a buffer with a particular size.
    220 int IrEmitter::MinimumAlignmentForBufferSize(int64 buffer_size) {
    221   // GLibc returns a pointer with alignment 8 on 32-bit platforms and 16 on
    222   // 64-bit platforms.  TCMalloc returns a pointer with alignment 8 for
    223   // allocations smaller than kMallocAlignmentThreshold bytes and at least
    224   // alignment 16 for allocations greater than or equal to
    225   // kMallocAlignmentThreshold bytes.  N.B. We could improve on this lower bound
    226   // by explicitly allocating the memory with posix_memalign.  This is
    227   // complicated by our desire to allow parameter buffers created by clients to
    228   // be consumed directly by the JIT.
    229   if (buffer_size == 0) {
    230     // No need to align empty buffers.
    231     return 1;
    232   }
    233 
    234   const int64 kMallocAlignmentThreshold = 512;
    235 
    236   int pointer_size = module_->getDataLayout().getPointerSize();
    237   int buffer_alignment = buffer_size >= kMallocAlignmentThreshold
    238                              ? 2 * pointer_size
    239                              : pointer_size;
    240   DCHECK_GT(buffer_alignment, 0);
    241 
    242   return buffer_alignment;
    243 }
    244 
    245 // Calculate the alignment of a buffer allocated for a given primitive type.
    246 int IrEmitter::MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type) {
    247   int64 byte_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
    248   DCHECK_GE(byte_size, 0);
    249   // Largest scalar is a complex64 so we don't need to worry about the
    250   // int64->int truncation here.
    251   DCHECK_LE(byte_size, 8);
    252   return byte_size;
    253 }
    254 
    255 int64 IrEmitter::ByteSizeOf(const Shape& shape) const {
    256   return llvm_ir::ByteSizeOf(shape, module_->getDataLayout());
    257 }
    258 
    259 // Calculate the alignment of a buffer allocated for a given shape.
    260 int IrEmitter::MinimumAlignmentForShape(const Shape& shape) {
    261   if (ShapeUtil::IsScalar(shape)) {
    262     return MinimumAlignmentForPrimitiveType(shape.element_type());
    263   }
    264 
    265   int64 buffer_size = ByteSizeOf(shape);
    266   DCHECK_GE(buffer_size, 0);
    267   DCHECK_LE(buffer_size, SIZE_MAX);
    268 
    269   return MinimumAlignmentForBufferSize(buffer_size);
    270 }
    271 
    272 void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load,
    273                                                const Shape& shape) {
    274   int alignment = MinimumAlignmentForShape(shape);
    275   if (alignment > 1) {
    276     llvm_ir::SetAlignmentMetadataForLoad(load, alignment);
    277   }
    278 }
    279 
    280 void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load,
    281                                                int64 buffer_size) {
    282   int alignment = MinimumAlignmentForBufferSize(buffer_size);
    283   if (alignment > 1) {
    284     llvm_ir::SetAlignmentMetadataForLoad(load, alignment);
    285   }
    286 }
    287 
    288 void IrEmitter::AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
    289                                                      const Shape& shape) {
    290   AttachDereferenceableMetadataForLoad(load, ByteSizeOf(shape));
    291 }
    292 
    293 void IrEmitter::AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
    294                                                      int64 buffer_size) {
    295   if (buffer_size > 0) {
    296     llvm_ir::SetDereferenceableMetadataForLoad(load, buffer_size);
    297   }
    298 }
    299 
    300 Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) {
    301   // A tuple is an array of pointers, one for each operand. Each pointer points
    302   // to the output buffer of its corresponding operand. A GetTupleElement
    303   // instruction forwards a pointer to the tuple element buffer at the given
    304   // index.
    305   auto operand = get_tuple_element->operand(0);
    306   const Shape& shape = get_tuple_element->shape();
    307   emitted_value_[get_tuple_element] = llvm_ir::EmitGetTupleElement(
    308       shape, get_tuple_element->tuple_index(), MinimumAlignmentForShape(shape),
    309       GetEmittedValueFor(operand), &ir_builder_, module_);
    310   return Status::OK();
    311 }
    312 
    313 Status IrEmitter::HandleSelect(HloInstruction* select) {
    314   auto pred = select->operand(0);
    315   auto on_true = select->operand(1);
    316   auto on_false = select->operand(2);
    317   TF_RET_CHECK(pred->shape().element_type() == PRED);
    318 
    319   if (ShapeUtil::IsTuple(select->shape())) {
    320     TF_RETURN_IF_ERROR(EmitTargetAddressForOp(select));
    321     llvm_ir::EmitTupleSelect(
    322         GetIrArrayFor(select), GetIrArrayFor(pred), GetEmittedValueFor(on_true),
    323         GetEmittedValueFor(on_false), &ir_builder_, module_);
    324     return Status::OK();
    325   }
    326 
    327   return DefaultAction(select);
    328 }
    329 
    330 Status IrEmitter::HandleInfeed(HloInstruction* infeed) {
    331   VLOG(2) << "HandleInfeed: " << infeed->ToString();
    332 
    333   const Shape& shape = infeed->shape();
    334 
    335   // The infeed operation produces data (dequeued from the infeed queue) at this
    336   // address, which has been provided by buffer assignment.
    337   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(infeed));
    338   llvm_ir::IrArray infeed_array = GetIrArrayFor(infeed);
    339 
    340   if (ShapeUtil::IsTuple(shape)) {
    341     TF_RET_CHECK(!ShapeUtil::IsNestedTuple(shape));
    342 
    343     // For a tuple, we first copy each of the internal elements to
    344     // their corresponding target locations. We then construct the
    345     // tuple outer buffer containing pointers to the internal
    346     // elements.
    347     std::vector<llvm::Value*> tuple_element_addresses;
    348     for (int64 i = 0; i < shape.tuple_shapes_size(); ++i) {
    349       TF_ASSIGN_OR_RETURN(BufferAllocation::Slice buffer,
    350                           assignment_.GetUniqueSlice(infeed, {i}));
    351 
    352       const Shape& tuple_element_shape =
    353           ShapeUtil::GetTupleElementShape(shape, i);
    354 
    355       // Only the outer tuple buffer's target address is obtained from
    356       // GetEmittedValueFor, to handle the case when Infeed is the root
    357       // instruction. Target addresses for internal elements can be obtained
    358       // from EmitTempBufferPointer.
    359       llvm::Value* tuple_element_address =
    360           EmitTempBufferPointer(buffer, tuple_element_shape);
    361 
    362       TF_RETURN_IF_ERROR(EmitXfeedTransfer(
    363           XfeedKind::kInfeed, tuple_element_shape, tuple_element_address));
    364 
    365       tuple_element_addresses.push_back(tuple_element_address);
    366     }
    367 
    368     llvm_ir::EmitTuple(infeed_array, tuple_element_addresses, &ir_builder_,
    369                        module_);
    370   } else {
    371     TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kInfeed, shape,
    372                                          GetEmittedValueFor(infeed)));
    373   }
    374 
    375   return Status::OK();
    376 }
    377 
    378 Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
    379                                     llvm::Value* program_buffer_address) {
    380   int64 length = ByteSizeOf(shape);
    381   if (length <= 0 || length > std::numeric_limits<int32>::max()) {
    382     return InvalidArgument(
    383         "xfeed (infeed or outfeed) buffer length %lld is outside the valid "
    384         "size range",
    385         length);
    386   }
    387   int32 length_32 = static_cast<int32>(length);
    388 
    389   int32 shape_length;
    390   TF_ASSIGN_OR_RETURN(llvm::Value * shape_ptr,
    391                       llvm_ir::EncodeSelfDescribingShapeConstant(
    392                           shape, &shape_length, &ir_builder_));
    393 
    394   // The signature of the acquire infeed buffer function is:
    395   //
    396   //   (void*)(int32 length);
    397   llvm::Type* int32_type = ir_builder_.getInt32Ty();
    398   llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
    399   llvm::FunctionType* acquire_type = llvm::FunctionType::get(
    400       i8_ptr_type, {int32_type, i8_ptr_type, int32_type},
    401       /*isVarArg=*/false);
    402 
    403   llvm::Function* acquire_func;
    404   if (kind == XfeedKind::kInfeed) {
    405     acquire_func = llvm::cast<llvm::Function>(module_->getOrInsertFunction(
    406         runtime::kAcquireInfeedBufferForDequeueSymbolName, acquire_type));
    407   } else {
    408     acquire_func = llvm::cast<llvm::Function>(module_->getOrInsertFunction(
    409         runtime::kAcquireOutfeedBufferForPopulationSymbolName, acquire_type));
    410   }
    411   acquire_func->setCallingConv(llvm::CallingConv::C);
    412 
    413   // The signature of the release infeed buffer function is:
    414   //
    415   //   (void)(int32 length, void* buffer);
    416   llvm::FunctionType* release_type = llvm::FunctionType::get(
    417       ir_builder_.getVoidTy(),
    418       {int32_type, i8_ptr_type, i8_ptr_type, int32_type},
    419       /*isVarArg=*/false);
    420 
    421   llvm::Function* release_func;
    422   if (kind == XfeedKind::kInfeed) {
    423     release_func = llvm::cast<llvm::Function>(module_->getOrInsertFunction(
    424         runtime::kReleaseInfeedBufferAfterDequeueSymbolName, release_type));
    425   } else {
    426     release_func = llvm::cast<llvm::Function>(module_->getOrInsertFunction(
    427         runtime::kReleaseOutfeedBufferAfterPopulationSymbolName, release_type));
    428   }
    429   release_func->setCallingConv(llvm::CallingConv::C);
    430 
    431   // Implementation note: this call informs the runtime that it wants a buffer
    432   // of size exactly 'length_32', and the runtime is responsible for
    433   // check-failing the process if there is a mismatch, versus passing us back a
    434   // buffer that we might overrun.
    435   llvm::Value* acquired_pointer = ir_builder_.CreateCall(
    436       acquire_func, {ir_builder_.getInt32(length_32), shape_ptr,
    437                      ir_builder_.getInt32(shape_length)});
    438 
    439   if (kind == XfeedKind::kInfeed) {
    440     // Copy to the program buffer address from the acquired buffer.
    441     ir_builder_.CreateMemCpy(program_buffer_address, acquired_pointer,
    442                              length_32, 1);
    443   } else {
    444     // Outfeed -- copy from the in-program address to the acquired buffer.
    445     ir_builder_.CreateMemCpy(acquired_pointer, program_buffer_address,
    446                              length_32, 1);
    447   }
    448 
    449   ir_builder_.CreateCall(release_func,
    450                          {ir_builder_.getInt32(length_32), acquired_pointer,
    451                           shape_ptr, ir_builder_.getInt32(shape_length)});
    452 
    453   return Status::OK();
    454 }
    455 
    456 Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) {
    457   HloInstruction* operand = outfeed->operands()[0];
    458   const Shape& operand_shape = operand->shape();
    459 
    460   llvm::Value* value = GetEmittedValueFor(operand);
    461   if (!ShapeUtil::IsTuple(operand_shape)) {
    462     return EmitXfeedTransfer(XfeedKind::kOutfeed, operand_shape, value);
    463   }
    464 
    465   TF_RET_CHECK(!ShapeUtil::IsNestedTuple(operand_shape));
    466 
    467   for (int64 i = 0; i < operand_shape.tuple_shapes_size(); ++i) {
    468     const Shape& tuple_element_shape =
    469         ShapeUtil::GetTupleElementShape(operand_shape, i);
    470     llvm::Value* tuple_element = llvm_ir::EmitGetTupleElement(
    471         tuple_element_shape, i, MinimumAlignmentForShape(tuple_element_shape),
    472         value, &ir_builder_, module_);
    473     TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kOutfeed,
    474                                          tuple_element_shape, tuple_element));
    475   }
    476 
    477   return Status::OK();
    478 }
    479 
    480 Status IrEmitter::HandleSort(HloInstruction* sort) {
    481   // TODO(b/26783907): Implement sort on CPU.
    482   return Unimplemented("Sort is not implemented on CPU.");
    483 }
    484 
    485 Status IrEmitter::HandleTuple(HloInstruction* tuple) {
    486   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(tuple));
    487   std::vector<llvm::Value*> base_ptrs;
    488   for (auto operand : tuple->operands()) {
    489     base_ptrs.push_back(GetEmittedValueFor(operand));
    490   }
    491   llvm_ir::EmitTuple(GetIrArrayFor(tuple), base_ptrs, &ir_builder_, module_);
    492   return Status::OK();
    493 }
    494 
    495 Status IrEmitter::HandleMap(HloInstruction* map) {
    496   gtl::ArraySlice<HloInstruction*> operands(map->operands());
    497   HloComputation* function = map->to_apply();
    498   // The called computation should have been emitted previously.
    499   llvm::Function* mapped_ir_function = FindOrDie(emitted_functions_, function);
    500 
    501   return EmitTargetElementLoop(map, [this, map, operands, mapped_ir_function](
    502                                         const llvm_ir::IrArray::Index& index) {
    503     std::vector<llvm::Value*> parameter_addresses;
    504     for (const HloInstruction* operand : operands) {
    505       const llvm_ir::IrArray& array = GetIrArrayFor(operand);
    506       parameter_addresses.push_back(
    507           array.EmitArrayElementAddress(index, &ir_builder_));
    508     }
    509     return EmitElementFunctionCall(mapped_ir_function, map->shape(),
    510                                    parameter_addresses, "map_function");
    511   });
    512 }
    513 
    514 Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
    515   auto operand = reduce_window->operand(0);
    516   const Window& window = reduce_window->window();
    517   HloComputation* function = reduce_window->to_apply();
    518   TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
    519       /*instruction=*/*reduce_window, /*operands=*/{operand},
    520       /*supported_types=*/{F32, BF16}));
    521 
    522   // TODO(b/31410564): Implement dilation for reduce-window.
    523   if (window_util::HasDilation(window)) {
    524     return Unimplemented(
    525         "Dilation for ReduceWindow is not implemented on CPU.");
    526   }
    527 
    528   // The called computation should have been emitted previously.
    529   llvm::Function* reducer_function = FindOrDie(emitted_functions_, function);
    530 
    531   // Pseudo code for reduce window:
    532   //
    533   //   for (coordinates O in the output)
    534   //     value = init_value;
    535   //     for (coordinates W in the window)
    536   //       for each index i:
    537   //         input coordinates I_i = O_i * stride_i + W_i - pad_low_i
    538   //       if I within bounds of input:
    539   //         value = function(value, input(I));
    540   //     output(O) = value;
    541   //
    542   // This is completely un-optimized and just here to have something
    543   // that works.
    544   return EmitTargetElementLoop(
    545       reduce_window, [this, reduce_window, operand, window,
    546                       reducer_function](const llvm_ir::IrArray::Index& index) {
    547         // We fold inputs into the accumulator and initialize it to
    548         // the initial value on the reduce_window.
    549         PrimitiveType operand_element_type = operand->shape().element_type();
    550         llvm::Value* accumulator_address = llvm_ir::EmitAllocaAtFunctionEntry(
    551             llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
    552             "reduce_window_accumulator_address", &ir_builder_,
    553             MinimumAlignmentForPrimitiveType(operand_element_type));
    554         ir_builder_.CreateStore(ir_builder_.CreateLoad(GetEmittedValueFor(
    555                                     reduce_window->operand(1))),
    556                                 accumulator_address);
    557 
    558         llvm_ir::ForLoopNest loops(IrName(reduce_window, "inner"),
    559                                    &ir_builder_);
    560         std::vector<int64> window_size;
    561         for (const auto& dim : window.dimensions()) {
    562           window_size.push_back(dim.size());
    563         }
    564         const llvm_ir::IrArray::Index window_index = loops.AddLoopsForShape(
    565             ShapeUtil::MakeShape(operand_element_type, window_size), "window");
    566         CHECK_EQ(window_index.size(), index.size());
    567 
    568         SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_);
    569 
    570         llvm_ir::IrArray::Index input_index(index.size());
    571         llvm::Value* in_bounds_condition = nullptr;
    572         for (size_t i = 0; i < index.size(); ++i) {
    573           llvm::Value* strided_index = ir_builder_.CreateNSWMul(
    574               index[i], ir_builder_.getInt64(window.dimensions(i).stride()));
    575           input_index[i] = ir_builder_.CreateNSWSub(
    576               ir_builder_.CreateNSWAdd(strided_index, window_index[i]),
    577               ir_builder_.getInt64(window.dimensions(i).padding_low()));
    578 
    579           // We need to check if 0 <= input_index[i] < bound, as
    580           // otherwise we are in the padding so that we can skip the
    581           // computation. That is equivalent to input_index[i] < bound
    582           // as an *unsigned* comparison, since a negative value will
    583           // wrap to a large positive value.
    584           llvm::Value* index_condition = ir_builder_.CreateICmpULT(
    585               input_index[i], ir_builder_.getInt64(ShapeUtil::GetDimension(
    586                                   operand->shape(), i)));
    587           if (in_bounds_condition == nullptr) {
    588             in_bounds_condition = index_condition;
    589           } else {
    590             in_bounds_condition =
    591                 ir_builder_.CreateAnd(in_bounds_condition, index_condition);
    592           }
    593         }
    594         CHECK(in_bounds_condition != nullptr);
    595 
    596         llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
    597             in_bounds_condition, "in-bounds", &ir_builder_);
    598         SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
    599 
    600         // We are not in the padding, so carry out the computation.
    601         llvm_ir::IrArray input_array(GetIrArrayFor(operand));
    602         llvm::Value* input_value_address =
    603             input_array.EmitArrayElementAddress(input_index, &ir_builder_);
    604         llvm::Value* result = EmitElementFunctionCall(
    605             reducer_function, reduce_window->shape(),
    606             {accumulator_address, input_value_address}, "reducer_function");
    607         ir_builder_.CreateStore(result, accumulator_address);
    608 
    609         SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_);
    610         return ir_builder_.CreateLoad(accumulator_address);
    611       });
    612 }
    613 
    614 Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
    615   CHECK_EQ(select_and_scatter->operand_count(), 3);
    616   const auto operand = select_and_scatter->operand(0);
    617   const auto source = select_and_scatter->operand(1);
    618   const auto init_value = select_and_scatter->operand(2);
    619   const Window& window = select_and_scatter->window();
    620   PrimitiveType operand_element_type = operand->shape().element_type();
    621   const int64 rank = ShapeUtil::Rank(operand->shape());
    622   CHECK_EQ(rank, ShapeUtil::Rank(source->shape()));
    623   CHECK_EQ(rank, window.dimensions_size());
    624 
    625   // TODO(b/31410564): Implement dilation for select-and-scatter.
    626   if (window_util::HasDilation(window)) {
    627     return Unimplemented(
    628         "Dilation for SelectAndScatter is not implemented on CPU. ");
    629   }
    630 
    631   // The select and scatter computations should have been emitted previously.
    632   llvm::Function* select_function =
    633       FindOrDie(emitted_functions_, select_and_scatter->select());
    634   llvm::Function* scatter_function =
    635       FindOrDie(emitted_functions_, select_and_scatter->scatter());
    636 
    637   // Pseudo code for select-and-scatter:
    638   //
    639   // initialized_flag is initially off for every window, and is turned on after
    640   // the first iteration is completed and the first operand value is selected.
    641   //
    642   // output(*) = init_value
    643   // for (coordinates S in the source) {
    644   //   initialized_flag = false
    645   //   for (coordinates W in the window) {
    646   //     I = S * stride + W - pad_low
    647   //     if I within bounds of operand:
    648   //       if !initialized_flag or select(selected_value, operand(I)) == false:
    649   //         selected_value = operand(I)
    650   //         selected_index = I
    651   //         initialized_flag = true
    652   //   }
    653   //   output(selected_index) = scatter(output(selected_index), source(S))
    654   // }
    655   //
    656 
    657   // Initialize the output array with the given init_value.
    658   TF_RETURN_IF_ERROR(EmitTargetElementLoop(
    659       select_and_scatter, /*desc=*/IrName(select_and_scatter, "init"),
    660       [this, init_value](const llvm_ir::IrArray::Index& target_index) {
    661         llvm::Value* init_value_addr = GetEmittedValueFor(init_value);
    662         return ir_builder_.CreateLoad(init_value_addr);
    663       }));
    664 
    665   // Create a loop to iterate over the source array to scatter to the output.
    666   llvm_ir::ForLoopNest source_loops(IrName(select_and_scatter), &ir_builder_);
    667   const llvm_ir::IrArray::Index source_index =
    668       source_loops.AddLoopsForShape(source->shape(), "source");
    669   SetToFirstInsertPoint(source_loops.GetInnerLoopBodyBasicBlock(),
    670                         &ir_builder_);
    671 
    672   // Allocate space to keep the currently selected value, its index, and
    673   // the boolean initialized_flag, which is initially set to false.
    674   llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry(
    675       llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
    676       "selected_value_address", &ir_builder_,
    677       MinimumAlignmentForPrimitiveType(operand_element_type));
    678   llvm::Value* selected_index_address =
    679       llvm_ir::EmitAllocaAtFunctionEntryWithCount(
    680           ir_builder_.getInt64Ty(), ir_builder_.getInt32(rank),
    681           "selected_index_address", &ir_builder_);
    682   llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry(
    683       ir_builder_.getInt1Ty(), "initialized_flag_address", &ir_builder_);
    684   ir_builder_.CreateStore(ir_builder_.getInt1(false), initialized_flag_address);
    685 
    686   // Create the inner loop to iterate over the window.
    687   llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "window"),
    688                                     &ir_builder_);
    689   std::vector<int64> window_size;
    690   for (const auto& dim : window.dimensions()) {
    691     window_size.push_back(dim.size());
    692   }
    693   const llvm_ir::IrArray::Index window_index = window_loops.AddLoopsForShape(
    694       ShapeUtil::MakeShape(operand_element_type, window_size), "window");
    695   SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(),
    696                         &ir_builder_);
    697 
    698   // Compute the operand index to visit and evaluate the condition whether the
    699   // operand index is within the bounds. The unsigned comparison includes
    700   // checking whether the operand index >= 0.
    701   llvm_ir::IrArray::Index operand_index(source_index.size());
    702   llvm::Value* in_bounds_condition = ir_builder_.getTrue();
    703   for (int64 i = 0; i < rank; ++i) {
    704     llvm::Value* strided_index = ir_builder_.CreateNSWMul(
    705         source_index[i], ir_builder_.getInt64(window.dimensions(i).stride()));
    706     operand_index[i] = ir_builder_.CreateNSWSub(
    707         ir_builder_.CreateNSWAdd(strided_index, window_index[i]),
    708         ir_builder_.getInt64(window.dimensions(i).padding_low()));
    709     llvm::Value* index_condition = ir_builder_.CreateICmpULT(
    710         operand_index[i],
    711         ir_builder_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
    712     in_bounds_condition =
    713         ir_builder_.CreateAnd(in_bounds_condition, index_condition);
    714   }
    715   CHECK(in_bounds_condition != nullptr);
    716 
    717   // Only need to do something if the operand index is within the bounds. First
    718   // check if the initialized_flag is set.
    719   llvm_ir::LlvmIfData if_in_bounds =
    720       llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &ir_builder_);
    721   SetToFirstInsertPoint(if_in_bounds.true_block, &ir_builder_);
    722   llvm_ir::LlvmIfData if_initialized =
    723       llvm_ir::EmitIfThenElse(ir_builder_.CreateLoad(initialized_flag_address),
    724                               "initialized", &ir_builder_);
    725 
    726   // If the initialized_flag is false, initialize the selected value and index
    727   // with the currently visiting operand.
    728   SetToFirstInsertPoint(if_initialized.false_block, &ir_builder_);
    729   const auto save_operand_index =
    730       [&](const llvm_ir::IrArray::Index& operand_index) {
    731         for (int64 i = 0; i < rank; ++i) {
    732           llvm::Value* selected_index_address_slot =
    733               ir_builder_.CreateInBoundsGEP(selected_index_address,
    734                                             {ir_builder_.getInt32(i)});
    735           ir_builder_.CreateStore(operand_index[i],
    736                                   selected_index_address_slot);
    737         }
    738       };
    739   llvm_ir::IrArray operand_array(GetIrArrayFor(operand));
    740   llvm::Value* operand_data =
    741       operand_array.EmitReadArrayElement(operand_index, &ir_builder_);
    742   ir_builder_.CreateStore(operand_data, selected_value_address);
    743   save_operand_index(operand_index);
    744   ir_builder_.CreateStore(ir_builder_.getInt1(true), initialized_flag_address);
    745 
    746   // If the initialized_flag is true, call the `select` function to potentially
    747   // update the selected value and index with the currently visiting operand.
    748   SetToFirstInsertPoint(if_initialized.true_block, &ir_builder_);
    749   const Shape output_shape = ShapeUtil::MakeShape(PRED, {});
    750   llvm::Value* operand_address =
    751       operand_array.EmitArrayElementAddress(operand_index, &ir_builder_);
    752   llvm::Value* result = EmitElementFunctionCall(
    753       select_function, output_shape, {selected_value_address, operand_address},
    754       "select_function");
    755 
    756   // If the 'select' function returns false, update the selected value and the
    757   // index to the currently visiting operand.
    758   llvm::Value* cond = ir_builder_.CreateICmpNE(
    759       result,
    760       llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
    761       "boolean_predicate");
    762   llvm_ir::LlvmIfData if_select_lhs =
    763       llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &ir_builder_);
    764   SetToFirstInsertPoint(if_select_lhs.false_block, &ir_builder_);
    765   ir_builder_.CreateStore(ir_builder_.CreateLoad(operand_address),
    766                           selected_value_address);
    767   save_operand_index(operand_index);
    768 
    769   // After iterating over the window elements, scatter the source element to
    770   // the selected index of the output. The value we store at the output
    771   // location is computed by calling the `scatter` function with the source
    772   // value and the current output value.
    773   SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(),
    774                         &ir_builder_);
    775   llvm_ir::IrArray::Index selected_index;
    776   for (int64 i = 0; i < rank; ++i) {
    777     llvm::Value* selected_index_address_slot = ir_builder_.CreateInBoundsGEP(
    778         selected_index_address, {ir_builder_.getInt32(i)});
    779     selected_index.push_back(
    780         ir_builder_.CreateLoad(selected_index_address_slot));
    781   }
    782   llvm_ir::IrArray source_array(GetIrArrayFor(source));
    783   llvm::Value* source_value_address =
    784       source_array.EmitArrayElementAddress(source_index, &ir_builder_);
    785   llvm_ir::IrArray output_array(GetIrArrayFor(select_and_scatter));
    786   llvm::Value* output_value_address =
    787       output_array.EmitArrayElementAddress(selected_index, &ir_builder_);
    788   llvm::Value* scatter_value = EmitElementFunctionCall(
    789       scatter_function, source->shape(),
    790       {output_value_address, source_value_address}, "scatter_function");
    791   output_array.EmitWriteArrayElement(selected_index, scatter_value,
    792                                      &ir_builder_);
    793 
    794   SetToFirstInsertPoint(source_loops.GetOuterLoopExitBasicBlock(),
    795                         &ir_builder_);
    796   return Status::OK();
    797 }
    798 
    799 Status IrEmitter::HandleDot(HloInstruction* dot) {
    800   auto lhs = dot->operand(0);
    801   auto rhs = dot->operand(1);
    802   TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
    803       /*instruction=*/*dot, /*operands=*/{lhs, rhs},
    804       /*supported_types=*/{F16, F32, F64, C64}));
    805   const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
    806   if (dnums.lhs_batch_dimensions_size() > 0 ||
    807       dnums.rhs_batch_dimensions_size() > 0) {
    808     return Unimplemented("Dot with batch dimensions not implemented.");
    809   }
    810 
    811   if (dnums.lhs_contracting_dimensions_size() != 1) {
    812     // This is disallowed by ShapeInference today.
    813     return Unimplemented(
    814         "Dot with multiple contracting dimensions not implemented.");
    815   }
    816 
    817   if (dnums.lhs_contracting_dimensions(0) !=
    818           std::min(lhs->shape().dimensions_size() - 1, 1) ||
    819       dnums.rhs_contracting_dimensions(0) != 0) {
    820     return Unimplemented(
    821         "Dot with non-standard contracting dimensions not implemented.");
    822   }
    823 
    824   llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs));
    825   llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs));
    826 
    827   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dot));
    828   llvm_ir::IrArray target_array = GetIrArrayFor(dot);
    829 
    830   VLOG(2) << "HandleDot: ";
    831   VLOG(2) << "  lhs operand: "
    832           << llvm_ir::DumpToString(*lhs_array.GetBasePointer());
    833   VLOG(2) << "  rhs operand: "
    834           << llvm_ir::DumpToString(*rhs_array.GetBasePointer());
    835   VLOG(2) << "  target: "
    836           << llvm_ir::DumpToString(*target_array.GetBasePointer());
    837 
    838   // Dot operation is complicated so we delegate to a helper class.
    839   return DotOpEmitter::EmitDotOperation(
    840       *dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array,
    841       lhs_array, rhs_array, /*addend_array=*/nullptr,
    842       GetExecutableRunOptionsArgument(), &ir_builder_, hlo_module_config_,
    843       target_machine_features_);
    844 }
    845 
    846 Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
    847   auto lhs = convolution->operand(0);
    848   auto rhs = convolution->operand(1);
    849   const auto& window = convolution->window();
    850   TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
    851       /*instruction=*/*convolution, /*operands=*/{lhs, rhs},
    852       /*supported_types=*/{F16, F32, C64}));
    853 
    854   const ConvolutionDimensionNumbers& dnums =
    855       convolution->convolution_dimension_numbers();
    856 
    857   if (PotentiallyImplementedAsEigenConvolution(*convolution)) {
    858     const Shape& lhs_shape = lhs->shape();
    859     const Shape& rhs_shape = rhs->shape();
    860     const Shape& convolution_shape = convolution->shape();
    861     // The input, kernel and output agree with respect to layout.
    862     if (LayoutUtil::IsMonotonicWithDim0Major(lhs_shape.layout()) &&
    863         LayoutUtil::IsMonotonicWithDim0Major(rhs_shape.layout()) &&
    864         LayoutUtil::IsMonotonicWithDim0Major(convolution_shape.layout())) {
    865       // We lower 1D convolutions into calls to the same Eigen function as 2D
    866       // convolutions, except that we pretend that the 1D convolution is really
    867       // a 2D convolution with the missing dimension set to 1.  We also adjust
    868       // the padding, dilation parameters as needed.
    869       bool one_dim_convolution = lhs_shape.dimensions_size() == 3;
    870       llvm::Value* lhs_address = GetEmittedValueFor(lhs);
    871       llvm::Value* rhs_address = GetEmittedValueFor(rhs);
    872       TF_RETURN_IF_ERROR(EmitTargetAddressForOp(convolution));
    873 
    874       const ConvolutionDimensionNumbers& dnums =
    875           convolution->convolution_dimension_numbers();
    876 
    877       // Input tensor.
    878       const Shape& input_shape = convolution->operand(0)->shape();
    879       int64 input_batch = input_shape.dimensions(dnums.input_batch_dimension());
    880       int64 input_rows =
    881           input_shape.dimensions(dnums.input_spatial_dimensions(0));
    882       int64 input_cols =
    883           one_dim_convolution
    884               ? 1
    885               : input_shape.dimensions(dnums.input_spatial_dimensions(1));
    886       int64 input_channels =
    887           input_shape.dimensions(dnums.input_feature_dimension());
    888 
    889       // Kernel tensor.
    890       const Shape& kernel_shape = convolution->operand(1)->shape();
    891       int64 kernel_rows =
    892           kernel_shape.dimensions(dnums.kernel_spatial_dimensions(0));
    893       int64 kernel_cols =
    894           one_dim_convolution
    895               ? 1
    896               : kernel_shape.dimensions(dnums.kernel_spatial_dimensions(1));
    897       int64 kernel_channels =
    898           kernel_shape.dimensions(dnums.kernel_input_feature_dimension());
    899       int64 kernel_filters =
    900           kernel_shape.dimensions(dnums.kernel_output_feature_dimension());
    901 
    902       // Output tensor.
    903       const Shape& convolution_shape = convolution->shape();
    904       int64 output_rows =
    905           convolution_shape.dimensions(dnums.output_spatial_dimensions(0));
    906       int64 output_cols = one_dim_convolution
    907                               ? 1
    908                               : convolution_shape.dimensions(
    909                                     dnums.output_spatial_dimensions(1));
    910 
    911       // Extract the window stride for the convolution.
    912       const Window& window = convolution->window();
    913       int64 row_stride = window.dimensions(0).stride();
    914       int64 col_stride =
    915           one_dim_convolution ? 1 : window.dimensions(1).stride();
    916 
    917       int64 padding_top = window.dimensions(0).padding_low();
    918       int64 padding_bottom = window.dimensions(0).padding_high();
    919       int64 padding_left =
    920           one_dim_convolution ? 0 : window.dimensions(1).padding_low();
    921       int64 padding_right =
    922           one_dim_convolution ? 0 : window.dimensions(1).padding_high();
    923 
    924       int64 lhs_row_dilation = window.dimensions(0).base_dilation();
    925       int64 lhs_col_dilation =
    926           one_dim_convolution ? 1 : window.dimensions(1).base_dilation();
    927       int64 rhs_row_dilation = window.dimensions(0).window_dilation();
    928       int64 rhs_col_dilation =
    929           one_dim_convolution ? 1 : window.dimensions(1).window_dilation();
    930 
    931       PrimitiveType primitive_type = lhs->shape().element_type();
    932       llvm::Type* ir_ptr_type = primitive_type == F16
    933                                     ? ir_builder_.getHalfTy()->getPointerTo()
    934                                     : ir_builder_.getFloatTy()->getPointerTo();
    935       llvm::Type* int64_type = ir_builder_.getInt64Ty();
    936       llvm::Type* int8_ptr_type = ir_builder_.getInt8Ty()->getPointerTo();
    937       llvm::FunctionType* conv_type = llvm::FunctionType::get(
    938           ir_builder_.getVoidTy(),
    939           {int8_ptr_type, ir_ptr_type, ir_ptr_type, ir_ptr_type, int64_type,
    940            int64_type,    int64_type,  int64_type,  int64_type,  int64_type,
    941            int64_type,    int64_type,  int64_type,  int64_type,  int64_type,
    942            int64_type,    int64_type,  int64_type,  int64_type,  int64_type,
    943            int64_type,    int64_type,  int64_type,  int64_type},
    944           /*isVarArg=*/false);
    945       bool multi_threaded_eigen =
    946           hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen();
    947       const char* fn_name =
    948           primitive_type == F16
    949               ? (multi_threaded_eigen
    950                      ? runtime::kEigenConvF16SymbolName
    951                      : runtime::kEigenSingleThreadedConvF16SymbolName)
    952               : (multi_threaded_eigen
    953                      ? runtime::kEigenConvF32SymbolName
    954                      : runtime::kEigenSingleThreadedConvF32SymbolName);
    955       llvm::Function* conv_func = llvm::cast<llvm::Function>(
    956           module_->getOrInsertFunction(fn_name, conv_type));
    957       conv_func->setCallingConv(llvm::CallingConv::C);
    958       conv_func->setDoesNotThrow();
    959       conv_func->setOnlyAccessesArgMemory();
    960       ir_builder_.CreateCall(
    961           conv_func, {
    962                          GetExecutableRunOptionsArgument(),
    963                          ir_builder_.CreateBitCast(
    964                              GetEmittedValueFor(convolution), ir_ptr_type),
    965                          ir_builder_.CreateBitCast(lhs_address, ir_ptr_type),
    966                          ir_builder_.CreateBitCast(rhs_address, ir_ptr_type),
    967                          ir_builder_.getInt64(input_batch),
    968                          ir_builder_.getInt64(input_rows),
    969                          ir_builder_.getInt64(input_cols),
    970                          ir_builder_.getInt64(input_channels),
    971                          ir_builder_.getInt64(kernel_rows),
    972                          ir_builder_.getInt64(kernel_cols),
    973                          ir_builder_.getInt64(kernel_channels),
    974                          ir_builder_.getInt64(kernel_filters),
    975                          ir_builder_.getInt64(output_rows),
    976                          ir_builder_.getInt64(output_cols),
    977                          ir_builder_.getInt64(row_stride),
    978                          ir_builder_.getInt64(col_stride),
    979                          ir_builder_.getInt64(padding_top),
    980                          ir_builder_.getInt64(padding_bottom),
    981                          ir_builder_.getInt64(padding_left),
    982                          ir_builder_.getInt64(padding_right),
    983                          ir_builder_.getInt64(lhs_row_dilation),
    984                          ir_builder_.getInt64(lhs_col_dilation),
    985                          ir_builder_.getInt64(rhs_row_dilation),
    986                          ir_builder_.getInt64(rhs_col_dilation),
    987                      });
    988 
    989       return Status::OK();
    990     }
    991   }
    992 
    993   // This is a completely un-optimized version of convolution just to
    994   // have an early version that works. E.g. the input index and
    995   // padding calculation is not hoisted out of the inner loop.
    996   //
    997   // See the description of convolution in the XLA documentation for the pseudo
    998   // code for convolution.
    999   return EmitTargetElementLoop(
   1000       convolution, [this, convolution, lhs, rhs, window,
   1001                     dnums](const llvm_ir::IrArray::Index& index) {
   1002         int num_spatial_dims = dnums.output_spatial_dimensions_size();
   1003         std::vector<llvm::Value*> output_spatial(num_spatial_dims);
   1004         for (int i = 0; i < num_spatial_dims; ++i) {
   1005           output_spatial[i] = index[dnums.output_spatial_dimensions(i)];
   1006         }
   1007         llvm::Value* output_feature = index[dnums.output_feature_dimension()];
   1008         llvm::Value* batch = index[dnums.output_batch_dimension()];
   1009 
   1010         // We will accumulate the products into this sum to calculate
   1011         // the output entry at the given index.
   1012         PrimitiveType lhs_element_type = lhs->shape().element_type();
   1013         llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry(
   1014             llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_),
   1015             "convolution_sum_address", &ir_builder_,
   1016             MinimumAlignmentForPrimitiveType(lhs_element_type));
   1017         ir_builder_.CreateStore(
   1018             llvm::ConstantFP::get(ir_builder_.getFloatTy(), 0.0), sum_address);
   1019 
   1020         llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &ir_builder_);
   1021         std::vector<llvm::Value*> kernel_spatial(num_spatial_dims);
   1022         for (int i = 0; i < num_spatial_dims; ++i) {
   1023           kernel_spatial[i] =
   1024               loops
   1025                   .AddLoop(0,
   1026                            rhs->shape().dimensions(
   1027                                dnums.kernel_spatial_dimensions(i)),
   1028                            tensorflow::strings::StrCat("k", i))
   1029                   ->GetIndVarValue();
   1030         }
   1031         llvm::Value* input_feature =
   1032             loops
   1033                 .AddLoop(
   1034                     0, lhs->shape().dimensions(dnums.input_feature_dimension()),
   1035                     "iz")
   1036                 ->GetIndVarValue();
   1037 
   1038         SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_);
   1039 
   1040         // Calculate the spatial index in the input array, taking striding,
   1041         // dilation and padding into account. An index in the padding will be
   1042         // out of the bounds of the array.
   1043         const auto calculate_input_index =
   1044             [this](llvm::Value* output_index, llvm::Value* kernel_index,
   1045                    const WindowDimension& window_dim) {
   1046               llvm::Value* strided_index = ir_builder_.CreateNSWMul(
   1047                   output_index, ir_builder_.getInt64(window_dim.stride()));
   1048               llvm::Value* dilated_kernel_index = ir_builder_.CreateNSWMul(
   1049                   kernel_index,
   1050                   ir_builder_.getInt64(window_dim.window_dilation()));
   1051               return ir_builder_.CreateNSWSub(
   1052                   ir_builder_.CreateNSWAdd(strided_index, dilated_kernel_index),
   1053                   ir_builder_.getInt64(window_dim.padding_low()));
   1054             };
   1055         std::vector<llvm::Value*> input_spatial(num_spatial_dims);
   1056         for (int i = 0; i < num_spatial_dims; ++i) {
   1057           input_spatial[i] = calculate_input_index(
   1058               output_spatial[i], kernel_spatial[i], window.dimensions(i));
   1059         }
   1060 
   1061         // We need to check if 0 <= input dim < bound, as otherwise we are in
   1062         // the padding so that we can skip the computation. That is equivalent
   1063         // to input dim < bound as an *unsigned* comparison, since a negative
   1064         // value will wrap to a large positive value. The input dim is dilated,
   1065         // so we need to dilate the bound as well to match.
   1066 
   1067         // Also need to check that the input coordinates are not in one of the
   1068         // holes created by base dilation.
   1069         const auto not_in_hole = [&](llvm::Value* input_index,
   1070                                      int64 base_dilation) {
   1071           llvm::Value* remainder = ir_builder_.CreateSRem(
   1072               input_index, ir_builder_.getInt64(base_dilation));
   1073           return ir_builder_.CreateICmpEQ(remainder, ir_builder_.getInt64(0));
   1074         };
   1075 
   1076         llvm::Value* in_bounds_condition = ir_builder_.getInt1(true);
   1077         for (int i = 0; i < num_spatial_dims; ++i) {
   1078           llvm::ConstantInt* input_bound =
   1079               ir_builder_.getInt64(window_util::DilatedBound(
   1080                   lhs->shape().dimensions(dnums.input_spatial_dimensions(i)),
   1081                   window.dimensions(i).base_dilation()));
   1082           llvm::Value* dim_in_bound =
   1083               ir_builder_.CreateICmpULT(input_spatial[i], input_bound);
   1084           llvm::Value* dim_not_in_hole = not_in_hole(
   1085               input_spatial[i], window.dimensions(i).base_dilation());
   1086           llvm::Value* dim_ok =
   1087               ir_builder_.CreateAnd(dim_in_bound, dim_not_in_hole);
   1088           in_bounds_condition =
   1089               ir_builder_.CreateAnd(in_bounds_condition, dim_ok);
   1090         }
   1091 
   1092         // Now we need to map the dilated base coordinates back to the actual
   1093         // data indices on the lhs.
   1094         const auto undilate = [&](llvm::Value* input_index,
   1095                                   int64 base_dilation) {
   1096           return ir_builder_.CreateSDiv(input_index,
   1097                                         ir_builder_.getInt64(base_dilation));
   1098         };
   1099         for (int i = 0; i < num_spatial_dims; ++i) {
   1100           input_spatial[i] =
   1101               undilate(input_spatial[i], window.dimensions(i).base_dilation());
   1102         }
   1103 
   1104         llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
   1105             in_bounds_condition, "in-bounds", &ir_builder_);
   1106         SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
   1107 
   1108         // We are not in the padding, so carry out the computation.
   1109         int num_dims = num_spatial_dims + 2;
   1110         llvm_ir::IrArray::Index input_index(num_dims);
   1111         for (int i = 0; i < num_spatial_dims; ++i) {
   1112           input_index[dnums.input_spatial_dimensions(i)] = input_spatial[i];
   1113         }
   1114         input_index[dnums.input_feature_dimension()] = input_feature;
   1115         input_index[dnums.input_batch_dimension()] = batch;
   1116 
   1117         llvm_ir::IrArray kernel_array(GetIrArrayFor(rhs));
   1118         llvm_ir::IrArray::Index kernel_index(num_dims);
   1119         for (int i = 0; i < num_spatial_dims; ++i) {
   1120           kernel_index[dnums.kernel_spatial_dimensions(i)] =
   1121               window.dimensions(i).window_reversal()
   1122                   ? ir_builder_.CreateNSWSub(
   1123                         ir_builder_.getInt64(window.dimensions(i).size() - 1),
   1124                         kernel_spatial[i])
   1125                   : kernel_spatial[i];
   1126         }
   1127 
   1128         kernel_index[dnums.kernel_input_feature_dimension()] = input_feature;
   1129         kernel_index[dnums.kernel_output_feature_dimension()] = output_feature;
   1130 
   1131         llvm_ir::IrArray input_array(GetIrArrayFor(lhs));
   1132         llvm::Value* product = ir_builder_.CreateFMul(
   1133             input_array.EmitReadArrayElement(input_index, &ir_builder_),
   1134             kernel_array.EmitReadArrayElement(kernel_index, &ir_builder_));
   1135         llvm::Value* sum = ir_builder_.CreateFAdd(
   1136             ir_builder_.CreateLoad(sum_address), product);
   1137         ir_builder_.CreateStore(sum, sum_address);
   1138 
   1139         SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_);
   1140         return ir_builder_.CreateLoad(sum_address);
   1141       });
   1142 }
   1143 
   1144 Status IrEmitter::HandleFft(HloInstruction* fft) {
   1145   auto operand = fft->operand(0);
   1146   TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
   1147       /*instruction=*/*fft, /*operands=*/{operand},
   1148       /*supported_types=*/{F32, C64}));
   1149   TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(operand->shape().layout()));
   1150   TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(fft->shape().layout()));
   1151   VLOG(3) << "operand=" << ShapeUtil::HumanStringWithLayout(operand->shape());
   1152   VLOG(3) << "fft=" << ShapeUtil::HumanStringWithLayout(fft->shape());
   1153 
   1154   llvm::Value* operand_address = GetEmittedValueFor(operand);
   1155   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fft));
   1156 
   1157   const std::vector<int64>& fft_length = fft->fft_length();
   1158   int64 input_batch = 1;
   1159   for (int i = 0; i < fft->shape().dimensions_size() - fft_length.size(); i++) {
   1160     input_batch *= fft->shape().dimensions(i);
   1161   }
   1162 
   1163   // Args have been computed, make the call.
   1164   llvm::Type* int8_ptr_type = ir_builder_.getInt8Ty()->getPointerTo();
   1165   llvm::Type* int32_type = ir_builder_.getInt32Ty();
   1166   llvm::Type* int64_type = ir_builder_.getInt64Ty();
   1167   llvm::FunctionType* fft_type = llvm::FunctionType::get(
   1168       ir_builder_.getVoidTy(),
   1169       {int8_ptr_type, int8_ptr_type, int8_ptr_type, int32_type, int32_type,
   1170        int64_type, int64_type, int64_type, int64_type},
   1171       /*isVarArg=*/false);
   1172   const char* fn_name = runtime::kEigenFftSymbolName;
   1173   llvm::Function* fft_func = llvm::cast<llvm::Function>(
   1174       module_->getOrInsertFunction(fn_name, fft_type));
   1175   fft_func->setCallingConv(llvm::CallingConv::C);
   1176   fft_func->setDoesNotThrow();
   1177   fft_func->setOnlyAccessesInaccessibleMemOrArgMem();
   1178   const int fft_rank = fft_length.size();
   1179   ir_builder_.CreateCall(
   1180       fft_func,
   1181       {GetExecutableRunOptionsArgument(),
   1182        ir_builder_.CreateBitCast(GetEmittedValueFor(fft), int8_ptr_type),
   1183        ir_builder_.CreateBitCast(operand_address, int8_ptr_type),
   1184        ir_builder_.getInt32(fft->fft_type()), ir_builder_.getInt32(fft_rank),
   1185        ir_builder_.getInt64(input_batch),
   1186        ir_builder_.getInt64(fft_rank > 0 ? fft_length[0] : 0),
   1187        ir_builder_.getInt64(fft_rank > 1 ? fft_length[1] : 0),
   1188        ir_builder_.getInt64(fft_rank > 2 ? fft_length[2] : 0)});
   1189 
   1190   return Status::OK();
   1191 }
   1192 
   1193 Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) {
   1194   if (hlo_module_config_.replica_count() == 1) {
   1195     // When there is a single replica, a cross replica sum is the identity
   1196     // function, and the buffer assignment expects a copy (we could eliminate
   1197     // these at the HLO level as an optimization).
   1198     TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs));
   1199     return EmitMemcpy(*crs->operand(0), *crs);
   1200   }
   1201 
   1202   // TODO(b/33011107): Support cross replica sum on CPU.
   1203   return Unimplemented("CrossReplicaSum is not implemented on CPU.");
   1204 }
   1205 
   1206 // Fills up the free variables in 'index_with_free_var' with values from
   1207 // 'filler_index'. The size of free variables must be the same as the
   1208 // size of 'filler_index'.
   1209 //
   1210 // This is often used after dimension reduction, where
   1211 // 'index_with_free_var' has one or more dimensions reduced, which serves as
   1212 // free variables (represented as nullptr). For example, if we have a 4
   1213 // dimensional input and index for the dimension being reduced is
   1214 // 2 (third dimension), we will have an index like [i, j, NULL, k]
   1215 // after reduced dimension.
   1216 //
   1217 // Here we fill up that free variable by 'filler_index', which contains
   1218 // the value in the reduced dimension.
   1219 static llvm_ir::IrArray::Index FillReducedDimensionIndex(
   1220     llvm_ir::IrArray::Index index_with_free_var,
   1221     llvm_ir::IrArray::Index filler_index) {
   1222   llvm_ir::IrArray::Index::const_iterator it = filler_index.begin();
   1223 
   1224   for (size_t i = 0; i < index_with_free_var.size(); ++i) {
   1225     if (index_with_free_var[i] == nullptr) {
   1226       index_with_free_var[i] = *it++;
   1227     }
   1228   }
   1229   CHECK(filler_index.end() == it);
   1230   return index_with_free_var;
   1231 }
   1232 
   1233 Status IrEmitter::HandleParameter(HloInstruction* parameter) {
   1234   VLOG(2) << "HandleParameter: " << parameter->ToString();
   1235   auto param_number = parameter->parameter_number();
   1236   auto param_shape = parameter->shape();
   1237 
   1238   // We have to access the parameter at offset param_number in the params
   1239   // array. The code generated here is equivalent to this C code:
   1240   //
   1241   //   i8* param_address_untyped = params[param_number];
   1242   //   Param* param_address_typed = (Param*)param_address_untyped;
   1243   //
   1244   // Where Param is the actual element type of the underlying buffer (for
   1245   // example, float for an XLA F32 element type).
   1246   llvm::Value* params = compute_function_->parameters_arg();
   1247   llvm::Value* param_address_offset =
   1248       llvm_ir::EmitBufferIndexingGEP(params, param_number, &ir_builder_);
   1249   llvm::LoadInst* param_address_untyped =
   1250       ir_builder_.CreateLoad(param_address_offset);
   1251   param_address_untyped->setName(AsStringRef(IrName(parameter, "untyped")));
   1252   if (is_top_level_computation_ &&
   1253       hlo_module_config_.debug_options()
   1254           .xla_llvm_enable_invariant_load_metadata()) {
   1255     // In the entry computation the parameter slots in the %params argument are
   1256     // invariant through program execution.  In computations that are called
   1257     // from the entry computation (via kWhile, kCall and kConditional) the
   1258     // parameter slots are *not* invariant since they're written to by their
   1259     // callers.
   1260     param_address_untyped->setMetadata(
   1261         llvm::LLVMContext::MD_invariant_load,
   1262         llvm::MDNode::get(param_address_untyped->getContext(), /*MDs=*/{}));
   1263   }
   1264 
   1265   llvm::Value* param_address_typed = ir_builder_.CreateBitCast(
   1266       param_address_untyped, IrShapeType(param_shape)->getPointerTo());
   1267   emitted_value_[parameter] = param_address_typed;
   1268 
   1269   if (!ShapeUtil::IsOpaque(param_shape)) {
   1270     AttachAlignmentMetadataForLoad(param_address_untyped, param_shape);
   1271     AttachDereferenceableMetadataForLoad(param_address_untyped, param_shape);
   1272   }
   1273 
   1274   VLOG(2) << "  emitted value: " << llvm_ir::DumpToString(*param_address_typed);
   1275   return Status::OK();
   1276 }
   1277 
   1278 // Returns true if the relative order of the unreduced dimensions stays the same
   1279 // through the reduce operation.
   1280 static bool ReductionPreservesLayout(const HloInstruction& reduce) {
   1281   DCHECK_EQ(reduce.opcode(), HloOpcode::kReduce);
   1282 
   1283   // Maps dimensions that were not reduced from their dimension numbers in the
   1284   // source shape to their dimensions numbers in the destination shape.
   1285   //
   1286   // So if we reduce f32[A,B,C,D] on dimensions 1 and 2, this map contains
   1287   // [0->0, 3->1].
   1288   gtl::FlatMap<int64, int64> unreduced_dim_map;
   1289 
   1290   gtl::FlatSet<int64> reduced_dims(reduce.dimensions().begin(),
   1291                                    reduce.dimensions().end());
   1292 
   1293   const Shape& operand_shape = reduce.operand(0)->shape();
   1294   const Shape& result_shape = reduce.shape();
   1295 
   1296   int64 delta = 0;
   1297   for (int64 i = 0; i < operand_shape.dimensions_size(); i++) {
   1298     if (reduced_dims.count(i)) {
   1299       delta++;
   1300     } else {
   1301       InsertOrDie(&unreduced_dim_map, i, i - delta);
   1302     }
   1303   }
   1304 
   1305   // Iterate dimensions minor to major and check that the corresponding
   1306   // dimensions in the source and target shapes are equivalent.
   1307   int64 result_dim_idx = 0;
   1308   for (int64 operand_dim_idx = 0;
   1309        operand_dim_idx < operand_shape.dimensions_size(); operand_dim_idx++) {
   1310     int64 operand_dim = operand_shape.layout().minor_to_major(operand_dim_idx);
   1311     if (!reduced_dims.count(operand_dim)) {
   1312       if (FindOrDie(unreduced_dim_map, operand_dim) !=
   1313           result_shape.layout().minor_to_major(result_dim_idx++)) {
   1314         return false;
   1315       }
   1316     }
   1317   }
   1318 
   1319   CHECK_EQ(result_dim_idx, result_shape.dimensions_size());
   1320 
   1321   return true;
   1322 }
   1323 
   1324 IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator(
   1325     HloComputation* function, string* failure_reason) const {
   1326   CHECK_EQ(function->num_parameters(), 2);
   1327 
   1328   auto root_instruction = function->root_instruction();
   1329   CHECK(ShapeUtil::IsScalar(root_instruction->shape()));
   1330 
   1331   if (root_instruction->operand_count() != 2) {
   1332     *failure_reason = "root instruction is not a binary operation";
   1333     return nullptr;
   1334   }
   1335 
   1336   const Shape& root_shape = root_instruction->shape();
   1337   if (ShapeUtil::ElementIsComplex(root_shape)) {
   1338     // TODO(b/65408531): Complex add could by done via bitcast to <float x [2N]>
   1339     // Complex multiply would be more challenging. We could perhaps use a
   1340     // strided load to get all reals in a vector, all images in a vector, or use
   1341     // CreateShuffleVector on a bitcast to float x [2N].
   1342     *failure_reason = "complex values not supported";
   1343     return nullptr;
   1344   }
   1345   bool root_is_floating_point = ShapeUtil::ElementIsFloating(root_shape);
   1346   bool root_is_integral = ShapeUtil::ElementIsIntegral(root_shape);
   1347   bool root_is_signed = ShapeUtil::ElementIsSigned(root_shape);
   1348 
   1349   auto lhs = root_instruction->operand(0);
   1350   auto rhs = root_instruction->operand(1);
   1351 
   1352   auto param_0 = function->parameter_instruction(0);
   1353   auto param_1 = function->parameter_instruction(1);
   1354   if (!(lhs == param_0 && rhs == param_1) &&
   1355       !(rhs == param_0 && lhs == param_1)) {
   1356     *failure_reason =
   1357         "root instruction is not a binary operation on the incoming arguments";
   1358     return nullptr;
   1359   }
   1360 
   1361   CHECK(ShapeUtil::IsScalar(lhs->shape()) && ShapeUtil::IsScalar(rhs->shape()));
   1362 
   1363   // This is visually similar to ElementalIrEmitter, though conceptually we're
   1364   // doing something different here.  ElementalIrEmitter emits scalar operations
   1365   // while these emit scalar or vector operations depending on the type of the
   1366   // operands. See CreateShardedVectorType for the actual types in use here.
   1367   switch (root_instruction->opcode()) {
   1368     default:
   1369       *failure_reason = "did not recognize root instruction opcode";
   1370       return nullptr;
   1371 
   1372     case HloOpcode::kAdd:
   1373       return [root_is_integral](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs,
   1374                                 llvm::Value* rhs) {
   1375         return root_is_integral ? ir_builder->CreateAdd(lhs, rhs)
   1376                                 : ir_builder->CreateFAdd(lhs, rhs);
   1377       };
   1378 
   1379     case HloOpcode::kMultiply:
   1380       return [root_is_integral](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs,
   1381                                 llvm::Value* rhs) {
   1382         return root_is_integral ? ir_builder->CreateMul(lhs, rhs)
   1383                                 : ir_builder->CreateFMul(lhs, rhs);
   1384       };
   1385 
   1386     case HloOpcode::kAnd:
   1387       return [](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs,
   1388                 llvm::Value* rhs) { return ir_builder->CreateAnd(lhs, rhs); };
   1389 
   1390     case HloOpcode::kOr:
   1391       return [](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs,
   1392                 llvm::Value* rhs) { return ir_builder->CreateOr(lhs, rhs); };
   1393 
   1394     case HloOpcode::kMaximum:
   1395       return [root_is_floating_point, root_is_signed](
   1396                  llvm::IRBuilder<>* ir_builder, llvm::Value* lhs,
   1397                  llvm::Value* rhs) {
   1398         if (root_is_floating_point) {
   1399           return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::maxnum,
   1400                                               {lhs, rhs}, {lhs->getType()},
   1401                                               ir_builder);
   1402         }
   1403 
   1404         return ir_builder->CreateSelect(
   1405             ir_builder->CreateICmp(root_is_signed ? llvm::ICmpInst::ICMP_SGE
   1406                                                   : llvm::ICmpInst::ICMP_UGE,
   1407                                    lhs, rhs),
   1408             lhs, rhs);
   1409       };
   1410 
   1411     case HloOpcode::kMinimum:
   1412       return [root_is_floating_point, root_is_signed](
   1413                  llvm::IRBuilder<>* ir_builder, llvm::Value* lhs,
   1414                  llvm::Value* rhs) {
   1415         if (root_is_floating_point) {
   1416           return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::minnum,
   1417                                               {lhs, rhs}, {lhs->getType()},
   1418                                               ir_builder);
   1419         }
   1420 
   1421         return ir_builder->CreateSelect(
   1422             ir_builder->CreateICmp(root_is_signed ? llvm::ICmpInst::ICMP_SLE
   1423                                                   : llvm::ICmpInst::ICMP_ULE,
   1424                                    lhs, rhs),
   1425             lhs, rhs);
   1426       };
   1427   }
   1428 }
   1429 
   1430 IrEmitter::ShardedVectorType IrEmitter::CreateShardedVectorType(
   1431     PrimitiveType element_type, unsigned element_count) {
   1432   int vector_register_size_in_elements =
   1433       target_machine_features_.vector_register_byte_size(
   1434           *compute_function_->function()) /
   1435       ShapeUtil::ByteSizeOfPrimitiveType(element_type);
   1436 
   1437   ShardedVectorType sharded_vector_type;
   1438   llvm::Type* element_ir_type =
   1439       llvm_ir::PrimitiveTypeToIrType(element_type, module_);
   1440 
   1441   for (int i = 0, e = 1 + tensorflow::Log2Ceiling(element_count); i < e; i++) {
   1442     // For every power of two present in element_count, we generate one or more
   1443     // vector or scalar types.
   1444     const unsigned current_size_fragment = 1u << i;
   1445     if (!(element_count & current_size_fragment)) {
   1446       // Power of two not present in element_count.
   1447       continue;
   1448     }
   1449 
   1450     if (current_size_fragment == 1) {
   1451       // Single element, use a scalar type.
   1452       sharded_vector_type.push_back(element_ir_type);
   1453       continue;
   1454     }
   1455 
   1456     // Lower "current_size_fragment" number of elements using (as few as
   1457     // possible) vector registers.
   1458 
   1459     if (current_size_fragment >= vector_register_size_in_elements) {
   1460       auto vector_type = llvm::VectorType::get(
   1461           element_ir_type, vector_register_size_in_elements);
   1462       sharded_vector_type.insert(
   1463           sharded_vector_type.end(),
   1464           current_size_fragment / vector_register_size_in_elements,
   1465           vector_type);
   1466 
   1467       // Both current_size_fragment and vector_register_size_in_elements are
   1468       // powers of two.
   1469       CHECK_EQ(current_size_fragment % vector_register_size_in_elements, 0);
   1470       continue;
   1471     }
   1472 
   1473     // For now we assume that vector_register_size_in_elements and lower powers
   1474     // of two are all legal vector sizes (or at least can be lowered easily by
   1475     // LLVM).
   1476     sharded_vector_type.push_back(
   1477         llvm::VectorType::get(element_ir_type, current_size_fragment));
   1478   }
   1479   return sharded_vector_type;
   1480 }
   1481 
   1482 StatusOr<IrEmitter::ShardedVector>
   1483 IrEmitter::EmitInnerLoopForVectorizedReduction(
   1484     const ReductionGenerator& reduction_generator,
   1485     const llvm_ir::IrArray::Index& output_index,
   1486     const ShardedVectorType& accumulator_type, HloInstruction* init_value,
   1487     HloInstruction* arg, gtl::ArraySlice<int64> dimensions,
   1488     unsigned element_alignment) {
   1489   ShardedVector accumulator;
   1490   accumulator.reserve(accumulator_type.size());
   1491   for (auto accumulator_shard_type : accumulator_type) {
   1492     accumulator.push_back(llvm_ir::EmitAllocaAtFunctionEntry(
   1493         accumulator_shard_type, "accumulator", &ir_builder_, 0));
   1494   }
   1495 
   1496   llvm::Value* init_value_ssa =
   1497       ir_builder_.CreateLoad(GetEmittedValueFor(init_value));
   1498 
   1499   for (llvm::Value* accumulator_shard : accumulator) {
   1500     llvm::Value* initial_value;
   1501     auto shard_type = accumulator_shard->getType()->getPointerElementType();
   1502     if (auto vector_type = llvm::dyn_cast<llvm::VectorType>(shard_type)) {
   1503       initial_value = ir_builder_.CreateVectorSplat(
   1504           vector_type->getNumElements(), init_value_ssa);
   1505     } else {
   1506       initial_value = init_value_ssa;
   1507     }
   1508 
   1509     ir_builder_.CreateAlignedStore(initial_value, accumulator_shard,
   1510                                    element_alignment);
   1511   }
   1512 
   1513   llvm_ir::ForLoopNest reduction_loop_nest(IrName(arg, "vectorized_inner"),
   1514                                            &ir_builder_);
   1515   llvm_ir::IrArray::Index reduced_dims_index =
   1516       reduction_loop_nest.AddLoopsForShapeOnDimensions(arg->shape(), dimensions,
   1517                                                        "reduction_dim");
   1518 
   1519   SetToFirstInsertPoint(reduction_loop_nest.GetInnerLoopBodyBasicBlock(),
   1520                         &ir_builder_);
   1521 
   1522   llvm_ir::IrArray arg_array(GetIrArrayFor(arg));
   1523   llvm_ir::IrArray::Index input_index = reduced_dims_index;
   1524   llvm_ir::IrArray::Index::const_iterator it = output_index.begin();
   1525 
   1526   for (size_t i = 0; i < input_index.size(); ++i) {
   1527     if (input_index[i] == nullptr) {
   1528       input_index[i] = *it++;
   1529     }
   1530   }
   1531   CHECK(output_index.end() == it);
   1532 
   1533   llvm::Value* input_address = ir_builder_.CreateBitCast(
   1534       arg_array.EmitArrayElementAddress(input_index, &ir_builder_),
   1535       ir_builder_.getInt8PtrTy());
   1536 
   1537   for (int i = 0; i < accumulator.size(); i++) {
   1538     auto input_address_typed =
   1539         ir_builder_.CreateBitCast(input_address, accumulator[i]->getType());
   1540     auto current_accumulator_value =
   1541         ir_builder_.CreateAlignedLoad(accumulator[i], element_alignment);
   1542     auto addend =
   1543         ir_builder_.CreateAlignedLoad(input_address_typed, element_alignment);
   1544     arg_array.AnnotateLoadStoreInstructionWithMetadata(addend);
   1545 
   1546     auto reduced_result =
   1547         reduction_generator(&ir_builder_, current_accumulator_value, addend);
   1548     ir_builder_.CreateAlignedStore(reduced_result, accumulator[i],
   1549                                    element_alignment);
   1550 
   1551     if (i != (accumulator.size() - 1)) {
   1552       input_address = ir_builder_.CreateConstInBoundsGEP1_32(
   1553           reduced_result->getType(), input_address_typed, 1);
   1554     }
   1555   }
   1556 
   1557   SetToFirstInsertPoint(reduction_loop_nest.GetOuterLoopExitBasicBlock(),
   1558                         &ir_builder_);
   1559 
   1560   ShardedVector result_ssa;
   1561   result_ssa.reserve(accumulator.size());
   1562   for (auto accumulator_shard : accumulator) {
   1563     result_ssa.push_back(
   1564         ir_builder_.CreateAlignedLoad(accumulator_shard, element_alignment));
   1565   }
   1566   return result_ssa;
   1567 }
   1568 
   1569 void IrEmitter::EmitShardedVectorStore(
   1570     llvm::Value* store_address, const std::vector<llvm::Value*>& value_to_store,
   1571     const int alignment, const llvm_ir::IrArray& containing_array) {
   1572   for (int i = 0; i < value_to_store.size(); i++) {
   1573     auto store_address_typed = ir_builder_.CreateBitCast(
   1574         store_address,
   1575         llvm::PointerType::getUnqual(value_to_store[i]->getType()));
   1576 
   1577     auto store_instruction = ir_builder_.CreateAlignedStore(
   1578         value_to_store[i], store_address_typed, alignment);
   1579     containing_array.AnnotateLoadStoreInstructionWithMetadata(
   1580         store_instruction);
   1581 
   1582     if (i != (value_to_store.size() - 1)) {
   1583       store_address = ir_builder_.CreateConstInBoundsGEP1_32(
   1584           value_to_store[i]->getType(), store_address_typed, 1);
   1585     }
   1586   }
   1587 }
   1588 
   1589 StatusOr<bool> IrEmitter::EmitVectorizedReduce(
   1590     HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value,
   1591     gtl::ArraySlice<int64> dimensions, HloComputation* function,
   1592     string* failure_reason) {
   1593   if (!ReductionPreservesLayout(*reduce)) {
   1594     return false;
   1595   }
   1596 
   1597   ReductionGenerator reduction_generator =
   1598       MatchReductionGenerator(function, failure_reason);
   1599   if (!reduction_generator) {
   1600     return false;
   1601   }
   1602 
   1603   int vectorization_factor_in_bytes =
   1604       target_machine_features_.vectorization_factor_in_bytes();
   1605 
   1606   // We try to process vectorization_factor elements at the same time.
   1607   const int vectorization_factor =
   1608       vectorization_factor_in_bytes /
   1609       ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type());
   1610 
   1611   bool is_reduction_over_minor_dimension =
   1612       std::find(dimensions.begin(), dimensions.end(),
   1613                 LayoutUtil::Minor(arg->shape().layout(), 0)) !=
   1614       dimensions.end();
   1615 
   1616   unsigned element_alignment = tensorflow::MathUtil::GCD<unsigned>(
   1617       ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()),
   1618       MinimumAlignmentForPrimitiveType(reduce->shape().element_type()));
   1619 
   1620   if (is_reduction_over_minor_dimension) {
   1621     // TODO(sanjoy): Implement vectorized reduction over the minor dimension.
   1622     *failure_reason = "reduction over minor dimension not implemented";
   1623     return false;
   1624   }
   1625 
   1626   CHECK(!ShapeUtil::IsTuple(reduce->shape()));
   1627   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(reduce));
   1628 
   1629   // We know we're not reducing over the most minor dimension, which means we
   1630   // can lower the reduction loop as:
   1631   //
   1632   //  1. We're reducing over dimensions R0, R1.
   1633   //  2. D0 is the most minor dimension.
   1634   //  3. VS is the vectorization stride (we want to reduce this many elements at
   1635   //     once)
   1636   //
   1637   //  for (d1 in D1) {
   1638   //    for (d0 in D0 with stride VS) {
   1639   //      vector_acc = init
   1640   //      for (r1 in R1) {
   1641   //        for (r0 in R0) {
   1642   //          vector_acc = elementwise_reduce(vector_acc, input[d1, d0, r1, r0]
   1643   //        }
   1644   //      }
   1645   //      output[d1, d0] = vector_acc
   1646   //    }
   1647   //  }
   1648 
   1649   llvm_ir::ForLoopNest loop_nest(IrName(reduce), &ir_builder_);
   1650   llvm_ir::IrArray::Index array_index(reduce->shape().dimensions_size());
   1651   for (int i = LayoutUtil::MinorToMajor(reduce->shape()).size() - 1; i > 0;
   1652        --i) {
   1653     int64 dimension = LayoutUtil::Minor(reduce->shape().layout(), i);
   1654     int64 start_index = 0;
   1655     int64 end_index = reduce->shape().dimensions(dimension);
   1656     std::unique_ptr<llvm_ir::ForLoop> loop =
   1657         loop_nest.AddLoop(start_index, end_index,
   1658                           tensorflow::strings::Printf("dim.%lld", dimension));
   1659     array_index[dimension] = loop->GetIndVarValue();
   1660   }
   1661 
   1662   int64 innermost_dimension = LayoutUtil::Minor(reduce->shape().layout(), 0);
   1663   int64 innermost_dimension_size =
   1664       reduce->shape().dimensions(innermost_dimension);
   1665 
   1666   if (llvm::BasicBlock* innermost_body_bb =
   1667           loop_nest.GetInnerLoopBodyBasicBlock()) {
   1668     SetToFirstInsertPoint(innermost_body_bb, &ir_builder_);
   1669   }
   1670 
   1671   auto outermost_loop_exit_block = loop_nest.GetOuterLoopExitBasicBlock();
   1672 
   1673   if (innermost_dimension_size >= vectorization_factor) {
   1674     int64 start_index = 0;
   1675     int64 end_index = (innermost_dimension_size / vectorization_factor) *
   1676                       vectorization_factor;
   1677     std::unique_ptr<llvm_ir::ForLoop> loop = loop_nest.AddLoop(
   1678         start_index, end_index, vectorization_factor,
   1679         tensorflow::strings::Printf("dim.%lld", innermost_dimension));
   1680     array_index[innermost_dimension] = loop->GetIndVarValue();
   1681 
   1682     SetToFirstInsertPoint(loop->GetBodyBasicBlock(), &ir_builder_);
   1683 
   1684     ShardedVectorType vector_type = CreateShardedVectorType(
   1685         reduce->shape().element_type(), vectorization_factor);
   1686     TF_ASSIGN_OR_RETURN(std::vector<llvm::Value*> accumulator,
   1687                         EmitInnerLoopForVectorizedReduction(
   1688                             reduction_generator, array_index, vector_type,
   1689                             init_value, arg, dimensions, element_alignment));
   1690 
   1691     llvm_ir::IrArray target_array = GetIrArrayFor(reduce);
   1692     llvm::Value* output_address =
   1693         target_array.EmitArrayElementAddress(array_index, &ir_builder_);
   1694     EmitShardedVectorStore(output_address, accumulator, element_alignment,
   1695                            target_array);
   1696 
   1697     if (auto exit_terminator = loop->GetExitBasicBlock()->getTerminator()) {
   1698       CHECK_GT(LayoutUtil::MinorToMajor(reduce->shape()).size(), 1);
   1699       ir_builder_.SetInsertPoint(exit_terminator);
   1700     } else {
   1701       CHECK_EQ(LayoutUtil::MinorToMajor(reduce->shape()).size(), 1);
   1702       ir_builder_.SetInsertPoint(loop->GetExitBasicBlock());
   1703     }
   1704   }
   1705 
   1706   // Since we increment the stride for the inner dimension by more than 1, we
   1707   // may need to peel out an "epilogue" iteration to get the remaining elements
   1708   // in the following case:
   1709   if (innermost_dimension_size % vectorization_factor) {
   1710     // TODO(b/63775531): Consider using a scalar loop here to save on code size.
   1711     array_index[innermost_dimension] =
   1712         ir_builder_.getInt64(innermost_dimension_size -
   1713                              (innermost_dimension_size % vectorization_factor));
   1714 
   1715     ShardedVectorType vector_type = CreateShardedVectorType(
   1716         reduce->shape().element_type(),
   1717         innermost_dimension_size % vectorization_factor);
   1718     TF_ASSIGN_OR_RETURN(std::vector<llvm::Value*> accumulator,
   1719                         EmitInnerLoopForVectorizedReduction(
   1720                             reduction_generator, array_index, vector_type,
   1721                             init_value, arg, dimensions, element_alignment));
   1722 
   1723     llvm_ir::IrArray target_array = GetIrArrayFor(reduce);
   1724     llvm::Value* output_address =
   1725         target_array.EmitArrayElementAddress(array_index, &ir_builder_);
   1726     EmitShardedVectorStore(output_address, accumulator, element_alignment,
   1727                            target_array);
   1728   }
   1729 
   1730   if (outermost_loop_exit_block) {
   1731     ir_builder_.SetInsertPoint(outermost_loop_exit_block);
   1732   }
   1733 
   1734   return true;
   1735 }
   1736 
   1737 Status IrEmitter::HandleReduce(HloInstruction* reduce) {
   1738   auto arg = reduce->mutable_operand(0);
   1739   auto init_value = reduce->mutable_operand(1);
   1740   gtl::ArraySlice<int64> dimensions(reduce->dimensions());
   1741   HloComputation* function = reduce->to_apply();
   1742   if (!options::VectorizedReduceDisabled(hlo_module_config_)) {
   1743     string vectorization_failure_reason;
   1744     TF_ASSIGN_OR_RETURN(
   1745         bool vectorization_successful,
   1746         EmitVectorizedReduce(reduce, arg, init_value, dimensions, function,
   1747                              &vectorization_failure_reason));
   1748     if (vectorization_successful) {
   1749       VLOG(1) << "Successfully vectorized reduction " << reduce->ToString()
   1750               << "\n";
   1751       return Status::OK();
   1752     } else {
   1753       VLOG(1) << "Could not vectorize reduction " << reduce->ToString() << ": "
   1754               << vectorization_failure_reason;
   1755     }
   1756   }
   1757 
   1758   // The called computation should have been emitted previously.
   1759   llvm::Function* reducer_function = FindOrDie(emitted_functions_, function);
   1760   return EmitTargetElementLoop(
   1761       reduce, [this, reduce, arg, init_value, dimensions,
   1762                reducer_function](const llvm_ir::IrArray::Index& index) {
   1763         // Initialize an accumulator with init_value.
   1764         PrimitiveType accumulator_type = reduce->shape().element_type();
   1765         llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry(
   1766             llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_),
   1767             "accumulator", &ir_builder_,
   1768             MinimumAlignmentForPrimitiveType(accumulator_type));
   1769         llvm::Value* init_value_addr = GetEmittedValueFor(init_value);
   1770         llvm::Value* load_init_value = ir_builder_.CreateLoad(init_value_addr);
   1771         ir_builder_.CreateStore(load_init_value, accumulator_addr);
   1772 
   1773         // The enclosing loops go over all the target elements. Now we have to
   1774         // compute the actual target element. For this, we build a new loop nest
   1775         // to iterate over all the reduction dimensions in the argument.
   1776         // AddLoopsForShapeOnDimensions will return an Index where induction
   1777         // Value*s are placed for each dimension in dimensions, and all the rest
   1778         // are nullptrs.
   1779         llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &ir_builder_);
   1780         const llvm_ir::IrArray::Index reduced_dims_index =
   1781             loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions,
   1782                                                "reduction_dim");
   1783 
   1784         SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_);
   1785 
   1786         // Build a full index for the input argument, using reduced_dims_index
   1787         // as the base. In reduced_dims_index only the reduction dimensions are
   1788         // filled in. We fill in the rest of the dimensions with induction
   1789         // Value*s taken from 'index' which iterates over the target array.
   1790         // See the high-level description in the XLA documentation for details.
   1791         llvm_ir::IrArray arg_array(GetIrArrayFor(arg));
   1792         llvm_ir::IrArray::Index input_index = reduced_dims_index;
   1793         llvm_ir::IrArray::Index::const_iterator it = index.begin();
   1794 
   1795         for (size_t i = 0; i < input_index.size(); ++i) {
   1796           if (input_index[i] == nullptr) {
   1797             input_index[i] = *it++;
   1798           }
   1799         }
   1800         CHECK(index.end() == it);
   1801 
   1802         // Apply the reduction function to the loaded value.
   1803         llvm::Value* input_address =
   1804             arg_array.EmitArrayElementAddress(input_index, &ir_builder_);
   1805         llvm::Value* result = EmitElementFunctionCall(
   1806             reducer_function, reduce->shape(),
   1807             {accumulator_addr, input_address}, "reduce_function");
   1808         ir_builder_.CreateStore(result, accumulator_addr);
   1809 
   1810         SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_);
   1811         return ir_builder_.CreateLoad(accumulator_addr);
   1812       });
   1813 }
   1814 
   1815 Status IrEmitter::HandleSend(HloInstruction* send) {
   1816   // TODO(b/33942983): Support Send/Recv on CPU.
   1817   return Unimplemented("Send is not implemented on CPU.");
   1818 }
   1819 
   1820 Status IrEmitter::HandleSendDone(HloInstruction* send_done) {
   1821   // TODO(b/33942983): Support Send/Recv on CPU.
   1822   return Unimplemented("Send-done is not implemented on CPU.");
   1823 }
   1824 
   1825 Status IrEmitter::HandleSlice(HloInstruction* slice) {
   1826   VLOG(2) << "HandleSlice: " << slice->ToString();
   1827   auto operand = slice->operand(0);
   1828   // The code below emits a sequential loop nest. For the parallel backend, use
   1829   // ParallelLoopEmitter which respects dynamic loop bounds.
   1830   if (ShouldEmitParallelLoopFor(*slice)) {
   1831     return DefaultAction(slice);
   1832   }
   1833 
   1834   // The code below assumes the layouts are equal.
   1835   if (!LayoutUtil::Equal(operand->shape().layout(), slice->shape().layout())) {
   1836     return DefaultAction(slice);
   1837   }
   1838 
   1839   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(slice));
   1840 
   1841   if (ShapeUtil::HasZeroElements(slice->shape())) {
   1842     return Status::OK();
   1843   }
   1844 
   1845   const Layout& layout = operand->shape().layout();
   1846   const int64 num_dims = operand->shape().dimensions_size();
   1847 
   1848   // The slice lowering finds maximal contiguous blocks of memory that can be
   1849   // copied from the source to the target. This is done by looking at the
   1850   // source/target layout in minor to major order and do the following:
   1851   //
   1852   // * Find an initial segment of dimensions along which the slice uses the
   1853   //   whole dimension. These are the "inner" dimensions and can be folded into
   1854   //   the memcpy.
   1855   //
   1856   // * Of the remaining dimensions decide which ones require loops.
   1857   //
   1858   // * Implement the memcpy within the innermost loop.
   1859 
   1860   gtl::FlatSet<int64> inner_dims;
   1861   for (int64 dim : LayoutUtil::MinorToMajor(layout)) {
   1862     if (operand->shape().dimensions(dim) != slice->shape().dimensions(dim)) {
   1863       break;
   1864     }
   1865     inner_dims.insert(dim);
   1866   }
   1867 
   1868   const bool is_trivial_copy = (inner_dims.size() == num_dims);
   1869   if (is_trivial_copy) {
   1870     if (ShapeUtil::IsEffectiveScalar(slice->shape())) {
   1871       return DefaultAction(slice);
   1872     } else {
   1873       return EmitMemcpy(*slice, *operand);
   1874     }
   1875   }
   1876 
   1877   // The memcpy will copy elements that are logically this shape (allowed to be
   1878   // scalar).
   1879   const Shape logical_element_shape = ShapeUtil::FilterDimensions(
   1880       [&inner_dims](int64 dim) -> bool { return inner_dims.count(dim); },
   1881       operand->shape());
   1882 
   1883   const int64 primitive_elements_per_logical_element =
   1884       ShapeUtil::ElementsIn(logical_element_shape);
   1885 
   1886   // memcpy_dim is the innermost (in terms of layout) dimension for which the
   1887   // slice does *not* just copy all the elements along the dimension.
   1888   const int64 memcpy_dim = LayoutUtil::Minor(layout, inner_dims.size());
   1889 
   1890   const bool memcpy_is_contiguous = slice->slice_strides(memcpy_dim) == 1;
   1891   // The number of logical elements that can be copied in a single call
   1892   // to memcpy. We can only copy 1 element at a time if there is a non-trivial
   1893   // stride.
   1894   const int64 memcpy_logical_elements =
   1895       memcpy_is_contiguous
   1896           ? slice->slice_limits(memcpy_dim) - slice->slice_starts(memcpy_dim)
   1897           : 1;
   1898 
   1899   // Determine the dimensions that get lowered as loops.
   1900   std::vector<int64> outer_dims;
   1901   for (int64 i = 0; i < num_dims - inner_dims.size() - 1; ++i) {
   1902     outer_dims.push_back(LayoutUtil::Major(layout, i));
   1903   }
   1904 
   1905   // Is the slice along the memcpy dimension contiguous? If not, then memcpy_dim
   1906   // needs to be wrapped around a loop as well.
   1907   if (!memcpy_is_contiguous) {
   1908     outer_dims.push_back(memcpy_dim);
   1909   }
   1910 
   1911   llvm_ir::IrArray target_array = GetIrArrayFor(slice);
   1912 
   1913   const int64 num_outer_loops = outer_dims.size();
   1914   llvm_ir::ForLoopNest loops(IrName(slice), &ir_builder_);
   1915   llvm_ir::IrArray::Index target_index =
   1916       loops.AddLoopsForShapeOnDimensions(slice->shape(), outer_dims, "slice");
   1917 
   1918   // Only the indices for the outer dimensions have been initialized in
   1919   // target_index. The rest of the indices should get initialized to 0, since
   1920   // for the rest of the dimensions the copy writes to the full dimension.
   1921   std::replace(target_index.begin(), target_index.end(),
   1922                static_cast<llvm::Value*>(nullptr),
   1923                static_cast<llvm::Value*>(ir_builder_.getInt64(0)));
   1924 
   1925   if (num_outer_loops > 0) {
   1926     SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_);
   1927   }
   1928 
   1929   llvm_ir::IrArray source_array = GetIrArrayFor(operand);
   1930   const llvm_ir::IrArray::Index source_index = target_index.SourceIndexOfSlice(
   1931       /*shape=*/slice->shape(), /*starts=*/slice->slice_starts(),
   1932       /*strides=*/slice->slice_strides(), /*builder=*/&ir_builder_);
   1933 
   1934   llvm::Value* memcpy_dest = target_array.EmitArrayElementAddress(
   1935       target_index, &ir_builder_, "slice.dest");
   1936   llvm::Value* memcpy_source = source_array.EmitArrayElementAddress(
   1937       source_index, &ir_builder_, "slice.source");
   1938 
   1939   const int64 memcpy_elements =
   1940       primitive_elements_per_logical_element * memcpy_logical_elements;
   1941 
   1942   EmitTransferElements(memcpy_dest, memcpy_source, memcpy_elements,
   1943                        slice->shape().element_type(), target_array,
   1944                        source_array);
   1945 
   1946   if (VLOG_IS_ON(2)) {
   1947     const int64 memcpy_bytes =
   1948         ShapeUtil::ByteSizeOf(logical_element_shape) * memcpy_elements;
   1949     VLOG(2) << "  emitted copy of " << memcpy_bytes << " bytes inside "
   1950             << num_outer_loops << " loops";
   1951   }
   1952 
   1953   if (num_outer_loops > 0) {
   1954     SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_);
   1955   }
   1956 
   1957   return Status::OK();
   1958 }
   1959 
   1960 Status IrEmitter::HandleDynamicSlice(HloInstruction* dynamic_slice) {
   1961   if (ShapeUtil::IsScalar(dynamic_slice->shape())) {
   1962     TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_slice));
   1963     return EmitMemcpy(*dynamic_slice->operand(0), *dynamic_slice);
   1964   }
   1965   return DefaultAction(dynamic_slice);
   1966 }
   1967 
   1968 Status IrEmitter::HandleDynamicUpdateSlice(
   1969     HloInstruction* dynamic_update_slice) {
   1970   auto update = dynamic_update_slice->operand(1);
   1971   if (ShapeUtil::IsScalar(dynamic_update_slice->shape())) {
   1972     TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_update_slice));
   1973     return EmitMemcpy(*update, *dynamic_update_slice);
   1974   } else if (llvm_ir::CanUpdateDynamicSliceInPlace(dynamic_update_slice,
   1975                                                    assignment_)) {
   1976     TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_update_slice));
   1977     auto operands = GetIrArraysForOperandsOf(dynamic_update_slice);
   1978     return llvm_ir::EmitDynamicUpdateSliceInPlace(
   1979         operands, GetIrArrayFor(dynamic_update_slice),
   1980         IrName(dynamic_update_slice, "in_place"), &ir_builder_);
   1981   }
   1982   return DefaultAction(dynamic_update_slice);
   1983 }
   1984 
   1985 Status IrEmitter::HandleRecv(HloInstruction* recv) {
   1986   // TODO(b/33942983): Support Send/Recv on CPU.
   1987   return Unimplemented("Recv is not implemented on CPU.");
   1988 }
   1989 
   1990 Status IrEmitter::HandleRecvDone(HloInstruction* recv_done) {
   1991   // TODO(b/33942983): Support Send/Recv on CPU.
   1992   return Unimplemented("Recv-done is not implemented on CPU.");
   1993 }
   1994 
   1995 Status IrEmitter::HandlePad(HloInstruction* pad) {
   1996   // CPU backend does not properly handle negative padding but this is ok
   1997   // because negative padding should be removed by the algebraic simplifier.
   1998   for (auto& padding_dimension : pad->padding_config().dimensions()) {
   1999     if (padding_dimension.edge_padding_low() < 0 ||
   2000         padding_dimension.edge_padding_high() < 0) {
   2001       return InternalErrorStrCat(
   2002           "Encountered negative padding in IrEmitter on CPU. "
   2003           "This should have been eliminated at the HLO level. ",
   2004           pad->ToString());
   2005     }
   2006   }
   2007 
   2008   // First, fill in the padding value to all output elements.
   2009   TF_RETURN_IF_ERROR(EmitTargetElementLoop(
   2010       pad, "initialize",
   2011       [this, pad](const llvm_ir::IrArray::Index& target_index) {
   2012         const HloInstruction* padding_value = pad->operand(1);
   2013         llvm::Value* padding_value_addr = GetEmittedValueFor(padding_value);
   2014         return ir_builder_.CreateLoad(padding_value_addr);
   2015       }));
   2016 
   2017   // Create a loop to iterate over the operand elements and update the output
   2018   // locations where the operand elements should be stored.
   2019   llvm_ir::ForLoopNest loops(IrName(pad, "assign"), &ir_builder_);
   2020   const HloInstruction* operand = pad->operand(0);
   2021   const llvm_ir::IrArray::Index operand_index =
   2022       loops.AddLoopsForShape(operand->shape(), "operand");
   2023 
   2024   SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_);
   2025 
   2026   // Load an element from the operand.
   2027   llvm_ir::IrArray operand_array(GetIrArrayFor(operand));
   2028   llvm::Value* operand_data =
   2029       operand_array.EmitReadArrayElement(operand_index, &ir_builder_);
   2030 
   2031   // Compute the output index the operand element should be assigned to.
   2032   // output_index := edge_padding_low + operand_index * (interior_padding + 1)
   2033   const PaddingConfig& padding_config = pad->padding_config();
   2034   llvm_ir::IrArray::Index output_index;
   2035   for (size_t i = 0; i < operand_index.size(); ++i) {
   2036     llvm::Value* offset = ir_builder_.CreateMul(
   2037         operand_index[i],
   2038         ir_builder_.getInt64(padding_config.dimensions(i).interior_padding() +
   2039                              1));
   2040     llvm::Value* index = ir_builder_.CreateAdd(
   2041         offset,
   2042         ir_builder_.getInt64(padding_config.dimensions(i).edge_padding_low()));
   2043     output_index.push_back(index);
   2044   }
   2045 
   2046   // Store the operand element to the computed output location.
   2047   llvm_ir::IrArray output_array(GetIrArrayFor(pad));
   2048   output_array.EmitWriteArrayElement(output_index, operand_data, &ir_builder_);
   2049 
   2050   SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_);
   2051   return Status::OK();
   2052 }
   2053 
   2054 // If `hlo` is a Transpose, returns its operand; otherwise returns `hlo` itself.
   2055 static const HloInstruction* StripTranspose(const HloInstruction& hlo) {
   2056   if (hlo.IsRank2Transpose()) {
   2057     return hlo.operand(0);
   2058   }
   2059   return &hlo;
   2060 }
   2061 
   2062 Status IrEmitter::HandleFusion(HloInstruction* fusion) {
   2063   auto* root = fusion->fused_expression_root();
   2064   if (fusion->fusion_kind() == HloInstruction::FusionKind::kTransposeDot) {
   2065     DCHECK(root->opcode() == HloOpcode::kDot);
   2066     const HloInstruction* lhs_parameter = StripTranspose(*root->operand(0));
   2067     const HloInstruction* rhs_parameter = StripTranspose(*root->operand(1));
   2068     DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter &&
   2069            rhs_parameter->opcode() == HloOpcode::kParameter);
   2070     const HloInstruction* lhs =
   2071         fusion->operand(lhs_parameter->parameter_number());
   2072     const HloInstruction* rhs =
   2073         fusion->operand(rhs_parameter->parameter_number());
   2074 
   2075     TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
   2076         /*instruction=*/*root, /*operands=*/{lhs, rhs},
   2077         /*supported_types=*/{F32}));
   2078 
   2079     llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs));
   2080     llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs));
   2081 
   2082     Shape target_shape = fusion->shape();
   2083     TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion));
   2084     llvm_ir::IrArray target_array = GetIrArrayFor(fusion);
   2085     VLOG(2) << "HandleFusion kTransposeDot: ";
   2086     VLOG(2) << "  lhs operand: "
   2087             << llvm_ir::DumpToString(*lhs_array.GetBasePointer());
   2088     VLOG(2) << "  rhs operand: "
   2089             << llvm_ir::DumpToString(*rhs_array.GetBasePointer());
   2090     VLOG(2) << "  target: "
   2091             << llvm_ir::DumpToString(*target_array.GetBasePointer());
   2092 
   2093     // Dot operation is complicated so we delegate to a helper class.
   2094     TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation(
   2095         *root, root->operand(0)->IsRank2Transpose(),
   2096         root->operand(1)->IsRank2Transpose(), target_array, lhs_array,
   2097         rhs_array, /*addend_array=*/nullptr, GetExecutableRunOptionsArgument(),
   2098         &ir_builder_, hlo_module_config_, target_machine_features_));
   2099     return Status::OK();
   2100   } else if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion,
   2101                                                             assignment_)) {
   2102     VLOG(3) << "HandleFusion FusedDynamicUpdateSliceInPlace";
   2103     CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
   2104     TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion));
   2105 
   2106     // Delegate to common implementation of fused in-place dynamic-update-slice.
   2107     auto operands = GetIrArraysForOperandsOf(fusion);
   2108     return llvm_ir::EmitFusedDynamicUpdateSliceInPlace(
   2109         fusion, operands, GetIrArrayFor(fusion), &elemental_emitter,
   2110         &ir_builder_);
   2111   } else if (fusion->fusion_kind() == HloInstruction::FusionKind::kLoop) {
   2112     VLOG(3) << "HandleFusion kLoop";
   2113     CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
   2114     auto operands = GetIrArraysForOperandsOf(fusion);
   2115     FusedIrEmitter fused_emitter(operands, &elemental_emitter);
   2116     TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter));
   2117 
   2118     return EmitTargetElementLoop(fusion, fused_emitter.GetRootGenerator());
   2119   } else if (fusion->fusion_kind() == HloInstruction::FusionKind::kOutput) {
   2120     VLOG(3) << "HandleFusion kOutput";
   2121     int64 dot_op_index = root->operand(0)->opcode() == HloOpcode::kDot ? 0 : 1;
   2122     const HloInstruction* dot = root->operand(dot_op_index);
   2123     CHECK_EQ(dot->opcode(), HloOpcode::kDot)
   2124         << dot->ToString() << "  "
   2125         << fusion->fused_instructions_computation()->ToString();
   2126 
   2127     int64 dot_lhs_param_number = dot->operand(0)->parameter_number();
   2128     int64 dot_rhs_param_number = dot->operand(1)->parameter_number();
   2129     int64 addend_param_number =
   2130         root->operand(1 - dot_op_index)->parameter_number();
   2131 
   2132     Shape target_shape = fusion->shape();
   2133     TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion));
   2134     llvm_ir::IrArray target_array = GetIrArrayFor(fusion);
   2135 
   2136     llvm_ir::IrArray lhs_array(
   2137         GetIrArrayFor(fusion->operand(dot_lhs_param_number)));
   2138     llvm_ir::IrArray rhs_array(
   2139         GetIrArrayFor(fusion->operand(dot_rhs_param_number)));
   2140     llvm_ir::IrArray addend_array(
   2141         GetIrArrayFor(fusion->operand(addend_param_number)));
   2142 
   2143     TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation(
   2144         *dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array,
   2145         lhs_array, rhs_array, &addend_array, GetExecutableRunOptionsArgument(),
   2146         &ir_builder_, hlo_module_config_, target_machine_features_));
   2147     return Status::OK();
   2148   } else {
   2149     return Unimplemented("Fusion kind not implemented on CPU");
   2150   }
   2151 }
   2152 
   2153 Status IrEmitter::HandleCall(HloInstruction* call) {
   2154   HloComputation* computation = call->to_apply();
   2155   llvm::Function* call_ir_function = FindOrDie(emitted_functions_, computation);
   2156 
   2157   std::vector<llvm::Value*> parameter_addresses;
   2158   for (const HloInstruction* operand : call->operands()) {
   2159     parameter_addresses.push_back(GetEmittedValueFor(operand));
   2160   }
   2161 
   2162   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(call));
   2163 
   2164   if (!computation->root_instruction()->outer_dimension_partitions().empty() &&
   2165       !parallel_cpu_backend_) {
   2166     // ParallelTaskAssignment assigned partitions, emit call to
   2167     // ParallelForkJoin.
   2168     std::vector<llvm::Value*> call_args = GetArrayFunctionCallArguments(
   2169         parameter_addresses, &ir_builder_, computation->name(),
   2170         /*return_value_buffer=*/emitted_value_[call],
   2171         /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
   2172         /*temp_buffers_arg=*/GetTempBuffersArgument(),
   2173         /*profile_counters_arg=*/GetProfileCountersArgument());
   2174 
   2175     HloInstruction* root = computation->root_instruction();
   2176     TF_RETURN_IF_ERROR(EmitCallToParallelForkJoin(
   2177         call_args, root->shape(), root->outer_dimension_partitions(),
   2178         &ir_builder_, call_ir_function, computation->name()));
   2179   } else {
   2180     EmitArrayFunctionCallInto(call_ir_function, parameter_addresses,
   2181                               emitted_value_[call], computation->name());
   2182   }
   2183 
   2184   return Status::OK();
   2185 }
   2186 
   2187 Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
   2188   gtl::ArraySlice<HloInstruction*> operands(custom_call->operands());
   2189   tensorflow::StringPiece custom_call_target(custom_call->custom_call_target());
   2190   llvm::Type* i8_ptr_type = ir_builder_.getInt8PtrTy();
   2191   llvm::AllocaInst* operands_alloca =
   2192       llvm_ir::EmitAllocaAtFunctionEntryWithCount(
   2193           i8_ptr_type, ir_builder_.getInt32(operands.size()),
   2194           "cc_operands_alloca", &ir_builder_);
   2195   for (size_t i = 0; i < operands.size(); ++i) {
   2196     const HloInstruction* operand = operands[i];
   2197     llvm::Value* operand_as_i8ptr =
   2198         ir_builder_.CreatePointerCast(GetEmittedValueFor(operand), i8_ptr_type);
   2199     llvm::Value* slot_in_operands_alloca = ir_builder_.CreateInBoundsGEP(
   2200         operands_alloca, {ir_builder_.getInt64(i)});
   2201     ir_builder_.CreateStore(operand_as_i8ptr, slot_in_operands_alloca);
   2202   }
   2203   auto* custom_call_ir_function =
   2204       llvm::cast<llvm::Function>(module_->getOrInsertFunction(
   2205           AsStringRef(custom_call_target),
   2206           llvm::FunctionType::get(
   2207               /*Result=*/ir_builder_.getVoidTy(),
   2208               /*Params=*/{i8_ptr_type, operands_alloca->getType()},
   2209               /*isVarArg=*/false)));
   2210 
   2211   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call));
   2212   auto* output_address_arg = ir_builder_.CreatePointerCast(
   2213       GetEmittedValueFor(custom_call), i8_ptr_type);
   2214 
   2215   ir_builder_.CreateCall(custom_call_ir_function,
   2216                          {output_address_arg, operands_alloca});
   2217 
   2218   return Status::OK();
   2219 }
   2220 
   2221 Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
   2222   // Precondition: Condition computation must return a scalar bool.
   2223   HloComputation* condition = xla_while->while_condition();
   2224   TF_RET_CHECK(ShapeUtil::IsScalar(condition->root_instruction()->shape()) &&
   2225                condition->root_instruction()->shape().element_type() == PRED)
   2226       << "While condition computation must return bool; got: "
   2227       << ShapeUtil::HumanString(condition->root_instruction()->shape());
   2228   // Check that all while-related buffers share an allocation slice.
   2229   TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
   2230       xla_while->shape(),
   2231       [this, &xla_while](const Shape& /*subshape*/,
   2232                          const ShapeIndex& index) -> Status {
   2233         auto check = [this](const HloInstruction* a, const HloInstruction* b,
   2234                             const ShapeIndex& index) {
   2235           const BufferAllocation::Slice slice_a =
   2236               assignment_.GetUniqueSlice(a, index).ConsumeValueOrDie();
   2237           const BufferAllocation::Slice slice_b =
   2238               assignment_.GetUniqueSlice(b, index).ConsumeValueOrDie();
   2239           if (slice_a != slice_b) {
   2240             return InternalError(
   2241                 "instruction %s %s does not share slice with "
   2242                 "instruction %s %s",
   2243                 a->ToString().c_str(), slice_a.ToString().c_str(),
   2244                 b->ToString().c_str(), slice_b.ToString().c_str());
   2245           }
   2246           return Status::OK();
   2247         };
   2248         TF_RETURN_IF_ERROR(check(xla_while, xla_while->operand(0), index));
   2249         TF_RETURN_IF_ERROR(check(
   2250             xla_while, xla_while->while_condition()->parameter_instruction(0),
   2251             index));
   2252         TF_RETURN_IF_ERROR(
   2253             check(xla_while, xla_while->while_body()->parameter_instruction(0),
   2254                   index));
   2255         TF_RETURN_IF_ERROR(check(
   2256             xla_while, xla_while->while_body()->root_instruction(), index));
   2257         return Status::OK();
   2258       }));
   2259 
   2260   // Set emitted value to that of 'init' with which it shares an allocation.
   2261   const HloInstruction* init = xla_while->operand(0);
   2262   emitted_value_[xla_while] = GetEmittedValueFor(init);
   2263 
   2264   // The called computation should have been emitted previously.
   2265   llvm::Function* condition_ir_function =
   2266       FindOrDie(emitted_functions_, condition);
   2267   llvm::Function* body_ir_function =
   2268       FindOrDie(emitted_functions_, xla_while->while_body());
   2269 
   2270   // Generating:
   2271   //   while (Condition(while_result)) {
   2272   //     // CopyInsertion pass inserts copies which enable 'while_result' to
   2273   //     // be passed back in as 'Body' parameter.
   2274   //     while_result = Body(while_result);  // Insert
   2275   //   }
   2276 
   2277   // Terminates the current block with a branch to a while header.
   2278   llvm::BasicBlock* header_bb = llvm::BasicBlock::Create(
   2279       module_->getContext(), AsStringRef(IrName(xla_while, "header")),
   2280       compute_function_->function());
   2281   ir_builder_.CreateBr(header_bb);
   2282   ir_builder_.SetInsertPoint(header_bb);
   2283 
   2284   // Calls the condition function to determine whether to proceed with the
   2285   // body.  It must return a bool, so use the scalar call form.
   2286   llvm::Value* while_result = GetEmittedValueFor(xla_while);
   2287   llvm::Value* while_condition = EmitElementFunctionCall(
   2288       condition_ir_function, condition->root_instruction()->shape(),
   2289       {while_result}, IrName(xla_while, "cond"));
   2290   llvm::Value* while_predicate = ir_builder_.CreateICmpNE(
   2291       while_condition,
   2292       llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0));
   2293 
   2294   // Branches to the body or to the while exit depending on the condition.
   2295   llvm::BasicBlock* body_bb = llvm::BasicBlock::Create(
   2296       module_->getContext(), AsStringRef(IrName(xla_while, "body")),
   2297       compute_function_->function());
   2298   llvm::BasicBlock* exit_bb = llvm::BasicBlock::Create(
   2299       module_->getContext(), AsStringRef(IrName(xla_while, "exit")));
   2300   ir_builder_.CreateCondBr(while_predicate, body_bb, exit_bb);
   2301 
   2302   // Calls the body function from the body block.
   2303   ir_builder_.SetInsertPoint(body_bb);
   2304 
   2305   // Calls the body function.
   2306   EmitArrayFunctionCallInto(body_ir_function, {while_result}, while_result,
   2307                             IrName(xla_while, "body"));
   2308   // Finishes with a branch back to the header.
   2309   ir_builder_.CreateBr(header_bb);
   2310 
   2311   // Adds the exit block to the function and sets the insert point there.
   2312   compute_function_->function()->getBasicBlockList().push_back(exit_bb);
   2313   ir_builder_.SetInsertPoint(exit_bb);
   2314 
   2315   return Status::OK();
   2316 }
   2317 
   2318 StatusOr<bool> IrEmitter::EmitFastConcatenate(
   2319     HloInstruction* concatenate, gtl::ArraySlice<HloInstruction*> operands,
   2320     string* failure_reason) {
   2321   if (ShouldEmitParallelLoopFor(*concatenate)) {
   2322     *failure_reason =
   2323         "cannot generate memcpy-based concat for the parallel CPU backend";
   2324     return false;
   2325   }
   2326 
   2327   const Shape& output_shape = concatenate->shape();
   2328   for (auto* op : operands) {
   2329     if (!LayoutUtil::Equal(op->shape().layout(), output_shape.layout())) {
   2330       *failure_reason = "operand has mismatching layouts";
   2331       return false;
   2332     }
   2333     if (LayoutUtil::IsPadded(op->shape())) {
   2334       *failure_reason = "operand has padded layout";
   2335       return false;
   2336     }
   2337   }
   2338 
   2339   CHECK(!LayoutUtil::IsPadded(concatenate->shape()));
   2340 
   2341   // We split the dimensions into three categories: the dimension over which we
   2342   // are concatenating (concat_dim), the dimensions that are minor to it
   2343   // (inner_dims) and the dimensions that are major to it (outer_dims).
   2344 
   2345   int64 concat_dim = concatenate->dimensions(0);
   2346   const Layout& output_layout = output_shape.layout();
   2347   auto output_min2maj = LayoutUtil::MinorToMajor(output_layout);
   2348   auto concat_dim_layout_itr =
   2349       std::find(output_min2maj.begin(), output_min2maj.end(), concat_dim);
   2350 
   2351   std::vector<int64> inner_dims(output_min2maj.begin(), concat_dim_layout_itr);
   2352   std::vector<int64> outer_dims(std::next(concat_dim_layout_itr),
   2353                                 output_min2maj.end());
   2354 
   2355   llvm::Type* i8_ptr_type = ir_builder_.getInt8PtrTy();
   2356   llvm::Type* i8_type = ir_builder_.getInt8Ty();
   2357 
   2358   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(concatenate));
   2359   llvm_ir::IrArray target_array = GetIrArrayFor(concatenate);
   2360 
   2361   llvm_ir::ForLoopNest loops(IrName(concatenate), &ir_builder_);
   2362   llvm_ir::IrArray::Index outer_dims_index =
   2363       loops.AddLoopsForShapeOnDimensions(output_shape, outer_dims, "concat");
   2364   std::replace(outer_dims_index.begin(), outer_dims_index.end(),
   2365                static_cast<llvm::Value*>(nullptr),
   2366                static_cast<llvm::Value*>(ir_builder_.getInt64(0)));
   2367 
   2368   if (!outer_dims.empty()) {
   2369     SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_);
   2370   }
   2371 
   2372   PrimitiveType primitive_type = output_shape.element_type();
   2373   unsigned primitive_type_size =
   2374       ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
   2375 
   2376   // Contiguous subregions from each operand to the concatenate contribute to a
   2377   // contiguous subregion in the target buffer starting at target_region_begin.
   2378   llvm::Value* target_region_begin = ir_builder_.CreateBitCast(
   2379       target_array.EmitArrayElementAddress(outer_dims_index, &ir_builder_,
   2380                                            "target_region"),
   2381       i8_ptr_type);
   2382   int64 byte_offset_into_target_region = 0;
   2383 
   2384   int64 inner_dims_product =
   2385       std::accumulate(inner_dims.begin(), inner_dims.end(), 1l,
   2386                       [&](int64 product, int64 inner_dim) {
   2387                         return product * output_shape.dimensions(inner_dim);
   2388                       });
   2389 
   2390   // For each operand, emit a memcpy from the operand to the target of size
   2391   // equal to the product of inner dimensions.
   2392   for (HloInstruction* operand : operands) {
   2393     const Shape& input_shape = operand->shape();
   2394     llvm_ir::IrArray source_array = GetIrArrayFor(operand);
   2395     llvm::Value* copy_source_address = ir_builder_.CreateBitCast(
   2396         source_array.EmitArrayElementAddress(outer_dims_index, &ir_builder_,
   2397                                              "src_addr"),
   2398         i8_ptr_type);
   2399 
   2400     llvm::Value* copy_target_address = ir_builder_.CreateGEP(
   2401         i8_type, target_region_begin,
   2402         ir_builder_.getInt64(byte_offset_into_target_region));
   2403 
   2404     EmitTransferElements(
   2405         copy_target_address, copy_source_address,
   2406         inner_dims_product * input_shape.dimensions(concat_dim), primitive_type,
   2407         target_array, source_array);
   2408 
   2409     byte_offset_into_target_region += inner_dims_product *
   2410                                       input_shape.dimensions(concat_dim) *
   2411                                       primitive_type_size;
   2412   }
   2413 
   2414   if (!outer_dims.empty()) {
   2415     SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_);
   2416   }
   2417 
   2418   return true;
   2419 }
   2420 
   2421 void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source,
   2422                                      int64 element_count,
   2423                                      PrimitiveType primitive_type,
   2424                                      const llvm_ir::IrArray& target_array,
   2425                                      const llvm_ir::IrArray& source_array) {
   2426   unsigned primitive_type_size =
   2427       ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
   2428   unsigned element_alignment = tensorflow::MathUtil::GCD<unsigned>(
   2429       primitive_type_size, MinimumAlignmentForPrimitiveType(primitive_type));
   2430   llvm::Type* primitive_ptr_type = llvm::PointerType::getUnqual(
   2431       llvm_ir::PrimitiveTypeToIrType(primitive_type, module_));
   2432 
   2433   if (element_count == 1) {
   2434     auto* load_instruction = ir_builder_.CreateAlignedLoad(
   2435         ir_builder_.CreateBitCast(source, primitive_ptr_type),
   2436         element_alignment);
   2437     source_array.AnnotateLoadStoreInstructionWithMetadata(load_instruction);
   2438     auto* store_instruction = ir_builder_.CreateAlignedStore(
   2439         load_instruction, ir_builder_.CreateBitCast(target, primitive_ptr_type),
   2440         element_alignment);
   2441     target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction);
   2442   } else {
   2443     auto* memcpy_instruction = ir_builder_.CreateMemCpy(
   2444         target, source, element_count * primitive_type_size, element_alignment);
   2445 
   2446     // The memcpy does the load and the store internally.  The aliasing related
   2447     // metadata has to reflect that.
   2448     std::map<int, llvm::MDNode*> merged_metadata =
   2449         llvm_ir::MergeMetadata(&module_->getContext(), source_array.metadata(),
   2450                                target_array.metadata());
   2451     for (const auto& kind_md_pair : merged_metadata) {
   2452       memcpy_instruction->setMetadata(kind_md_pair.first, kind_md_pair.second);
   2453     }
   2454   }
   2455 }
   2456 
   2457 Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) {
   2458   gtl::ArraySlice<HloInstruction*> operands(concatenate->operands());
   2459   string failure_reason;
   2460   TF_ASSIGN_OR_RETURN(
   2461       bool successful,
   2462       EmitFastConcatenate(concatenate, operands, &failure_reason));
   2463   if (successful) {
   2464     VLOG(1) << "Emitted fast concatenate for " << concatenate->ToString();
   2465     return Status::OK();
   2466   }
   2467 
   2468   VLOG(1) << "Could not emit fast concatenate for " << concatenate->ToString()
   2469           << ": " << failure_reason;
   2470 
   2471   return DefaultAction(concatenate);
   2472 }
   2473 
   2474 Status IrEmitter::HandleConditional(HloInstruction* conditional) {
   2475   auto pred = conditional->operand(0);
   2476   auto true_arg = conditional->operand(1);
   2477   auto false_arg = conditional->operand(2);
   2478   TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()) &&
   2479                pred->shape().element_type() == PRED)
   2480       << "Predicate on a Conditional must be bool; got: "
   2481       << ShapeUtil::HumanString(pred->shape());
   2482 
   2483   HloComputation* true_computation = conditional->true_computation();
   2484   HloComputation* false_computation = conditional->false_computation();
   2485   TF_RET_CHECK(ShapeUtil::Equal(conditional->shape(),
   2486                                 true_computation->root_instruction()->shape()))
   2487       << "Shape of conditional should be same as the shape of the true "
   2488       << "computation; got: " << ShapeUtil::HumanString(conditional->shape())
   2489       << " and "
   2490       << ShapeUtil::HumanString(true_computation->root_instruction()->shape());
   2491 
   2492   TF_RET_CHECK(ShapeUtil::Equal(conditional->shape(),
   2493                                 false_computation->root_instruction()->shape()))
   2494       << "Shape of conditional should be same as the shape of the false "
   2495       << "computation; got: " << ShapeUtil::HumanString(conditional->shape())
   2496       << " and "
   2497       << ShapeUtil::HumanString(false_computation->root_instruction()->shape());
   2498 
   2499   llvm::Function* true_function =
   2500       FindOrDie(emitted_functions_, true_computation);
   2501   llvm::Function* false_function =
   2502       FindOrDie(emitted_functions_, false_computation);
   2503 
   2504   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(conditional));
   2505   llvm::Value* conditional_result = GetEmittedValueFor(conditional);
   2506 
   2507   // Generating:
   2508   //   if (pred)
   2509   //     cond_result = true_computation(true_operand)
   2510   //   else
   2511   //     cond_result = false_computation(false_operand)
   2512   llvm::LoadInst* pred_value = ir_builder_.CreateLoad(
   2513       GetIrArrayFor(pred).GetBasePointer(), "load_predicate_value");
   2514   llvm::Value* pred_cond = ir_builder_.CreateICmpNE(
   2515       pred_value,
   2516       llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
   2517       "boolean_predicate");
   2518   llvm_ir::LlvmIfData if_data =
   2519       llvm_ir::EmitIfThenElse(pred_cond, "conditional", &ir_builder_);
   2520 
   2521   SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
   2522   EmitArrayFunctionCallInto(true_function, {GetEmittedValueFor(true_arg)},
   2523                             conditional_result, IrName(conditional, "_true"));
   2524 
   2525   SetToFirstInsertPoint(if_data.false_block, &ir_builder_);
   2526   EmitArrayFunctionCallInto(false_function, {GetEmittedValueFor(false_arg)},
   2527                             conditional_result, IrName(conditional, "_false"));
   2528 
   2529   SetToFirstInsertPoint(if_data.after_block, &ir_builder_);
   2530   return Status::OK();
   2531 }
   2532 
   2533 Status IrEmitter::FinishVisit(HloInstruction* root) {
   2534   // When this method is called, we should have already emitted an IR value for
   2535   // the root (return) op. The IR value holds the address of the buffer holding
   2536   // the value. If the root is a constant or parameter, we perform a memcpy from
   2537   // this buffer to the retval buffer of the computation. Otherwise, there's
   2538   // nothing to do since the result was already written directly into the output
   2539   // buffer.
   2540   VLOG(2) << "FinishVisit root: " << root->ToString();
   2541   llvm::Value* root_value = GetEmittedValueFor(root);
   2542   VLOG(2) << "  value: " << llvm_ir::DumpToString(*root_value);
   2543 
   2544   auto record_complete_computation = [&](llvm::Value* prof_counter) {
   2545     if (prof_counter) {
   2546       profiling_state_.RecordCompleteComputation(&ir_builder_, prof_counter);
   2547     }
   2548   };
   2549 
   2550   // For the parallel cpu backend, we record the total for each embedded
   2551   // computation callee with its caller kCall HLO.
   2552   if (parallel_cpu_backend_ && is_top_level_computation_) {
   2553     auto* computation = root->parent();
   2554     auto* entry_computation = computation->parent()->entry_computation();
   2555     if (computation != entry_computation) {
   2556       for (HloInstruction* instruction : entry_computation->instructions()) {
   2557         if (instruction->opcode() == HloOpcode::kCall &&
   2558             instruction->to_apply()->root_instruction() == root) {
   2559           record_complete_computation(GetProfileCounterFor(*instruction));
   2560           return Status::OK();
   2561         }
   2562       }
   2563     }
   2564   }
   2565 
   2566   // For the entry computation this increment is cumulative of embedded
   2567   // computations since it includes cycles spent in computations invoked by
   2568   // While, Call etc.
   2569   record_complete_computation(GetProfileCounterFor(*root->parent()));
   2570   return Status::OK();
   2571 }
   2572 
   2573 template <typename T>
   2574 llvm::Value* IrEmitter::GetProfileCounterCommon(
   2575     const T& hlo,
   2576     const std::unordered_map<const T*, int64>& profile_index_map) {
   2577   auto it = profile_index_map.find(&hlo);
   2578   if (it == profile_index_map.end()) {
   2579     return nullptr;
   2580   }
   2581 
   2582   int64 prof_counter_idx = it->second;
   2583   string counter_name = IrName("prof_counter", hlo.name());
   2584   return ir_builder_.CreateGEP(GetProfileCountersArgument(),
   2585                                ir_builder_.getInt64(prof_counter_idx),
   2586                                AsStringRef(counter_name));
   2587 }
   2588 
   2589 void IrEmitter::ProfilingState::UpdateProfileCounter(
   2590     llvm::IRBuilder<>* ir_builder, llvm::Value* prof_counter,
   2591     llvm::Value* cycle_end, llvm::Value* cycle_start) {
   2592   auto* cycle_diff = ir_builder->CreateSub(cycle_end, cycle_start);
   2593   llvm::LoadInst* old_cycle_count =
   2594       ir_builder->CreateLoad(prof_counter, "old_cycle_count");
   2595   auto* new_cycle_count =
   2596       ir_builder->CreateAdd(cycle_diff, old_cycle_count, "new_cycle_count");
   2597   ir_builder->CreateStore(new_cycle_count, prof_counter);
   2598 }
   2599 
   2600 llvm::Value* IrEmitter::ProfilingState::ReadCycleCounter(
   2601     llvm::IRBuilder<>* ir_builder) {
   2602   llvm::Module* module = ir_builder->GetInsertBlock()->getModule();
   2603   if (use_rdtscp_) {
   2604     llvm::Function* func_llvm_readcyclecounter =
   2605         llvm::Intrinsic::getDeclaration(module,
   2606                                         llvm::Intrinsic::readcyclecounter);
   2607     return ir_builder->CreateCall(func_llvm_readcyclecounter);
   2608   }
   2609   llvm::Function* func_llvm_x86_rdtscp =
   2610       llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::x86_rdtscp);
   2611   if (!aux_i8ptr_) {
   2612     llvm::AllocaInst* rdtscp_aux = llvm_ir::EmitAllocaAtFunctionEntry(
   2613         ir_builder->getInt32Ty(), "rdtscp_aux", ir_builder);
   2614     aux_i8ptr_ =
   2615         ir_builder->CreateBitCast(rdtscp_aux, ir_builder->getInt8PtrTy());
   2616   }
   2617   llvm::ConstantInt* alloca_size = ir_builder->getInt64(4);
   2618   llvm::Function* func_llvm_lifetime_start =
   2619       llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::lifetime_start);
   2620   ir_builder->CreateCall(func_llvm_lifetime_start, {alloca_size, aux_i8ptr_});
   2621   llvm::Value* rdtscp_call =
   2622       ir_builder->CreateCall(func_llvm_x86_rdtscp, aux_i8ptr_);
   2623   llvm::Function* func_llvm_lifetime_end =
   2624       llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::lifetime_end);
   2625   ir_builder->CreateCall(func_llvm_lifetime_end, {alloca_size, aux_i8ptr_});
   2626   return rdtscp_call;
   2627 }
   2628 
   2629 void IrEmitter::ProfilingState::RecordCycleStart(llvm::IRBuilder<>* ir_builder,
   2630                                                  HloInstruction* hlo) {
   2631   auto* cycle_start = ReadCycleCounter(ir_builder);
   2632   cycle_start->setName(AsStringRef(IrName(hlo, "cycle_start")));
   2633   cycle_starts_[hlo] = cycle_start;
   2634   if (first_read_cycle_start_ == nullptr) {
   2635     first_read_cycle_start_ = cycle_start;
   2636   }
   2637 }
   2638 
   2639 void IrEmitter::ProfilingState::RecordCycleDelta(llvm::IRBuilder<>* ir_builder,
   2640                                                  HloInstruction* hlo,
   2641                                                  llvm::Value* prof_counter) {
   2642   auto* cycle_end = ReadCycleCounter(ir_builder);
   2643   cycle_end->setName(AsStringRef(IrName(hlo, "cycle_end")));
   2644   auto* cycle_start = cycle_starts_[hlo];
   2645   UpdateProfileCounter(ir_builder, prof_counter, cycle_end, cycle_start);
   2646   last_read_cycle_end_ = cycle_end;
   2647 }
   2648 
   2649 void IrEmitter::ProfilingState::RecordCompleteComputation(
   2650     llvm::IRBuilder<>* ir_builder, llvm::Value* prof_counter) {
   2651   if (last_read_cycle_end_ && first_read_cycle_start_) {
   2652     UpdateProfileCounter(ir_builder, prof_counter, last_read_cycle_end_,
   2653                          first_read_cycle_start_);
   2654   }
   2655 }
   2656 
   2657 Status IrEmitter::Preprocess(HloInstruction* hlo) {
   2658   VLOG(3) << "Visiting: " << hlo->ToString();
   2659   if (instruction_to_profile_idx_.count(hlo)) {
   2660     profiling_state_.RecordCycleStart(&ir_builder_, hlo);
   2661   }
   2662   return Status::OK();
   2663 }
   2664 
   2665 Status IrEmitter::Postprocess(HloInstruction* hlo) {
   2666   if (auto* prof_counter = GetProfileCounterFor(*hlo)) {
   2667     profiling_state_.RecordCycleDelta(&ir_builder_, hlo, prof_counter);
   2668   }
   2669   return Status::OK();
   2670 }
   2671 
   2672 llvm_ir::IrArray IrEmitter::GetIrArrayFor(const HloInstruction* hlo) {
   2673   llvm::Value* value_for_op = GetEmittedValueFor(hlo);
   2674 
   2675   llvm_ir::IrArray array(value_for_op, hlo->shape());
   2676   AddAliasingInformationToIrArray(*hlo, &array);
   2677   return array;
   2678 }
   2679 
   2680 std::vector<llvm_ir::IrArray> IrEmitter::GetIrArraysForOperandsOf(
   2681     const HloInstruction* hlo) {
   2682   std::vector<llvm_ir::IrArray> arrays;
   2683   std::transform(
   2684       hlo->operands().begin(), hlo->operands().end(),
   2685       std::back_inserter(arrays),
   2686       [&](const HloInstruction* operand) { return GetIrArrayFor(operand); });
   2687   return arrays;
   2688 }
   2689 
   2690 llvm::Value* IrEmitter::GetEmittedValueFor(const HloInstruction* hlo) {
   2691   auto it = emitted_value_.find(hlo);
   2692   if (it == emitted_value_.end()) {
   2693     LOG(FATAL) << "could not find emitted value for: " << hlo->ToString();
   2694   }
   2695   return it->second;
   2696 }
   2697 
   2698 llvm::Type* IrEmitter::IrShapeType(const Shape& shape) {
   2699   return llvm_ir::ShapeToIrType(shape, module_);
   2700 }
   2701 
   2702 llvm::Value* IrEmitter::GetProfileCountersArgument() {
   2703   return compute_function_->profile_counters_arg();
   2704 }
   2705 
   2706 llvm::Value* IrEmitter::GetTempBuffersArgument() {
   2707   return compute_function_->temp_buffers_arg();
   2708 }
   2709 
   2710 llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() {
   2711   return compute_function_->exec_run_options_arg();
   2712 }
   2713 
   2714 llvm::Value* IrEmitter::EmitTempBufferPointer(
   2715     const BufferAllocation::Slice& slice, const Shape& target_shape) {
   2716   llvm::Type* element_type = IrShapeType(target_shape);
   2717   // The alignment and number of bytes within the temporary buffer is determined
   2718   // by the maximal shape as determined by buffer assignment.
   2719   const BufferAllocation& allocation = assignment_.GetAllocation(slice.index());
   2720   if (allocation.is_thread_local()) {
   2721     // Thread-local allocations should only be assigned a single buffer.
   2722     const auto& assigned_buffers = allocation.assigned_buffers();
   2723     CHECK_EQ(1, assigned_buffers.size());
   2724     const Shape& shape = assigned_buffers.begin()->first->shape();
   2725 
   2726     llvm::AllocaInst*& tempbuf_address = thread_local_buffers_[{
   2727         ir_builder_.GetInsertBlock()->getParent(), slice}];
   2728     if (tempbuf_address == nullptr) {
   2729       tempbuf_address = llvm_ir::EmitAllocaAtFunctionEntry(
   2730           IrShapeType(shape),
   2731           tensorflow::strings::StrCat("thread_local", slice.ToString()),
   2732           &ir_builder_, MinimumAlignmentForShape(target_shape));
   2733     }
   2734     return ir_builder_.CreateBitCast(tempbuf_address,
   2735                                      element_type->getPointerTo());
   2736   }
   2737 
   2738   llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP(
   2739       GetTempBuffersArgument(), slice.index(), &ir_builder_);
   2740   llvm::LoadInst* tempbuf_address_base =
   2741       ir_builder_.CreateLoad(tempbuf_address_ptr);
   2742   if (is_top_level_computation_ &&
   2743       hlo_module_config_.debug_options()
   2744           .xla_llvm_enable_invariant_load_metadata()) {
   2745     // In the entry computation the parameter slots in the %params argument are
   2746     // invariant through program execution.  In computations that are called
   2747     // from the entry computation (via kWhile, kCall and kConditional) the
   2748     // parameter slots are *not* invariant since they're written to by their
   2749     // callers.
   2750     tempbuf_address_base->setMetadata(
   2751         llvm::LLVMContext::MD_invariant_load,
   2752         llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{}));
   2753   }
   2754   AttachAlignmentMetadataForLoad(tempbuf_address_base, allocation.size());
   2755   AttachDereferenceableMetadataForLoad(tempbuf_address_base, allocation.size());
   2756 
   2757   llvm::Value* tempbuf_address_untyped = tempbuf_address_base;
   2758   if (slice.offset() > 0) {
   2759     // Adjust the address to account for the slice offset.
   2760     tempbuf_address_untyped = ir_builder_.CreateInBoundsGEP(
   2761         tempbuf_address_base, ir_builder_.getInt64(slice.offset()));
   2762   }
   2763   return ir_builder_.CreateBitCast(tempbuf_address_untyped,
   2764                                    element_type->getPointerTo());
   2765 }
   2766 
   2767 // Emits a function call returning a single array element.  Allocates space
   2768 // for a single element_type value, and loads it after call.
   2769 llvm::Value* IrEmitter::EmitElementFunctionCall(
   2770     llvm::Function* function, const Shape& return_shape,
   2771     gtl::ArraySlice<llvm::Value*> parameter_addresses,
   2772     tensorflow::StringPiece name) {
   2773   llvm::Value* return_value_buffer = EmitArrayFunctionCall(
   2774       function, return_shape, 1, parameter_addresses, name);
   2775   return ir_builder_.CreateLoad(
   2776       return_value_buffer,
   2777       AsStringRef(tensorflow::strings::StrCat(name, "_return_value")));
   2778 }
   2779 
   2780 // Emits a core function call based on the following pseudo-code.
   2781 //
   2782 //   char** parameter_addresses_buffer =
   2783 //       allocate buffer with a pointer for each parameter to the function
   2784 //   for each parameter index, i.e. for i = 0, ..., #parameters:
   2785 //     parameter_addresses_buffer[i] = parameter_addresses[i]
   2786 //   call function(return_value_buffer,
   2787 //                 parameter_addresses_buffer,
   2788 //                 temps)
   2789 //   return return_value_buffer  -- address of the return value.
   2790 void IrEmitter::EmitArrayFunctionCallInto(
   2791     llvm::Function* function, gtl::ArraySlice<llvm::Value*> parameter_addresses,
   2792     llvm::Value* return_value_buffer, tensorflow::StringPiece name) {
   2793   ir_builder_.CreateCall(
   2794       function, GetArrayFunctionCallArguments(
   2795                     parameter_addresses, &ir_builder_, name,
   2796                     /*return_value_buffer=*/return_value_buffer,
   2797                     /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
   2798                     /*temp_buffers_arg=*/GetTempBuffersArgument(),
   2799                     /*profile_counters_arg=*/GetProfileCountersArgument()));
   2800 }
   2801 
   2802 llvm::Value* IrEmitter::EmitArrayFunctionCall(
   2803     llvm::Function* function, const Shape& return_shape, int64 element_count,
   2804     gtl::ArraySlice<llvm::Value*> parameter_addresses,
   2805     tensorflow::StringPiece name) {
   2806   llvm::Value* elements =
   2807       llvm::ConstantInt::get(ir_builder_.getInt64Ty(), element_count);
   2808   PrimitiveType return_type = return_shape.element_type();
   2809   llvm::Value* return_value_buffer =
   2810       llvm_ir::EmitAllocaAtFunctionEntryWithCount(
   2811           llvm_ir::PrimitiveTypeToIrType(return_type, module_), elements,
   2812           tensorflow::strings::StrCat(name, "_return_value_address"),
   2813           &ir_builder_, MinimumAlignmentForPrimitiveType(return_type));
   2814   EmitArrayFunctionCallInto(function, parameter_addresses, return_value_buffer,
   2815                             name);
   2816   return return_value_buffer;
   2817 }
   2818 
   2819 Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) {
   2820   llvm::Value* addr;
   2821   const Shape& target_shape = op->shape();
   2822   if (op == op->parent()->root_instruction()) {
   2823     // For the root node, we write directly to the output buffer of the
   2824     // function.
   2825     llvm::Argument* retval = compute_function_->result_arg();
   2826     if (!ShapeUtil::IsNil(target_shape)) {
   2827       llvm::AttrBuilder attr_builder;
   2828       attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape));
   2829       attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape));
   2830       retval->addAttrs(attr_builder);
   2831     }
   2832     addr = ir_builder_.CreateBitCast(retval,
   2833                                      IrShapeType(target_shape)->getPointerTo());
   2834   } else {
   2835     // For other nodes, we need the temporary buffer allocated for this node to
   2836     // write the result into.
   2837     TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
   2838                         assignment_.GetUniqueTopLevelSlice(op));
   2839     addr = EmitTempBufferPointer(slice, target_shape);
   2840   }
   2841   addr->setName(AsStringRef(IrName(op)));
   2842   emitted_value_[op] = addr;
   2843   return Status::OK();
   2844 }
   2845 
   2846 Status IrEmitter::EmitTargetElementLoop(
   2847     HloInstruction* target_op,
   2848     const llvm_ir::ElementGenerator& element_generator) {
   2849   return EmitTargetElementLoop(target_op, /*desc=*/"", element_generator);
   2850 }
   2851 
   2852 Status IrEmitter::EmitTargetElementLoop(
   2853     HloInstruction* target_op, tensorflow::StringPiece desc,
   2854     const llvm_ir::ElementGenerator& element_generator) {
   2855   VLOG(2) << "EmitTargetElementLoop: " << target_op->ToString();
   2856 
   2857   const Shape& target_shape = target_op->shape();
   2858   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(target_op));
   2859   llvm_ir::IrArray target_array = GetIrArrayFor(target_op);
   2860 
   2861   if (target_op->IsMultiOutputFusion()) {
   2862     // For multiple outputs fusion, we need to emit each operand and the root.
   2863     TF_RET_CHECK(num_dynamic_loop_bounds_ == 0);
   2864     std::vector<llvm_ir::IrArray> output_arrays;
   2865     for (int64 i = 0; i < ShapeUtil::TupleElementCount(target_shape); ++i) {
   2866       TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
   2867                           assignment_.GetUniqueSlice(target_op, {i}));
   2868       const Shape& element_shape = ShapeUtil::GetSubshape(target_shape, {i});
   2869       llvm::Value* op_target_address =
   2870           EmitTempBufferPointer(slice, element_shape);
   2871       output_arrays.push_back(
   2872           llvm_ir::IrArray(op_target_address, element_shape));
   2873     }
   2874     TF_RETURN_IF_ERROR(
   2875         llvm_ir::LoopEmitter(element_generator, output_arrays, &ir_builder_)
   2876             .EmitLoop(IrName(target_op)));
   2877 
   2878     std::vector<llvm::Value*> tuple_operand_ptrs;
   2879     for (int64 i = 0; i < output_arrays.size(); ++i) {
   2880       tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
   2881     }
   2882     llvm_ir::EmitTuple(target_array, tuple_operand_ptrs, &ir_builder_, module_);
   2883 
   2884   } else {
   2885     if (ShouldEmitParallelLoopFor(*target_op)) {
   2886       // Emit code to read dynamic loop bounds from compute function argument.
   2887       std::vector<std::pair<llvm::Value*, llvm::Value*>> dynamic_loop_bounds =
   2888           compute_function_->GetDynamicLoopBounds();
   2889       // Emit parallel loop with dynamic loop bounds for most-major dimensions.
   2890       TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, target_array,
   2891                                              &dynamic_loop_bounds, &ir_builder_)
   2892                              .EmitLoop(IrName(target_op)));
   2893     } else {
   2894       TF_RETURN_IF_ERROR(
   2895           llvm_ir::LoopEmitter(element_generator, target_array, &ir_builder_)
   2896               .EmitLoop(IrName(target_op)));
   2897     }
   2898   }
   2899   return Status::OK();
   2900 }
   2901 
   2902 Status IrEmitter::EmitMemcpy(const HloInstruction& source,
   2903                              const HloInstruction& destination) {
   2904   llvm::Value* source_value = GetEmittedValueFor(&source);
   2905   llvm::Value* destination_value = GetEmittedValueFor(&destination);
   2906   int64 source_size = ByteSizeOf(source.shape());
   2907   // TODO(b/63762267): Be more aggressive about specifying alignment.
   2908   ir_builder_.CreateMemCpy(destination_value, source_value, source_size, 1);
   2909   return Status::OK();
   2910 }
   2911 
   2912 Status IrEmitter::ElementTypesSameAndSupported(
   2913     const HloInstruction& instruction,
   2914     gtl::ArraySlice<const HloInstruction*> operands,
   2915     gtl::ArraySlice<PrimitiveType> supported_types) {
   2916   for (auto operand : operands) {
   2917     TF_RET_CHECK(
   2918         ShapeUtil::SameElementType(operands[0]->shape(), operand->shape()));
   2919   }
   2920 
   2921   TF_RET_CHECK(!operands.empty());
   2922   PrimitiveType primitive_type = operands[0]->shape().element_type();
   2923   if (std::find(supported_types.begin(), supported_types.end(),
   2924                 primitive_type) == supported_types.end()) {
   2925     return Unimplemented("unsupported operand type %s in op %s",
   2926                          PrimitiveType_Name(primitive_type).c_str(),
   2927                          HloOpcodeString(instruction.opcode()).c_str());
   2928   }
   2929   return Status::OK();
   2930 }
   2931 
   2932 Status IrEmitter::DefaultAction(HloInstruction* hlo) {
   2933   ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
   2934   for (const HloInstruction* operand : hlo->operands()) {
   2935     operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
   2936       return GetIrArrayFor(operand).EmitReadArrayElement(index, &ir_builder_);
   2937     };
   2938   }
   2939   CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
   2940   return EmitTargetElementLoop(
   2941       hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator));
   2942 }
   2943 
   2944 StatusOr<llvm::Value*> IrEmitter::EmitScalarCall(
   2945     PrimitiveType return_type, HloComputation* computation,
   2946     const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name) {
   2947   llvm::Function* llvm_function = FindOrDie(emitted_functions_, computation);
   2948   std::vector<llvm::Value*> argument_addrs;
   2949   for (auto argument : arguments) {
   2950     llvm::Value* argument_addr = llvm_ir::EmitAllocaAtFunctionEntry(
   2951         argument->getType(), "arg_addr", &ir_builder_);
   2952     ir_builder_.CreateStore(argument, argument_addr);
   2953     argument_addrs.push_back(argument_addr);
   2954   }
   2955   return EmitElementFunctionCall(llvm_function,
   2956                                  ShapeUtil::MakeShape(return_type, {}),
   2957                                  argument_addrs, name);
   2958 }
   2959 }  // namespace cpu
   2960 }  // namespace xla
   2961