1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/cpu/ir_emitter.h" 17 18 #include <stddef.h> 19 #include <stdint.h> 20 #include <algorithm> 21 #include <iterator> 22 #include <limits> 23 #include <memory> 24 #include <utility> 25 #include <vector> 26 27 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" 28 #include "absl/container/flat_hash_map.h" 29 #include "absl/container/flat_hash_set.h" 30 #include "absl/strings/str_cat.h" 31 #include "absl/strings/str_format.h" 32 #include "absl/types/span.h" 33 #include "llvm/CodeGen/TargetRegisterInfo.h" 34 #include "llvm/CodeGen/TargetSubtargetInfo.h" 35 #include "llvm/IR/BasicBlock.h" 36 #include "llvm/IR/Constants.h" 37 #include "llvm/IR/GlobalVariable.h" 38 #include "llvm/IR/Instructions.h" 39 #include "llvm/IR/Intrinsics.h" 40 #include "llvm/IR/LLVMContext.h" 41 #include "tensorflow/compiler/xla/layout_util.h" 42 #include "tensorflow/compiler/xla/map_util.h" 43 #include "tensorflow/compiler/xla/service/buffer_assignment.h" 44 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" 45 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" 46 #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" 47 #include "tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h" 48 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" 49 #include "tensorflow/compiler/xla/service/cpu/ir_function.h" 50 #include "tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h" 51 #include "tensorflow/compiler/xla/service/cpu/shape_partition.h" 52 #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" 53 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" 54 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" 55 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 56 #include "tensorflow/compiler/xla/service/hlo_instructions.h" 57 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 58 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" 59 #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h" 60 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" 61 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" 62 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" 63 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" 64 #include "tensorflow/compiler/xla/shape_util.h" 65 #include "tensorflow/compiler/xla/status_macros.h" 66 #include "tensorflow/compiler/xla/types.h" 67 #include "tensorflow/compiler/xla/util.h" 68 #include "tensorflow/compiler/xla/window_util.h" 69 #include "tensorflow/core/lib/core/bits.h" 70 #include "tensorflow/core/lib/core/errors.h" 71 #include "tensorflow/core/lib/math/math_util.h" 72 #include "tensorflow/core/platform/logging.h" 73 74 namespace xla { 75 76 namespace { 77 using llvm_ir::IrName; 78 using llvm_ir::SetToFirstInsertPoint; 79 } // namespace 80 81 namespace cpu { 82 83 IrEmitter::IrEmitter( 84 const HloModule& hlo_module, const BufferAssignment& assignment, 85 llvm::Module* llvm_module, 86 std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx, 87 std::unordered_map<const HloComputation*, int64> computation_to_profile_idx, 88 const TargetMachineFeatures* target_machine_features, 89 bool emit_code_for_msan) 90 : assignment_(assignment), 91 module_(llvm_module), 92 arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()), 93 b_(llvm_module->getContext()), 94 instruction_to_profile_idx_(std::move(instruction_to_profile_idx)), 95 computation_to_profile_idx_(std::move(computation_to_profile_idx)), 96 alias_analysis_(hlo_module, assignment, &llvm_module->getContext()), 97 hlo_module_config_(hlo_module.config()), 98 is_top_level_computation_(false), 99 target_machine_features_(*target_machine_features), 100 emit_code_for_msan_(emit_code_for_msan) { 101 b_.setFastMathFlags(llvm_ir::GetCpuFastMathFlags(hlo_module_config_)); 102 Status s = GatherComputationsByAllocationType( 103 &hlo_module, &thread_local_computations_, &global_computations_); 104 absl::c_sort(thread_local_computations_); 105 absl::c_sort(global_computations_); 106 TF_CHECK_OK(s) << "Should have failed buffer assignment."; 107 } 108 109 StatusOr<llvm::Function*> IrEmitter::EmitComputation( 110 HloComputation* computation, const string& function_name_prefix, 111 bool is_top_level_computation, 112 absl::Span<HloInstruction* const> instruction_order) { 113 string function_name = name_uniquer_.GetUniqueName(function_name_prefix); 114 VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix << "]"; 115 is_top_level_computation_ = is_top_level_computation; 116 num_dynamic_loop_bounds_ = 0; 117 if (!computation->root_instruction()->outer_dimension_partitions().empty()) { 118 num_dynamic_loop_bounds_ = 119 computation->root_instruction()->outer_dimension_partitions().size(); 120 } 121 122 if (computation->root_instruction()->opcode() != HloOpcode::kOutfeed) { 123 TF_ASSIGN_OR_RETURN( 124 computation_root_allocation_, 125 assignment_.GetUniqueTopLevelSlice(computation->root_instruction())); 126 } 127 128 for (const HloInstruction* param : computation->parameter_instructions()) { 129 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice param_slice, 130 assignment_.GetUniqueTopLevelSlice(param)); 131 computation_parameter_allocations_[param_slice.allocation()->index()] = 132 param->parameter_number(); 133 } 134 135 InitializeIrFunction(function_name); 136 // The rdtscp instruction is x86 specific. We will fallback to LLVM's generic 137 // readcyclecounter if it is unavailable. 138 bool use_rdtscp = arch_type_ == llvm::Triple::ArchType::x86 || 139 arch_type_ == llvm::Triple::ArchType::x86_64; 140 profiling_state_ = ProfilingState(use_rdtscp); 141 TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, instruction_order)); 142 llvm::Function* ir_function = compute_function_->function(); 143 InsertOrDie(&emitted_functions_, computation, ir_function); 144 // Delete 'compute_function', finalizing 'ir_function' and restoring caller 145 // IR insert point. 146 compute_function_.reset(); 147 computation_root_allocation_ = BufferAllocation::Slice(); 148 computation_parameter_allocations_.clear(); 149 return ir_function; 150 } 151 152 void IrEmitter::InitializeIrFunction(const string& function_name) { 153 // Functions with local linkage get an inlining bonus. Because we know 154 // a-priori that embedded functions (non-entry functions) will not have its 155 // name resolved, give it local linkage. 156 llvm::Function::LinkageTypes linkage = 157 is_top_level_computation_ ? llvm::GlobalValue::ExternalLinkage 158 : llvm::GlobalValue::InternalLinkage; 159 // Create and initialize new IrFunction. 160 compute_function_.reset(new IrFunction(function_name, linkage, 161 hlo_module_config_, module_, &b_, 162 num_dynamic_loop_bounds_)); 163 } 164 165 IrEmitter::~IrEmitter() {} 166 167 Status IrEmitter::HandleBitcast(HloInstruction* bitcast) { 168 VLOG(2) << "HandleBitcast: " << bitcast->ToString(); 169 emitted_value_[bitcast] = 170 BitCast(GetEmittedValueFor(bitcast->operand(0)), 171 IrShapeType(bitcast->shape())->getPointerTo(), IrName(bitcast)); 172 return Status::OK(); 173 } 174 175 llvm::Constant* IrEmitter::EmitGlobalForLiteral(const Literal& literal) { 176 llvm::Constant* initializer = 177 llvm_ir::ConvertLiteralToIrConstant(literal, module_); 178 llvm::GlobalVariable* result_global = new llvm::GlobalVariable( 179 /*Module=*/*module_, 180 /*Type=*/initializer->getType(), 181 /*isConstant=*/true, 182 /*Linkage=*/llvm::GlobalValue::PrivateLinkage, 183 /*Initializer=*/initializer, 184 /*Name=*/""); 185 result_global->setAlignment(MinimumAlignmentForShape(literal.shape())); 186 result_global->setUnnamedAddr(llvm::GlobalVariable::UnnamedAddr::Global); 187 return llvm::ConstantExpr::getBitCast( 188 result_global, IrShapeType(literal.shape())->getPointerTo()); 189 } 190 191 Status IrEmitter::EmitConstantGlobals() { 192 for (const BufferAllocation& allocation : assignment_.Allocations()) { 193 if (!allocation.is_constant()) { 194 continue; 195 } 196 197 const Literal& literal = llvm_ir::LiteralForConstantAllocation(allocation); 198 llvm::Constant* global_for_const; 199 auto it = emitted_literals_.find(&literal); 200 if (it != emitted_literals_.end()) { 201 global_for_const = it->second; 202 } else { 203 global_for_const = EmitGlobalForLiteral(literal); 204 InsertOrDie(&emitted_literals_, &literal, global_for_const); 205 } 206 207 InsertOrDie(&constant_buffer_to_global_, allocation.index(), 208 global_for_const); 209 } 210 211 return Status::OK(); 212 } 213 214 Status IrEmitter::HandleConstant(HloInstruction* constant) { 215 VLOG(2) << "HandleConstant: " << constant->ToString(); 216 // IrEmitter::EmitConstantGlobals has already taken care of emitting the body 217 // of the constant. 218 return EmitTargetAddressForOp(constant); 219 } 220 221 Status IrEmitter::HandleCopy(HloInstruction* copy) { 222 if (copy->shape().IsTuple()) { 223 // kCopy shallow copies a tuple so just memcpy the top-level buffer. 224 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(copy)); 225 return EmitMemcpy(*(copy->operand(0)), *copy); 226 } else if (copy->shape().IsArray()) { 227 // Use the elemental emitter for array shapes. 228 return DefaultAction(copy); 229 } 230 return Unimplemented("unsupported operand type %s for copy instruction", 231 PrimitiveType_Name(copy->shape().element_type())); 232 } 233 234 // Calculate the alignment of a buffer allocated for a given primitive type. 235 int IrEmitter::MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type) { 236 int64 byte_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); 237 DCHECK_GE(byte_size, 0); 238 // Largest scalar is a complex128 so we don't need to worry about the 239 // int64->int truncation here. 240 DCHECK_LE(byte_size, 16); 241 242 // Allocations may be 8-byte aligned if part of a small block. 243 return std::min(8LL, byte_size); 244 } 245 246 int64 IrEmitter::ByteSizeOf(const Shape& shape) const { 247 return llvm_ir::ByteSizeOf(shape, module_->getDataLayout()); 248 } 249 250 // Calculate the alignment of a buffer allocated for a given shape. 251 int IrEmitter::MinimumAlignmentForShape(const Shape& shape) { 252 if (ShapeUtil::IsScalar(shape)) { 253 return MinimumAlignmentForPrimitiveType(shape.element_type()); 254 } 255 256 int64 buffer_size = ByteSizeOf(shape); 257 DCHECK_GE(buffer_size, 0); 258 DCHECK_LE(buffer_size, SIZE_MAX); 259 260 return target_machine_features_.minimum_alignment_for_allocation(buffer_size); 261 } 262 263 void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load, 264 const Shape& shape) { 265 int alignment = MinimumAlignmentForShape(shape); 266 if (alignment > 1) { 267 llvm_ir::SetAlignmentMetadataForLoad(load, alignment); 268 } 269 } 270 271 void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load, 272 int64 buffer_size) { 273 int alignment = 274 target_machine_features_.minimum_alignment_for_allocation(buffer_size); 275 if (alignment > 1) { 276 llvm_ir::SetAlignmentMetadataForLoad(load, alignment); 277 } 278 } 279 280 void IrEmitter::AttachDereferenceableMetadataForLoad(llvm::LoadInst* load, 281 const Shape& shape) { 282 AttachDereferenceableMetadataForLoad(load, ByteSizeOf(shape)); 283 } 284 285 void IrEmitter::AttachDereferenceableMetadataForLoad(llvm::LoadInst* load, 286 int64 buffer_size) { 287 if (buffer_size > 0) { 288 llvm_ir::SetDereferenceableMetadataForLoad(load, buffer_size); 289 } 290 } 291 292 Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) { 293 // A tuple is an array of pointers, one for each operand. Each pointer points 294 // to the output buffer of its corresponding operand. A GetTupleElement 295 // instruction forwards a pointer to the tuple element buffer at the given 296 // index. 297 auto operand = get_tuple_element->operand(0); 298 const Shape& shape = get_tuple_element->shape(); 299 emitted_value_[get_tuple_element] = llvm_ir::EmitGetTupleElement( 300 shape, get_tuple_element->tuple_index(), MinimumAlignmentForShape(shape), 301 GetEmittedValueFor(operand), &b_); 302 return Status::OK(); 303 } 304 305 Status IrEmitter::HandleSelect(HloInstruction* select) { 306 auto pred = select->operand(0); 307 TF_RET_CHECK(pred->shape().element_type() == PRED); 308 return DefaultAction(select); 309 } 310 311 Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) { 312 auto pred = tuple_select->operand(0); 313 auto on_true = tuple_select->operand(1); 314 auto on_false = tuple_select->operand(2); 315 TF_RET_CHECK(pred->shape().element_type() == PRED); 316 TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape())); 317 TF_RET_CHECK(tuple_select->shape().IsTuple()); 318 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(tuple_select)); 319 llvm_ir::EmitTupleSelect(GetIrArrayFor(tuple_select), GetIrArrayFor(pred), 320 GetEmittedValueFor(on_true), 321 GetEmittedValueFor(on_false), &b_); 322 return Status::OK(); 323 } 324 325 Status IrEmitter::HandleInfeed(HloInstruction* instruction) { 326 HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction); 327 VLOG(2) << "HandleInfeed: " << infeed->ToString(); 328 329 // The infeed operation produces a two-element tuple containing data and a 330 // token value. HloInfeedInstruction::infeed_shape gives us the data shape. 331 const Shape& data_shape = infeed->infeed_shape(); 332 DCHECK(ShapeUtil::Equal(data_shape, 333 ShapeUtil::GetTupleElementShape(infeed->shape(), 0))); 334 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(infeed)); 335 336 // Write the tuple index table. 337 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice, 338 assignment_.GetUniqueSlice(infeed, {0})); 339 llvm::Value* data_address = EmitBufferPointer(data_slice, data_shape); 340 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice token_slice, 341 assignment_.GetUniqueSlice(infeed, {1})); 342 llvm::Value* token_address = EmitBufferPointer( 343 token_slice, ShapeUtil::GetTupleElementShape(infeed->shape(), 1)); 344 llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address}, &b_); 345 346 if (data_shape.IsTuple()) { 347 TF_RET_CHECK(!ShapeUtil::IsNestedTuple(data_shape)); 348 349 // For a tuple, we first copy each of the internal elements to 350 // their corresponding target locations. We then construct the 351 // tuple outer buffer containing pointers to the internal 352 // elements. 353 std::vector<llvm::Value*> tuple_element_addresses; 354 for (int64 i = 0; i < data_shape.tuple_shapes_size(); ++i) { 355 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice buffer, 356 assignment_.GetUniqueSlice(infeed, {0, i})); 357 358 const Shape& tuple_element_shape = 359 ShapeUtil::GetTupleElementShape(data_shape, i); 360 361 // Only the outer tuple buffer's target address is obtained from 362 // GetEmittedValueFor, to handle the case when Infeed is the root 363 // instruction. Target addresses for internal elements can be obtained 364 // from EmitBufferPointer. 365 llvm::Value* tuple_element_address = 366 EmitBufferPointer(buffer, tuple_element_shape); 367 368 TF_RETURN_IF_ERROR(EmitXfeedTransfer( 369 XfeedKind::kInfeed, tuple_element_shape, tuple_element_address)); 370 371 tuple_element_addresses.push_back(tuple_element_address); 372 } 373 374 llvm_ir::EmitTuple(llvm_ir::IrArray(data_address, data_shape), 375 tuple_element_addresses, &b_); 376 } else { 377 TF_RETURN_IF_ERROR( 378 EmitXfeedTransfer(XfeedKind::kInfeed, data_shape, data_address)); 379 } 380 381 return Status::OK(); 382 } 383 384 Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, 385 llvm::Value* program_buffer_address) { 386 int64 length = ByteSizeOf(shape); 387 if (length <= 0 || length > std::numeric_limits<int32>::max()) { 388 return InvalidArgument( 389 "xfeed (infeed or outfeed) buffer length %d is outside the valid " 390 "size range", 391 length); 392 } 393 int32 length_32 = static_cast<int32>(length); 394 395 int32 shape_length; 396 TF_ASSIGN_OR_RETURN( 397 llvm::Value * shape_ptr, 398 llvm_ir::EncodeSelfDescribingShapeConstant(shape, &shape_length, &b_)); 399 400 llvm::Type* int32_type = b_.getInt32Ty(); 401 llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); 402 llvm::FunctionType* acquire_type = llvm::FunctionType::get( 403 i8_ptr_type, 404 {/*run_options*/ i8_ptr_type, /*buffer_length*/ int32_type, 405 /*shape_ptr*/ i8_ptr_type, /*shape_length*/ int32_type}, 406 /*isVarArg=*/false); 407 408 llvm::Function* acquire_func; 409 if (kind == XfeedKind::kInfeed) { 410 acquire_func = llvm::dyn_cast<llvm::Function>( 411 module_ 412 ->getOrInsertFunction( 413 runtime::kAcquireInfeedBufferForDequeueSymbolName, acquire_type) 414 .getCallee()); 415 } else { 416 acquire_func = llvm::dyn_cast<llvm::Function>( 417 module_ 418 ->getOrInsertFunction( 419 runtime::kAcquireOutfeedBufferForPopulationSymbolName, 420 acquire_type) 421 .getCallee()); 422 } 423 acquire_func->setCallingConv(llvm::CallingConv::C); 424 425 llvm::FunctionType* release_type = llvm::FunctionType::get( 426 b_.getVoidTy(), 427 {/*run_options*/ i8_ptr_type, /*buffer_length*/ int32_type, 428 /*buffer_ptr*/ i8_ptr_type, /*shape_ptr*/ i8_ptr_type, 429 /*shape_length*/ int32_type}, 430 /*isVarArg=*/false); 431 432 llvm::Function* release_func; 433 if (kind == XfeedKind::kInfeed) { 434 release_func = llvm::dyn_cast<llvm::Function>( 435 module_ 436 ->getOrInsertFunction( 437 runtime::kReleaseInfeedBufferAfterDequeueSymbolName, 438 release_type) 439 .getCallee()); 440 } else { 441 release_func = llvm::dyn_cast<llvm::Function>( 442 module_ 443 ->getOrInsertFunction( 444 runtime::kReleaseOutfeedBufferAfterPopulationSymbolName, 445 release_type) 446 .getCallee()); 447 } 448 release_func->setCallingConv(llvm::CallingConv::C); 449 450 // Implementation note: this call informs the runtime that it wants a buffer 451 // of size exactly 'length_32', and the runtime is responsible for 452 // check-failing the process if there is a mismatch, versus passing us back a 453 // buffer that we might overrun. 454 llvm::Value* acquired_pointer = Call( 455 acquire_func, {GetExecutableRunOptionsArgument(), b_.getInt32(length_32), 456 shape_ptr, b_.getInt32(shape_length)}); 457 458 if (kind == XfeedKind::kInfeed) { 459 // Copy to the program buffer address from the acquired buffer. 460 MemCpy(program_buffer_address, /*DstAlign=*/1, acquired_pointer, 461 /*SrcAlign=*/1, length_32); 462 } else { 463 // Outfeed -- copy from the in-program address to the acquired buffer. 464 MemCpy(acquired_pointer, /*DstAlign=*/1, program_buffer_address, 465 /*SrcAlign=*/1, length_32); 466 } 467 468 Call(release_func, {GetExecutableRunOptionsArgument(), b_.getInt32(length_32), 469 acquired_pointer, shape_ptr, b_.getInt32(shape_length)}); 470 471 return Status::OK(); 472 } 473 474 Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) { 475 // Outfeed produces no useful result, but it does return a token[] that can be 476 // threaded through to other side effecting operations to ensure ordering. In 477 // the IR emitter we treat this token as a normal u8[] and thus need to insert 478 // an entry for it in emitted_value_. 479 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(outfeed)); 480 481 HloInstruction* operand = outfeed->operands()[0]; 482 const Shape& operand_shape = operand->shape(); 483 484 llvm::Value* value = GetEmittedValueFor(operand); 485 if (!operand_shape.IsTuple()) { 486 return EmitXfeedTransfer(XfeedKind::kOutfeed, operand_shape, value); 487 } 488 489 TF_RET_CHECK(!ShapeUtil::IsNestedTuple(operand_shape)); 490 491 for (int64 i = 0; i < operand_shape.tuple_shapes_size(); ++i) { 492 const Shape& tuple_element_shape = 493 ShapeUtil::GetTupleElementShape(operand_shape, i); 494 llvm::Value* tuple_element = llvm_ir::EmitGetTupleElement( 495 tuple_element_shape, i, MinimumAlignmentForShape(tuple_element_shape), 496 value, &b_); 497 TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kOutfeed, 498 tuple_element_shape, tuple_element)); 499 } 500 501 return Status::OK(); 502 } 503 504 Status IrEmitter::HandleSort(HloInstruction* hlo) { 505 const HloSortInstruction* sort = Cast<HloSortInstruction>(hlo); 506 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(sort)); 507 Shape keys_shape = sort->keys()->shape(); 508 PrimitiveType keys_type = keys_shape.element_type(); 509 switch (keys_type) { 510 case PRED: 511 case S8: 512 case U8: 513 case S16: 514 case U16: 515 case BF16: 516 case F16: 517 case S32: 518 case U32: 519 case F32: 520 case S64: 521 case U64: 522 case F64: 523 break; 524 default: 525 return Unimplemented( 526 "Element type %s not supported in the Sort op on CPU.", 527 PrimitiveType_Name(keys_type)); 528 } 529 std::vector<llvm::Value*> destination_addresses(sort->operand_count()); 530 for (int64 i = 0; i < sort->operand_count(); ++i) { 531 ShapeIndex shape_index = 532 sort->values_count() > 0 ? ShapeIndex({i}) : ShapeIndex({}); 533 const HloInstruction* operand = sort->operand(i); 534 // We assume that the layout of all involved operands and outputs is the 535 // same. 536 TF_RET_CHECK( 537 LayoutUtil::LayoutsInShapesEqual(keys_shape, operand->shape())); 538 TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual( 539 keys_shape, ShapeUtil::GetSubshape(sort->shape(), shape_index))); 540 541 // The sort is implemented in-place, therefore we first copy the operand 542 // buffer to the output buffer if they are not the same. 543 auto destination_buffer = GetAllocationSlice(*sort, shape_index); 544 destination_addresses[i] = 545 EmitBufferPointer(destination_buffer, operand->shape()); 546 auto source_address = GetAllocationSlice(*operand); 547 if (destination_buffer != source_address) { 548 int64 primitive_type_size = 549 ShapeUtil::ByteSizeOfPrimitiveType(operand->shape().element_type()); 550 auto source_buffer = GetEmittedValueFor(operand); 551 int64 size = ByteSizeOf(operand->shape()); 552 MemCpy(destination_addresses[i], /*DstAlign=*/primitive_type_size, 553 source_buffer, 554 /*SrcAlign=*/primitive_type_size, size); 555 } 556 } 557 558 // Normalize the shape and the dimension to sort. 559 Shape normalized_keys_shape = 560 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(keys_shape); 561 int64 physical_dimension_to_sort = LayoutUtil::MakeLogicalToPhysical( 562 keys_shape.layout())[sort->sort_dimension()]; 563 564 int64 sort_dimension_elements = 565 normalized_keys_shape.dimensions(physical_dimension_to_sort); 566 int64 higher_dimensions = 1; 567 for (int64 i = 0; i < physical_dimension_to_sort; ++i) { 568 higher_dimensions *= normalized_keys_shape.dimensions(i); 569 } 570 int64 lower_dimensions = 1; 571 for (int64 i = normalized_keys_shape.rank() - 1; 572 i > physical_dimension_to_sort; --i) { 573 lower_dimensions *= normalized_keys_shape.dimensions(i); 574 } 575 576 auto less_than_function = FindOrDie(emitted_functions_, sort->to_apply()); 577 CHECK(absl::c_binary_search(thread_local_computations_, sort->to_apply())); 578 llvm::FunctionType* key_value_sort_type = llvm::FunctionType::get( 579 b_.getVoidTy(), 580 {b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(), 581 b_.getInt8PtrTy()->getPointerTo(), b_.getInt32Ty(), 582 b_.getInt32Ty()->getPointerTo(), b_.getInt1Ty(), b_.getInt8PtrTy(), 583 b_.getInt64Ty()->getPointerTo(), less_than_function->getType()}, 584 /*isVarArg=*/false); 585 auto* key_value_sort_func = llvm::dyn_cast<llvm::Function>( 586 module_ 587 ->getOrInsertFunction(runtime::kKeyValueSortSymbolName, 588 key_value_sort_type) 589 .getCallee()); 590 key_value_sort_func->setCallingConv(llvm::CallingConv::C); 591 key_value_sort_func->setDoesNotThrow(); 592 llvm::Value* values = llvm_ir::EmitAllocaAtFunctionEntryWithCount( 593 b_.getInt8PtrTy(), b_.getInt32(sort->operand_count()), "cc_values_alloca", 594 &b_); 595 llvm::Value* sizes = llvm_ir::EmitAllocaAtFunctionEntryWithCount( 596 b_.getInt32Ty(), b_.getInt32(sort->operand_count()), "cc_sizes_alloca", 597 &b_); 598 for (int64 i = 0; i < sort->operand_count(); ++i) { 599 llvm::Value* value_as_i8ptr = 600 PointerCast(destination_addresses[i], b_.getInt8PtrTy()); 601 llvm::Value* slot_in_values_alloca = 602 ConstInBoundsGEP1_32(b_.getInt8PtrTy(), values, i); 603 Store(value_as_i8ptr, slot_in_values_alloca); 604 llvm::Value* slot_in_sizes_alloca = 605 ConstInBoundsGEP1_32(b_.getInt32Ty(), sizes, i); 606 llvm::Value* size = b_.getInt32(ShapeUtil::ByteSizeOfPrimitiveType( 607 sort->operand(i)->shape().element_type())); 608 Store(size, slot_in_sizes_alloca); 609 } 610 611 Call(key_value_sort_func, 612 {b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements), 613 b_.getInt64(lower_dimensions), values, 614 b_.getInt32(sort->operand_count()), sizes, 615 b_.getInt1(sort->is_stable()), GetExecutableRunOptionsArgument(), 616 GetProfileCountersArgument(), less_than_function}); 617 618 if (sort->values_count() > 0) { 619 llvm_ir::EmitTuple(GetIrArrayFor(sort), destination_addresses, &b_); 620 } 621 return Status::OK(); 622 } 623 624 Status IrEmitter::HandleTuple(HloInstruction* tuple) { 625 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(tuple)); 626 std::vector<llvm::Value*> base_ptrs; 627 for (auto operand : tuple->operands()) { 628 base_ptrs.push_back(GetEmittedValueFor(operand)); 629 } 630 llvm_ir::EmitTuple(GetIrArrayFor(tuple), base_ptrs, &b_); 631 return Status::OK(); 632 } 633 634 llvm::Value* IrEmitter::EmitElementalMap( 635 const HloMapInstruction& map_instr, 636 absl::Span<llvm::Value* const> elemental_operands, absl::string_view name) { 637 return EmitThreadLocalCall(*map_instr.to_apply(), elemental_operands, name); 638 } 639 640 StatusOr<llvm::Value*> IrEmitter::EmitElementalReduceWindow( 641 const HloReduceWindowInstruction* reduce_window, 642 const llvm_ir::ElementGenerator& input_generator, 643 const llvm_ir::IrArray::Index& index) { 644 const HloInstruction* operand = reduce_window->operand(0); 645 const Window& window = reduce_window->window(); 646 647 // We fold inputs into the accumulator and initialize it to 648 // the initial value on the reduce_window. 649 PrimitiveType operand_element_type = operand->shape().element_type(); 650 llvm::Value* accumulator_address = llvm_ir::EmitAllocaAtFunctionEntry( 651 llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), 652 "reduce_window_accumulator_address", &b_, 653 MinimumAlignmentForPrimitiveType(operand_element_type)); 654 Store(Load(GetEmittedValueFor(reduce_window->operand(1))), 655 accumulator_address); 656 657 llvm_ir::ForLoopNest loops(IrName(reduce_window, "inner"), &b_); 658 std::vector<int64> window_size; 659 for (const auto& dim : window.dimensions()) { 660 window_size.push_back(dim.size()); 661 } 662 const llvm_ir::IrArray::Index window_index = loops.AddLoopsForShape( 663 ShapeUtil::MakeShape(operand_element_type, window_size), "window"); 664 CHECK_EQ(window_index.size(), index.size()); 665 666 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); 667 668 std::vector<llvm::Value*> input_multi_index(index.size()); 669 llvm::Value* in_bounds_condition = nullptr; 670 for (size_t i = 0; i < index.size(); ++i) { 671 llvm::Value* strided_index = 672 NSWMul(index[i], b_.getInt64(window.dimensions(i).stride())); 673 input_multi_index[i] = NSWSub( 674 NSWAdd(strided_index, 675 NSWMul(window_index[i], 676 b_.getInt64(window.dimensions(i).window_dilation()))), 677 b_.getInt64(window.dimensions(i).padding_low())); 678 679 // We need to verify that we are not in the dilated base area. 680 llvm::Value* dilation_condition = 681 ICmpEQ(SRem(input_multi_index[i], 682 b_.getInt64(window.dimensions(i).base_dilation())), 683 b_.getInt64(0)); 684 if (in_bounds_condition == nullptr) { 685 in_bounds_condition = dilation_condition; 686 } else { 687 in_bounds_condition = And(in_bounds_condition, dilation_condition); 688 } 689 690 // Apply base dilation to the index. 691 input_multi_index[i] = 692 SDiv(input_multi_index[i], 693 b_.getInt64(window.dimensions(i).base_dilation())); 694 695 // We need to check if 0 <= input_multi_index[i] < bound, as otherwise we 696 // are in the padding so that we can skip the computation. That is 697 // equivalent to input_multi_index[i] < bound as an *unsigned* comparison, 698 // since a negative value will wrap to a large positive value. 699 llvm::Value* index_condition = 700 ICmpULT(input_multi_index[i], 701 b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); 702 if (in_bounds_condition == nullptr) { 703 in_bounds_condition = index_condition; 704 } else { 705 in_bounds_condition = And(in_bounds_condition, index_condition); 706 } 707 } 708 CHECK(in_bounds_condition != nullptr); 709 710 llvm_ir::LlvmIfData if_data = 711 llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_); 712 SetToFirstInsertPoint(if_data.true_block, &b_); 713 714 // We are not in the padding, so carry out the computation. 715 llvm_ir::IrArray::Index input_index(input_multi_index, operand->shape(), 716 b_.getInt64Ty()); 717 TF_ASSIGN_OR_RETURN(llvm::Value* const input_value, 718 input_generator(input_index)); 719 llvm::Value* result = EmitThreadLocalCall( 720 *reduce_window->to_apply(), {Load(accumulator_address), input_value}, 721 "reducer_function"); 722 Store(result, accumulator_address); 723 724 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); 725 return Load(accumulator_address); 726 } 727 728 Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { 729 // Pseudo code for reduce window: 730 // 731 // for (coordinates O in the output) 732 // value = init_value; 733 // for (coordinates W in the window) 734 // for each index i: 735 // input coordinates I_i = O_i * stride_i + W_i - pad_low_i 736 // if I within bounds of input: 737 // value = function(value, input(I)); 738 // output(O) = value; 739 // 740 // This is completely un-optimized and just here to have something 741 // that works. 742 return DefaultAction(reduce_window); 743 } 744 745 Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { 746 CHECK_EQ(select_and_scatter->operand_count(), 3); 747 const auto operand = select_and_scatter->operand(0); 748 const auto source = select_and_scatter->operand(1); 749 const auto init_value = select_and_scatter->operand(2); 750 const Window& window = select_and_scatter->window(); 751 PrimitiveType operand_element_type = operand->shape().element_type(); 752 const int64 rank = operand->shape().rank(); 753 CHECK_EQ(rank, source->shape().rank()); 754 CHECK_EQ(rank, window.dimensions_size()); 755 756 // TODO(b/31410564): Implement dilation for select-and-scatter. 757 if (window_util::HasDilation(window)) { 758 return Unimplemented( 759 "Dilation for SelectAndScatter is not implemented on CPU. "); 760 } 761 762 // Pseudo code for select-and-scatter: 763 // 764 // initialized_flag is initially off for every window, and is turned on after 765 // the first iteration is completed and the first operand value is selected. 766 // 767 // output(*) = init_value 768 // for (coordinates S in the source) { 769 // initialized_flag = false 770 // for (coordinates W in the window) { 771 // I = S * stride + W - pad_low 772 // if I within bounds of operand: 773 // if !initialized_flag or select(selected_value, operand(I)) == false: 774 // selected_value = operand(I) 775 // selected_index = I 776 // initialized_flag = true 777 // } 778 // output(selected_index) = scatter(output(selected_index), source(S)) 779 // } 780 // 781 782 // Initialize the output array with the given init_value. 783 TF_RETURN_IF_ERROR(EmitTargetElementLoop( 784 select_and_scatter, /*desc=*/IrName(select_and_scatter, "init"), 785 [this, init_value](const llvm_ir::IrArray::Index& target_index) { 786 llvm::Value* init_value_addr = GetEmittedValueFor(init_value); 787 return Load(init_value_addr); 788 })); 789 790 // Create a loop to iterate over the source array to scatter to the output. 791 llvm_ir::ForLoopNest source_loops(IrName(select_and_scatter), &b_); 792 const llvm_ir::IrArray::Index source_index = 793 source_loops.AddLoopsForShape(source->shape(), "source"); 794 SetToFirstInsertPoint(source_loops.GetInnerLoopBodyBasicBlock(), &b_); 795 796 // Allocate space to keep the currently selected value, its index, and 797 // the boolean initialized_flag, which is initially set to false. 798 llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry( 799 llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), 800 "selected_value_address", &b_, 801 MinimumAlignmentForPrimitiveType(operand_element_type)); 802 llvm::Value* selected_index_address = 803 llvm_ir::EmitAllocaAtFunctionEntryWithCount( 804 b_.getInt64Ty(), b_.getInt32(rank), "selected_index_address", &b_); 805 llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry( 806 b_.getInt1Ty(), "initialized_flag_address", &b_); 807 Store(b_.getInt1(false), initialized_flag_address); 808 809 // Create the inner loop to iterate over the window. 810 llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "window"), &b_); 811 std::vector<int64> window_size; 812 for (const auto& dim : window.dimensions()) { 813 window_size.push_back(dim.size()); 814 } 815 const llvm_ir::IrArray::Index window_index = window_loops.AddLoopsForShape( 816 ShapeUtil::MakeShape(operand_element_type, window_size), "window"); 817 SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(), &b_); 818 819 // Compute the operand index to visit and evaluate the condition whether the 820 // operand index is within the bounds. The unsigned comparison includes 821 // checking whether the operand index >= 0. 822 std::vector<llvm::Value*> operand_multi_index(source_index.size()); 823 llvm::Value* in_bounds_condition = b_.getTrue(); 824 for (int64 i = 0; i < rank; ++i) { 825 llvm::Value* strided_index = 826 NSWMul(source_index[i], b_.getInt64(window.dimensions(i).stride())); 827 operand_multi_index[i] = 828 NSWSub(NSWAdd(strided_index, window_index[i]), 829 b_.getInt64(window.dimensions(i).padding_low())); 830 llvm::Value* index_condition = 831 ICmpULT(operand_multi_index[i], 832 b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); 833 in_bounds_condition = And(in_bounds_condition, index_condition); 834 } 835 CHECK(in_bounds_condition != nullptr); 836 837 // Only need to do something if the operand index is within the bounds. First 838 // check if the initialized_flag is set. 839 llvm_ir::LlvmIfData if_in_bounds = 840 llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_); 841 SetToFirstInsertPoint(if_in_bounds.true_block, &b_); 842 llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse( 843 Load(initialized_flag_address), "initialized", &b_); 844 845 // If the initialized_flag is false, initialize the selected value and index 846 // with the currently visiting operand. 847 SetToFirstInsertPoint(if_initialized.false_block, &b_); 848 const auto save_operand_index = 849 [&](const llvm_ir::IrArray::Index& operand_index) { 850 for (int64 i = 0; i < rank; ++i) { 851 llvm::Value* selected_index_address_slot = 852 InBoundsGEP(selected_index_address, {b_.getInt32(i)}); 853 Store(operand_index[i], selected_index_address_slot); 854 } 855 }; 856 llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); 857 llvm_ir::IrArray::Index operand_index( 858 operand_multi_index, operand_array.GetShape(), b_.getInt64Ty()); 859 llvm::Value* operand_data = 860 operand_array.EmitReadArrayElement(operand_index, &b_); 861 Store(operand_data, selected_value_address); 862 save_operand_index(operand_index); 863 Store(b_.getInt1(true), initialized_flag_address); 864 865 // If the initialized_flag is true, call the `select` function to potentially 866 // update the selected value and index with the currently visiting operand. 867 SetToFirstInsertPoint(if_initialized.true_block, &b_); 868 llvm::Value* operand_address = 869 operand_array.EmitArrayElementAddress(operand_index, &b_); 870 llvm::Value* operand_element = Load(operand_address); 871 llvm::Value* result = EmitThreadLocalCall( 872 *select_and_scatter->select(), 873 {Load(selected_value_address), operand_element}, "select_function"); 874 875 // If the 'select' function returns false, update the selected value and the 876 // index to the currently visiting operand. 877 llvm::Value* cond = ICmpNE( 878 result, 879 llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0), 880 "boolean_predicate"); 881 llvm_ir::LlvmIfData if_select_lhs = 882 llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_); 883 SetToFirstInsertPoint(if_select_lhs.false_block, &b_); 884 Store(Load(operand_address), selected_value_address); 885 save_operand_index(operand_index); 886 887 // After iterating over the window elements, scatter the source element to 888 // the selected index of the output. The value we store at the output 889 // location is computed by calling the `scatter` function with the source 890 // value and the current output value. 891 SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(), &b_); 892 std::vector<llvm::Value*> selected_multi_index; 893 for (int64 i = 0; i < rank; ++i) { 894 llvm::Value* selected_index_address_slot = 895 InBoundsGEP(selected_index_address, {b_.getInt32(i)}); 896 selected_multi_index.push_back(Load(selected_index_address_slot)); 897 } 898 llvm_ir::IrArray source_array(GetIrArrayFor(source)); 899 llvm::Value* source_value = 900 source_array.EmitReadArrayElement(source_index, &b_); 901 llvm_ir::IrArray output_array(GetIrArrayFor(select_and_scatter)); 902 llvm_ir::IrArray::Index selected_index( 903 selected_multi_index, output_array.GetShape(), source_index.GetType()); 904 llvm::Value* output_value = 905 output_array.EmitReadArrayElement(selected_index, &b_); 906 llvm::Value* scatter_value = 907 EmitThreadLocalCall(*select_and_scatter->scatter(), 908 {output_value, source_value}, "scatter_function"); 909 output_array.EmitWriteArrayElement(selected_index, scatter_value, &b_); 910 911 SetToFirstInsertPoint(source_loops.GetOuterLoopExitBasicBlock(), &b_); 912 return Status::OK(); 913 } 914 915 Status IrEmitter::HandleDot(HloInstruction* dot) { 916 auto lhs = dot->operand(0); 917 auto rhs = dot->operand(1); 918 TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( 919 /*instruction=*/*dot, /*operands=*/{lhs, rhs}, 920 /*supported_types=*/{F16, F32, F64, C64, C128})); 921 const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); 922 923 if (dnums.lhs_contracting_dimensions_size() != 1) { 924 // This is disallowed by ShapeInference today. 925 return Unimplemented( 926 "Dot with multiple contracting dimensions not implemented."); 927 } 928 929 llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs)); 930 llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs)); 931 932 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dot)); 933 llvm_ir::IrArray target_array = GetIrArrayFor(dot); 934 935 VLOG(2) << "HandleDot: "; 936 VLOG(2) << " lhs operand: " 937 << llvm_ir::DumpToString(*lhs_array.GetBasePointer()); 938 VLOG(2) << " rhs operand: " 939 << llvm_ir::DumpToString(*rhs_array.GetBasePointer()); 940 VLOG(2) << " target: " 941 << llvm_ir::DumpToString(*target_array.GetBasePointer()); 942 943 // Dot operation is complicated so we delegate to a helper class. 944 return EmitDotOperation(*dot, target_array, lhs_array, rhs_array, 945 /*addend_array=*/nullptr, 946 GetExecutableRunOptionsArgument(), &b_, 947 hlo_module_config_, target_machine_features_); 948 } 949 950 StatusOr<llvm::Value*> IrEmitter::EmitElementalConvolution( 951 const HloConvolutionInstruction* convolution, 952 const llvm_ir::ElementGenerator& input_generator, 953 const llvm_ir::ElementGenerator& kernel_generator, 954 const llvm_ir::IrArray::Index& index) { 955 const HloInstruction* lhs = convolution->operand(0); 956 const HloInstruction* rhs = convolution->operand(1); 957 const Window& window = convolution->window(); 958 959 const ConvolutionDimensionNumbers& dnums = 960 convolution->convolution_dimension_numbers(); 961 int num_spatial_dims = dnums.output_spatial_dimensions_size(); 962 std::vector<llvm::Value*> output_spatial(num_spatial_dims); 963 for (int i = 0; i < num_spatial_dims; ++i) { 964 output_spatial[i] = index[dnums.output_spatial_dimensions(i)]; 965 } 966 llvm::Value* output_feature = index[dnums.output_feature_dimension()]; 967 llvm::Value* batch = index[dnums.output_batch_dimension()]; 968 969 // We will accumulate the products into this sum to calculate the output entry 970 // at the given index. 971 PrimitiveType lhs_element_type = lhs->shape().element_type(); 972 llvm::Type* lhs_llvm_type = 973 llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_); 974 llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry( 975 lhs_llvm_type, "convolution_sum_address", &b_, 976 MinimumAlignmentForPrimitiveType(lhs_element_type)); 977 llvm::Value* constant_zero = llvm::Constant::getNullValue(lhs_llvm_type); 978 Store(constant_zero, sum_address); 979 980 llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &b_); 981 std::vector<llvm::Value*> kernel_spatial(num_spatial_dims); 982 for (int i = 0; i < num_spatial_dims; ++i) { 983 kernel_spatial[i] = 984 loops 985 .AddLoop( 986 0, rhs->shape().dimensions(dnums.kernel_spatial_dimensions(i)), 987 absl::StrCat("k", i)) 988 ->GetIndVarValue(); 989 } 990 llvm::Value* input_feature = 991 loops 992 .AddLoop(0, lhs->shape().dimensions(dnums.input_feature_dimension()), 993 "iz") 994 ->GetIndVarValue(); 995 996 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); 997 998 // Calculate the spatial index in the input array, taking striding, dilation 999 // and padding into account. An index in the padding will be out of the bounds 1000 // of the array. 1001 const auto calculate_input_index = [this](llvm::Value* output_index, 1002 llvm::Value* kernel_index, 1003 const WindowDimension& window_dim) { 1004 llvm::Value* strided_index = 1005 NSWMul(output_index, b_.getInt64(window_dim.stride())); 1006 llvm::Value* dilated_kernel_index = 1007 NSWMul(kernel_index, b_.getInt64(window_dim.window_dilation())); 1008 return NSWSub(NSWAdd(strided_index, dilated_kernel_index), 1009 b_.getInt64(window_dim.padding_low())); 1010 }; 1011 std::vector<llvm::Value*> input_spatial(num_spatial_dims); 1012 for (int i = 0; i < num_spatial_dims; ++i) { 1013 input_spatial[i] = calculate_input_index( 1014 output_spatial[i], kernel_spatial[i], window.dimensions(i)); 1015 } 1016 1017 // We need to check if 0 <= input dim < bound, as otherwise we are in the 1018 // padding so that we can skip the computation. That is equivalent to input 1019 // dim < bound as an *unsigned* comparison, since a negative value will wrap 1020 // to a large positive value. The input dim is dilated, so we need to dilate 1021 // the bound as well to match. 1022 1023 // Also need to check that the input coordinates are not in one of the 1024 // holes created by base dilation. 1025 const auto not_in_hole = [&](llvm::Value* input_index, int64 base_dilation) { 1026 llvm::Value* remainder = SRem(input_index, b_.getInt64(base_dilation)); 1027 return ICmpEQ(remainder, b_.getInt64(0)); 1028 }; 1029 1030 llvm::Value* in_bounds_condition = b_.getInt1(true); 1031 for (int i = 0; i < num_spatial_dims; ++i) { 1032 llvm::ConstantInt* input_bound = b_.getInt64(window_util::DilatedBound( 1033 lhs->shape().dimensions(dnums.input_spatial_dimensions(i)), 1034 window.dimensions(i).base_dilation())); 1035 llvm::Value* dim_in_bound = ICmpULT(input_spatial[i], input_bound); 1036 llvm::Value* dim_not_in_hole = 1037 not_in_hole(input_spatial[i], window.dimensions(i).base_dilation()); 1038 llvm::Value* dim_ok = And(dim_in_bound, dim_not_in_hole); 1039 in_bounds_condition = And(in_bounds_condition, dim_ok); 1040 } 1041 1042 // Now we need to map the dilated base coordinates back to the actual 1043 // data indices on the lhs. 1044 const auto undilate = [&](llvm::Value* input_index, int64 base_dilation) { 1045 return SDiv(input_index, b_.getInt64(base_dilation)); 1046 }; 1047 for (int i = 0; i < num_spatial_dims; ++i) { 1048 input_spatial[i] = 1049 undilate(input_spatial[i], window.dimensions(i).base_dilation()); 1050 } 1051 1052 llvm_ir::LlvmIfData if_data = 1053 llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_); 1054 SetToFirstInsertPoint(if_data.true_block, &b_); 1055 1056 // We are not in the padding, so carry out the computation. 1057 int num_dims = num_spatial_dims + 2; 1058 std::vector<llvm::Value*> input_multi_index(num_dims); 1059 for (int i = 0; i < num_spatial_dims; ++i) { 1060 input_multi_index[dnums.input_spatial_dimensions(i)] = input_spatial[i]; 1061 } 1062 input_multi_index[dnums.input_feature_dimension()] = input_feature; 1063 input_multi_index[dnums.input_batch_dimension()] = batch; 1064 1065 std::vector<llvm::Value*> kernel_multi_index(num_dims); 1066 for (int i = 0; i < num_spatial_dims; ++i) { 1067 kernel_multi_index[dnums.kernel_spatial_dimensions(i)] = 1068 window.dimensions(i).window_reversal() 1069 ? NSWSub(b_.getInt64(window.dimensions(i).size() - 1), 1070 kernel_spatial[i]) 1071 : kernel_spatial[i]; 1072 } 1073 1074 kernel_multi_index[dnums.kernel_input_feature_dimension()] = input_feature; 1075 kernel_multi_index[dnums.kernel_output_feature_dimension()] = output_feature; 1076 1077 llvm_ir::IrArray::Index input_index(input_multi_index, lhs->shape(), 1078 b_.getInt64Ty()); 1079 TF_ASSIGN_OR_RETURN(llvm::Value* const input_value, 1080 input_generator(input_index)); 1081 llvm_ir::IrArray::Index kernel_index(kernel_multi_index, rhs->shape(), 1082 b_.getInt64Ty()); 1083 TF_ASSIGN_OR_RETURN(llvm::Value* const kernel_value, 1084 kernel_generator(kernel_index)); 1085 llvm::Value* product = FMul(input_value, kernel_value); 1086 llvm::Value* sum = FAdd(Load(sum_address), product); 1087 Store(sum, sum_address); 1088 1089 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); 1090 return Load(sum_address); 1091 } 1092 1093 Status IrEmitter::HandleConvolution(HloInstruction* convolution) { 1094 auto lhs = convolution->operand(0); 1095 auto rhs = convolution->operand(1); 1096 TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( 1097 /*instruction=*/*convolution, /*operands=*/{lhs, rhs}, 1098 /*supported_types=*/{F16, F32, F64, C64, C128})); 1099 1100 // TODO(tonywy): Add PotentiallyImplementedAsMKLCovolution to support 1101 // different data layouts. 1102 if (PotentiallyImplementedAsEigenConvolution(*convolution, 1103 target_machine_features_)) { 1104 const Shape& lhs_shape = lhs->shape(); 1105 const Shape& rhs_shape = rhs->shape(); 1106 const Shape& convolution_shape = convolution->shape(); 1107 // The input, kernel and output agree with respect to layout. 1108 if (LayoutUtil::IsMonotonicWithDim0Major(lhs_shape.layout()) && 1109 LayoutUtil::IsMonotonicWithDim0Major(rhs_shape.layout()) && 1110 LayoutUtil::IsMonotonicWithDim0Major(convolution_shape.layout())) { 1111 // We lower 1D convolutions into calls to the same Eigen function as 2D 1112 // convolutions, except that we pretend that the 1D convolution is really 1113 // a 2D convolution with the missing dimension set to 1. We also adjust 1114 // the padding, dilation parameters as needed. 1115 bool one_dim_convolution = lhs_shape.dimensions_size() == 3; 1116 llvm::Value* lhs_address = GetEmittedValueFor(lhs); 1117 llvm::Value* rhs_address = GetEmittedValueFor(rhs); 1118 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(convolution)); 1119 1120 const ConvolutionDimensionNumbers& dnums = 1121 convolution->convolution_dimension_numbers(); 1122 1123 // Input tensor. 1124 const Shape& input_shape = convolution->operand(0)->shape(); 1125 int64 input_batch = input_shape.dimensions(dnums.input_batch_dimension()); 1126 int64 input_rows = 1127 input_shape.dimensions(dnums.input_spatial_dimensions(0)); 1128 int64 input_cols = 1129 one_dim_convolution 1130 ? 1 1131 : input_shape.dimensions(dnums.input_spatial_dimensions(1)); 1132 int64 input_channels = 1133 input_shape.dimensions(dnums.input_feature_dimension()); 1134 1135 // Kernel tensor. 1136 const Shape& kernel_shape = convolution->operand(1)->shape(); 1137 int64 kernel_rows = 1138 kernel_shape.dimensions(dnums.kernel_spatial_dimensions(0)); 1139 int64 kernel_cols = 1140 one_dim_convolution 1141 ? 1 1142 : kernel_shape.dimensions(dnums.kernel_spatial_dimensions(1)); 1143 int64 kernel_channels = 1144 kernel_shape.dimensions(dnums.kernel_input_feature_dimension()); 1145 int64 kernel_filters = 1146 kernel_shape.dimensions(dnums.kernel_output_feature_dimension()); 1147 1148 // Output tensor. 1149 const Shape& convolution_shape = convolution->shape(); 1150 int64 output_rows = 1151 convolution_shape.dimensions(dnums.output_spatial_dimensions(0)); 1152 int64 output_cols = one_dim_convolution 1153 ? 1 1154 : convolution_shape.dimensions( 1155 dnums.output_spatial_dimensions(1)); 1156 1157 // Extract the window stride for the convolution. 1158 const Window& window = convolution->window(); 1159 int64 row_stride = window.dimensions(0).stride(); 1160 int64 col_stride = 1161 one_dim_convolution ? 1 : window.dimensions(1).stride(); 1162 1163 int64 padding_top = window.dimensions(0).padding_low(); 1164 int64 padding_bottom = window.dimensions(0).padding_high(); 1165 int64 padding_left = 1166 one_dim_convolution ? 0 : window.dimensions(1).padding_low(); 1167 int64 padding_right = 1168 one_dim_convolution ? 0 : window.dimensions(1).padding_high(); 1169 1170 int64 lhs_row_dilation = window.dimensions(0).base_dilation(); 1171 int64 lhs_col_dilation = 1172 one_dim_convolution ? 1 : window.dimensions(1).base_dilation(); 1173 int64 rhs_row_dilation = window.dimensions(0).window_dilation(); 1174 int64 rhs_col_dilation = 1175 one_dim_convolution ? 1 : window.dimensions(1).window_dilation(); 1176 1177 PrimitiveType primitive_type = lhs->shape().element_type(); 1178 llvm::Type* ir_ptr_type = primitive_type == F16 1179 ? b_.getHalfTy()->getPointerTo() 1180 : b_.getFloatTy()->getPointerTo(); 1181 llvm::Type* int64_type = b_.getInt64Ty(); 1182 llvm::Type* int8_ptr_type = b_.getInt8Ty()->getPointerTo(); 1183 llvm::FunctionType* conv_type = llvm::FunctionType::get( 1184 b_.getVoidTy(), 1185 {int8_ptr_type, ir_ptr_type, ir_ptr_type, ir_ptr_type, int64_type, 1186 int64_type, int64_type, int64_type, int64_type, int64_type, 1187 int64_type, int64_type, int64_type, int64_type, int64_type, 1188 int64_type, int64_type, int64_type, int64_type, int64_type, 1189 int64_type, int64_type, int64_type, int64_type}, 1190 /*isVarArg=*/false); 1191 bool multi_threaded = 1192 hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); 1193 bool use_mkl_dnn = 1194 hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn(); 1195 1196 // TODO(b/78639006) Singlethread MKL conv2d is not implemented due to the 1197 // potential race condition by setting the omp_num_threads. 1198 const char* fn_name = 1199 primitive_type == F16 1200 ? (multi_threaded 1201 ? runtime::kEigenConvF16SymbolName 1202 : runtime::kEigenSingleThreadedConvF16SymbolName) 1203 : (multi_threaded 1204 ? (use_mkl_dnn ? runtime::kMKLConvF32SymbolName 1205 : runtime::kEigenConvF32SymbolName) 1206 : runtime::kEigenSingleThreadedConvF32SymbolName); 1207 if (!multi_threaded && use_mkl_dnn) { 1208 LOG(WARNING) << "Using Eigen instead of MKL-DNN for single-threaded " 1209 "conv2d function."; 1210 } 1211 llvm::Function* conv_func = llvm::dyn_cast<llvm::Function>( 1212 module_->getOrInsertFunction(fn_name, conv_type).getCallee()); 1213 conv_func->setCallingConv(llvm::CallingConv::C); 1214 conv_func->setDoesNotThrow(); 1215 conv_func->setOnlyAccessesArgMemory(); 1216 Call(conv_func, { 1217 GetExecutableRunOptionsArgument(), 1218 BitCast(GetEmittedValueFor(convolution), ir_ptr_type), 1219 BitCast(lhs_address, ir_ptr_type), 1220 BitCast(rhs_address, ir_ptr_type), 1221 b_.getInt64(input_batch), 1222 b_.getInt64(input_rows), 1223 b_.getInt64(input_cols), 1224 b_.getInt64(input_channels), 1225 b_.getInt64(kernel_rows), 1226 b_.getInt64(kernel_cols), 1227 b_.getInt64(kernel_channels), 1228 b_.getInt64(kernel_filters), 1229 b_.getInt64(output_rows), 1230 b_.getInt64(output_cols), 1231 b_.getInt64(row_stride), 1232 b_.getInt64(col_stride), 1233 b_.getInt64(padding_top), 1234 b_.getInt64(padding_bottom), 1235 b_.getInt64(padding_left), 1236 b_.getInt64(padding_right), 1237 b_.getInt64(lhs_row_dilation), 1238 b_.getInt64(lhs_col_dilation), 1239 b_.getInt64(rhs_row_dilation), 1240 b_.getInt64(rhs_col_dilation), 1241 }); 1242 1243 return Status::OK(); 1244 } 1245 } 1246 1247 // This is a completely un-optimized version of convolution just to 1248 // have an early version that works. E.g. the input index and 1249 // padding calculation is not hoisted out of the inner loop. 1250 // 1251 // See the description of convolution in the XLA documentation for the pseudo 1252 // code for convolution. 1253 return DefaultAction(convolution); 1254 } 1255 1256 Status IrEmitter::HandleFft(HloInstruction* fft) { 1257 auto operand = fft->operand(0); 1258 TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( 1259 /*instruction=*/*fft, /*operands=*/{operand}, 1260 /*supported_types=*/{F32, C64})); 1261 TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(operand->shape().layout())); 1262 TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(fft->shape().layout())); 1263 VLOG(3) << "operand=" << ShapeUtil::HumanStringWithLayout(operand->shape()); 1264 VLOG(3) << "fft=" << ShapeUtil::HumanStringWithLayout(fft->shape()); 1265 1266 llvm::Value* operand_address = GetEmittedValueFor(operand); 1267 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fft)); 1268 1269 const std::vector<int64>& fft_length = fft->fft_length(); 1270 int64 input_batch = 1; 1271 for (int i = 0; i < fft->shape().dimensions_size() - fft_length.size(); i++) { 1272 input_batch *= fft->shape().dimensions(i); 1273 } 1274 1275 // Args have been computed, make the call. 1276 llvm::Type* int8_ptr_type = b_.getInt8Ty()->getPointerTo(); 1277 llvm::Type* int32_type = b_.getInt32Ty(); 1278 llvm::Type* int64_type = b_.getInt64Ty(); 1279 llvm::FunctionType* fft_type = llvm::FunctionType::get( 1280 b_.getVoidTy(), 1281 {int8_ptr_type, int8_ptr_type, int8_ptr_type, int32_type, int32_type, 1282 int64_type, int64_type, int64_type, int64_type}, 1283 /*isVarArg=*/false); 1284 1285 bool multi_threaded_eigen = 1286 hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); 1287 const char* fn_name = multi_threaded_eigen 1288 ? runtime::kEigenFftSymbolName 1289 : runtime::kEigenSingleThreadedFftSymbolName; 1290 1291 llvm::Function* fft_func = llvm::dyn_cast<llvm::Function>( 1292 module_->getOrInsertFunction(fn_name, fft_type).getCallee()); 1293 fft_func->setCallingConv(llvm::CallingConv::C); 1294 fft_func->setDoesNotThrow(); 1295 fft_func->setOnlyAccessesInaccessibleMemOrArgMem(); 1296 const int fft_rank = fft_length.size(); 1297 Call(fft_func, 1298 {GetExecutableRunOptionsArgument(), 1299 BitCast(GetEmittedValueFor(fft), int8_ptr_type), 1300 BitCast(operand_address, int8_ptr_type), b_.getInt32(fft->fft_type()), 1301 b_.getInt32(fft_rank), b_.getInt64(input_batch), 1302 b_.getInt64(fft_rank > 0 ? fft_length[0] : 0), 1303 b_.getInt64(fft_rank > 1 ? fft_length[1] : 0), 1304 b_.getInt64(fft_rank > 2 ? fft_length[2] : 0)}); 1305 1306 return Status::OK(); 1307 } 1308 1309 Status IrEmitter::HandleAllReduce(HloInstruction* crs) { 1310 if (hlo_module_config_.replica_count() != 1) { 1311 // TODO(b/33011107): Support nontrivial cross replica sum on CPU. 1312 return Unimplemented( 1313 "AllReduce with >1 replica is not implemented on CPU."); 1314 } 1315 1316 // When there is a single replica, a cross replica sum is the identity 1317 // function, and the buffer assignment expects a copy. 1318 // 1319 // TODO(b/80100934): We would like to eliminate one-replica CRS nodes entirely 1320 // in algebraic-simplifier, but currently on some platforms 1321 // HloModuleConfig::num_replicas changes between when the module is compiled 1322 // and when it's run. 1323 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs)); 1324 1325 // CRS with one operand and one replica is simply the identity function. 1326 if (crs->operand_count() == 1) { 1327 return EmitMemcpy(*crs->operand(0), *crs); 1328 } 1329 1330 // CRS with multiple operands and one replica produces a (one-deep) tuple. 1331 std::vector<llvm::Value*> operand_ptrs; 1332 for (int64 i = 0; i < crs->operand_count(); ++i) { 1333 llvm::Value* in_ptr = GetEmittedValueFor(crs->operand(i)); 1334 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice, 1335 assignment_.GetUniqueSlice(crs, {i})); 1336 1337 const Shape& operand_shape = crs->operand(i)->shape(); 1338 CHECK(operand_shape.IsArray()) 1339 << "Operands to all-reduce must be arrays: " << crs->ToString(); 1340 operand_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape)); 1341 1342 // TODO(b/63762267): Be more aggressive about specifying alignment. 1343 MemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr, 1344 /*SrcAlign=*/1, ShapeUtil::ByteSizeOf(operand_shape)); 1345 } 1346 llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &b_); 1347 return Status::OK(); 1348 } 1349 1350 Status IrEmitter::HandleParameter(HloInstruction* parameter) { 1351 VLOG(2) << "HandleParameter: " << parameter->ToString(); 1352 return EmitTargetAddressForOp(parameter); 1353 } 1354 1355 // Returns true if the relative order of the unreduced dimensions stays the same 1356 // through the reduce operation. 1357 static bool ReductionPreservesLayout(const HloInstruction& reduce) { 1358 DCHECK_EQ(reduce.opcode(), HloOpcode::kReduce); 1359 1360 // Maps dimensions that were not reduced from their dimension numbers in the 1361 // source shape to their dimensions numbers in the destination shape. 1362 // 1363 // So if we reduce f32[A,B,C,D] on dimensions 1 and 2, this map contains 1364 // [0->0, 3->1]. 1365 absl::flat_hash_map<int64, int64> unreduced_dim_map; 1366 1367 absl::flat_hash_set<int64> reduced_dims(reduce.dimensions().begin(), 1368 reduce.dimensions().end()); 1369 1370 const Shape& operand_shape = reduce.operand(0)->shape(); 1371 const Shape& result_shape = reduce.shape(); 1372 1373 int64 delta = 0; 1374 for (int64 i = 0; i < operand_shape.dimensions_size(); i++) { 1375 if (reduced_dims.contains(i)) { 1376 delta++; 1377 } else { 1378 InsertOrDie(&unreduced_dim_map, i, i - delta); 1379 } 1380 } 1381 1382 // Iterate dimensions minor to major and check that the corresponding 1383 // dimensions in the source and target shapes are equivalent. 1384 int64 result_dim_idx = 0; 1385 for (int64 operand_dim_idx = 0; 1386 operand_dim_idx < operand_shape.dimensions_size(); operand_dim_idx++) { 1387 int64 operand_dim = operand_shape.layout().minor_to_major(operand_dim_idx); 1388 if (!reduced_dims.contains(operand_dim)) { 1389 if (FindOrDie(unreduced_dim_map, operand_dim) != 1390 result_shape.layout().minor_to_major(result_dim_idx++)) { 1391 return false; 1392 } 1393 } 1394 } 1395 1396 CHECK_EQ(result_dim_idx, result_shape.dimensions_size()); 1397 1398 return true; 1399 } 1400 1401 IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator( 1402 HloComputation* function, string* failure_reason) const { 1403 CHECK_EQ(function->num_parameters(), 2); 1404 1405 auto root_instruction = function->root_instruction(); 1406 CHECK(ShapeUtil::IsScalar(root_instruction->shape())); 1407 1408 if (root_instruction->operand_count() != 2) { 1409 *failure_reason = "root instruction is not a binary operation"; 1410 return nullptr; 1411 } 1412 1413 const Shape& root_shape = root_instruction->shape(); 1414 if (ShapeUtil::ElementIsComplex(root_shape)) { 1415 // TODO(b/65408531): Complex add could by done via bitcast to <float x [2N]> 1416 // Complex multiply would be more challenging. We could perhaps use a 1417 // strided load to get all reals in a vector, all images in a vector, or use 1418 // CreateShuffleVector on a bitcast to float x [2N]. 1419 *failure_reason = "complex values not supported"; 1420 return nullptr; 1421 } 1422 bool root_is_floating_point = ShapeUtil::ElementIsFloating(root_shape); 1423 bool root_is_integral = ShapeUtil::ElementIsIntegral(root_shape); 1424 bool root_is_signed = ShapeUtil::ElementIsSigned(root_shape); 1425 1426 auto lhs = root_instruction->operand(0); 1427 auto rhs = root_instruction->operand(1); 1428 1429 auto param_0 = function->parameter_instruction(0); 1430 auto param_1 = function->parameter_instruction(1); 1431 if (!(lhs == param_0 && rhs == param_1) && 1432 !(rhs == param_0 && lhs == param_1)) { 1433 *failure_reason = 1434 "root instruction is not a binary operation on the incoming arguments"; 1435 return nullptr; 1436 } 1437 1438 CHECK(ShapeUtil::IsScalar(lhs->shape()) && ShapeUtil::IsScalar(rhs->shape())); 1439 1440 // This is visually similar to ElementalIrEmitter, though conceptually we're 1441 // doing something different here. ElementalIrEmitter emits scalar operations 1442 // while these emit scalar or vector operations depending on the type of the 1443 // operands. See CreateShardedVectorType for the actual types in use here. 1444 switch (root_instruction->opcode()) { 1445 default: 1446 *failure_reason = "did not recognize root instruction opcode"; 1447 return nullptr; 1448 1449 case HloOpcode::kAdd: 1450 return [root_is_integral](llvm::IRBuilder<>* b, llvm::Value* lhs, 1451 llvm::Value* rhs) { 1452 return root_is_integral ? b->CreateAdd(lhs, rhs) 1453 : b->CreateFAdd(lhs, rhs); 1454 }; 1455 1456 case HloOpcode::kMultiply: 1457 return [root_is_integral](llvm::IRBuilder<>* b, llvm::Value* lhs, 1458 llvm::Value* rhs) { 1459 return root_is_integral ? b->CreateMul(lhs, rhs) 1460 : b->CreateFMul(lhs, rhs); 1461 }; 1462 1463 case HloOpcode::kAnd: 1464 return [](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) { 1465 return b->CreateAnd(lhs, rhs); 1466 }; 1467 1468 case HloOpcode::kOr: 1469 return [](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) { 1470 return b->CreateOr(lhs, rhs); 1471 }; 1472 1473 case HloOpcode::kXor: 1474 return [](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) { 1475 return b->CreateXor(lhs, rhs); 1476 }; 1477 1478 case HloOpcode::kMaximum: 1479 return [root_is_floating_point, root_is_signed]( 1480 llvm::IRBuilder<>* b, llvm::Value* lhs, 1481 llvm::Value* rhs) -> llvm::Value* { 1482 if (root_is_floating_point) { 1483 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::maxnum, 1484 {lhs, rhs}, {lhs->getType()}, b); 1485 } 1486 1487 return b->CreateSelect( 1488 b->CreateICmp(root_is_signed ? llvm::ICmpInst::ICMP_SGE 1489 : llvm::ICmpInst::ICMP_UGE, 1490 lhs, rhs), 1491 lhs, rhs); 1492 }; 1493 1494 case HloOpcode::kMinimum: 1495 return [root_is_floating_point, root_is_signed]( 1496 llvm::IRBuilder<>* b, llvm::Value* lhs, 1497 llvm::Value* rhs) -> llvm::Value* { 1498 if (root_is_floating_point) { 1499 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::minnum, 1500 {lhs, rhs}, {lhs->getType()}, b); 1501 } 1502 1503 return b->CreateSelect( 1504 b->CreateICmp(root_is_signed ? llvm::ICmpInst::ICMP_SLE 1505 : llvm::ICmpInst::ICMP_ULE, 1506 lhs, rhs), 1507 lhs, rhs); 1508 }; 1509 } 1510 } 1511 1512 IrEmitter::ShardedVectorType IrEmitter::CreateShardedVectorType( 1513 PrimitiveType element_type, unsigned element_count) { 1514 int vector_register_size_in_elements = 1515 target_machine_features_.vector_register_byte_size( 1516 *compute_function_->function()) / 1517 ShapeUtil::ByteSizeOfPrimitiveType(element_type); 1518 1519 ShardedVectorType sharded_vector_type; 1520 llvm::Type* element_ir_type = 1521 llvm_ir::PrimitiveTypeToIrType(element_type, module_); 1522 1523 for (int i = 0, e = 1 + tensorflow::Log2Ceiling(element_count); i < e; i++) { 1524 // For every power of two present in element_count, we generate one or more 1525 // vector or scalar types. 1526 const unsigned current_size_fragment = 1u << i; 1527 if (!(element_count & current_size_fragment)) { 1528 // Power of two not present in element_count. 1529 continue; 1530 } 1531 1532 if (current_size_fragment == 1) { 1533 // Single element, use a scalar type. 1534 sharded_vector_type.push_back(element_ir_type); 1535 continue; 1536 } 1537 1538 // Lower "current_size_fragment" number of elements using (as few as 1539 // possible) vector registers. 1540 1541 if (current_size_fragment >= vector_register_size_in_elements) { 1542 auto vector_type = llvm::VectorType::get( 1543 element_ir_type, vector_register_size_in_elements); 1544 sharded_vector_type.insert( 1545 sharded_vector_type.end(), 1546 current_size_fragment / vector_register_size_in_elements, 1547 vector_type); 1548 1549 // Both current_size_fragment and vector_register_size_in_elements are 1550 // powers of two. 1551 CHECK_EQ(current_size_fragment % vector_register_size_in_elements, 0); 1552 continue; 1553 } 1554 1555 // For now we assume that vector_register_size_in_elements and lower powers 1556 // of two are all legal vector sizes (or at least can be lowered easily by 1557 // LLVM). 1558 sharded_vector_type.push_back( 1559 llvm::VectorType::get(element_ir_type, current_size_fragment)); 1560 } 1561 return sharded_vector_type; 1562 } 1563 1564 StatusOr<IrEmitter::ShardedVector> 1565 IrEmitter::EmitInnerLoopForVectorizedReduction( 1566 const ReductionGenerator& reduction_generator, 1567 const llvm_ir::IrArray::Index& output_index, 1568 const ShardedVectorType& accumulator_type, HloInstruction* init_value, 1569 HloInstruction* arg, absl::Span<const int64> dimensions, 1570 unsigned element_alignment) { 1571 ShardedVector accumulator; 1572 accumulator.reserve(accumulator_type.size()); 1573 for (auto accumulator_shard_type : accumulator_type) { 1574 accumulator.push_back(llvm_ir::EmitAllocaAtFunctionEntry( 1575 accumulator_shard_type, "accumulator", &b_, 0)); 1576 } 1577 1578 llvm::Value* init_value_ssa = Load(GetEmittedValueFor(init_value)); 1579 1580 for (llvm::Value* accumulator_shard : accumulator) { 1581 llvm::Value* initial_value; 1582 auto shard_type = accumulator_shard->getType()->getPointerElementType(); 1583 if (auto vector_type = llvm::dyn_cast<llvm::VectorType>(shard_type)) { 1584 initial_value = 1585 VectorSplat(vector_type->getNumElements(), init_value_ssa); 1586 } else { 1587 initial_value = init_value_ssa; 1588 } 1589 1590 AlignedStore(initial_value, accumulator_shard, element_alignment); 1591 } 1592 1593 llvm_ir::ForLoopNest reduction_loop_nest(IrName(arg, "vectorized_inner"), 1594 &b_); 1595 std::vector<llvm::Value*> input_multi_index = 1596 reduction_loop_nest.AddLoopsForShapeOnDimensions(arg->shape(), dimensions, 1597 "reduction_dim"); 1598 1599 SetToFirstInsertPoint(reduction_loop_nest.GetInnerLoopBodyBasicBlock(), &b_); 1600 1601 llvm_ir::IrArray arg_array(GetIrArrayFor(arg)); 1602 llvm_ir::IrArray::Index::const_iterator it = output_index.begin(); 1603 1604 for (auto& i : input_multi_index) { 1605 if (i == nullptr) { 1606 i = *it++; 1607 } 1608 } 1609 CHECK(output_index.end() == it); 1610 llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(), 1611 b_.getInt64Ty()); 1612 1613 llvm::Value* input_address = BitCast( 1614 arg_array.EmitArrayElementAddress(input_index, &b_), b_.getInt8PtrTy()); 1615 1616 for (int i = 0; i < accumulator.size(); i++) { 1617 auto input_address_typed = 1618 BitCast(input_address, accumulator[i]->getType()); 1619 auto current_accumulator_value = 1620 AlignedLoad(accumulator[i], element_alignment); 1621 auto addend = AlignedLoad(input_address_typed, element_alignment); 1622 arg_array.AnnotateLoadStoreInstructionWithMetadata(addend); 1623 1624 auto reduced_result = 1625 reduction_generator(&b_, current_accumulator_value, addend); 1626 AlignedStore(reduced_result, accumulator[i], element_alignment); 1627 1628 if (i != (accumulator.size() - 1)) { 1629 input_address = ConstInBoundsGEP1_32(reduced_result->getType(), 1630 input_address_typed, 1); 1631 } 1632 } 1633 1634 SetToFirstInsertPoint(reduction_loop_nest.GetOuterLoopExitBasicBlock(), &b_); 1635 1636 ShardedVector result_ssa; 1637 result_ssa.reserve(accumulator.size()); 1638 for (auto accumulator_shard : accumulator) { 1639 result_ssa.push_back(AlignedLoad(accumulator_shard, element_alignment)); 1640 } 1641 return result_ssa; 1642 } 1643 1644 void IrEmitter::EmitShardedVectorStore( 1645 llvm::Value* store_address, const std::vector<llvm::Value*>& value_to_store, 1646 const int alignment, const llvm_ir::IrArray& containing_array) { 1647 for (int i = 0; i < value_to_store.size(); i++) { 1648 auto store_address_typed = 1649 BitCast(store_address, 1650 llvm::PointerType::getUnqual(value_to_store[i]->getType())); 1651 1652 auto store_instruction = 1653 AlignedStore(value_to_store[i], store_address_typed, alignment); 1654 containing_array.AnnotateLoadStoreInstructionWithMetadata( 1655 store_instruction); 1656 1657 if (i != (value_to_store.size() - 1)) { 1658 store_address = ConstInBoundsGEP1_32(value_to_store[i]->getType(), 1659 store_address_typed, 1); 1660 } 1661 } 1662 } 1663 1664 StatusOr<bool> IrEmitter::EmitVectorizedReduce( 1665 HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, 1666 absl::Span<const int64> dimensions, HloComputation* function, 1667 string* failure_reason) { 1668 if (!ReductionPreservesLayout(*reduce)) { 1669 return false; 1670 } 1671 1672 ReductionGenerator reduction_generator = 1673 MatchReductionGenerator(function, failure_reason); 1674 if (!reduction_generator) { 1675 return false; 1676 } 1677 1678 int vectorization_factor_in_bytes = 1679 target_machine_features_.vectorization_factor_in_bytes(); 1680 1681 // We try to process vectorization_factor elements at the same time. 1682 const int vectorization_factor = 1683 vectorization_factor_in_bytes / 1684 ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()); 1685 1686 bool is_reduction_over_minor_dimension = absl::c_linear_search( 1687 dimensions, LayoutUtil::Minor(arg->shape().layout(), 0)); 1688 1689 unsigned element_alignment = tensorflow::MathUtil::GCD<unsigned>( 1690 ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()), 1691 MinimumAlignmentForPrimitiveType(reduce->shape().element_type())); 1692 1693 if (is_reduction_over_minor_dimension) { 1694 // TODO(sanjoy): Implement vectorized reduction over the minor dimension. 1695 *failure_reason = "reduction over minor dimension not implemented"; 1696 return false; 1697 } 1698 1699 CHECK(!reduce->shape().IsTuple()); 1700 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(reduce)); 1701 1702 // We know we're not reducing over the most minor dimension, which means we 1703 // can lower the reduction loop as: 1704 // 1705 // 1. We're reducing over dimensions R0, R1. 1706 // 2. D0 is the most minor dimension. 1707 // 3. VS is the vectorization stride (we want to reduce this many elements at 1708 // once) 1709 // 1710 // for (d1 in D1) { 1711 // for (d0 in D0 with stride VS) { 1712 // vector_acc = init 1713 // for (r1 in R1) { 1714 // for (r0 in R0) { 1715 // vector_acc = elementwise_reduce(vector_acc, input[d1, d0, r1, r0] 1716 // } 1717 // } 1718 // output[d1, d0] = vector_acc 1719 // } 1720 // } 1721 1722 llvm_ir::ForLoopNest loop_nest(IrName(reduce), &b_); 1723 std::vector<llvm::Value*> array_multi_index( 1724 reduce->shape().dimensions_size()); 1725 for (int i = LayoutUtil::MinorToMajor(reduce->shape()).size() - 1; i > 0; 1726 --i) { 1727 int64 dimension = LayoutUtil::Minor(reduce->shape().layout(), i); 1728 int64 start_index = 0; 1729 int64 end_index = reduce->shape().dimensions(dimension); 1730 std::unique_ptr<llvm_ir::ForLoop> loop = loop_nest.AddLoop( 1731 start_index, end_index, absl::StrFormat("dim.%d", dimension)); 1732 array_multi_index[dimension] = loop->GetIndVarValue(); 1733 } 1734 1735 int64 innermost_dimension = LayoutUtil::Minor(reduce->shape().layout(), 0); 1736 int64 innermost_dimension_size = 1737 reduce->shape().dimensions(innermost_dimension); 1738 1739 if (llvm::BasicBlock* innermost_body_bb = 1740 loop_nest.GetInnerLoopBodyBasicBlock()) { 1741 SetToFirstInsertPoint(innermost_body_bb, &b_); 1742 } 1743 1744 auto outermost_loop_exit_block = loop_nest.GetOuterLoopExitBasicBlock(); 1745 1746 if (innermost_dimension_size >= vectorization_factor) { 1747 int64 start_index = 0; 1748 int64 end_index = (innermost_dimension_size / vectorization_factor) * 1749 vectorization_factor; 1750 std::unique_ptr<llvm_ir::ForLoop> loop = 1751 loop_nest.AddLoop(start_index, end_index, vectorization_factor, 1752 absl::StrFormat("dim.%d", innermost_dimension)); 1753 array_multi_index[innermost_dimension] = loop->GetIndVarValue(); 1754 1755 SetToFirstInsertPoint(loop->GetBodyBasicBlock(), &b_); 1756 1757 ShardedVectorType vector_type = CreateShardedVectorType( 1758 reduce->shape().element_type(), vectorization_factor); 1759 llvm_ir::IrArray::Index array_index(array_multi_index, reduce->shape(), 1760 b_.getInt64Ty()); 1761 TF_ASSIGN_OR_RETURN(std::vector<llvm::Value*> accumulator, 1762 EmitInnerLoopForVectorizedReduction( 1763 reduction_generator, array_index, vector_type, 1764 init_value, arg, dimensions, element_alignment)); 1765 1766 llvm_ir::IrArray target_array = GetIrArrayFor(reduce); 1767 llvm::Value* output_address = 1768 target_array.EmitArrayElementAddress(array_index, &b_); 1769 EmitShardedVectorStore(output_address, accumulator, element_alignment, 1770 target_array); 1771 1772 if (auto exit_terminator = loop->GetExitBasicBlock()->getTerminator()) { 1773 CHECK_GT(LayoutUtil::MinorToMajor(reduce->shape()).size(), 1); 1774 b_.SetInsertPoint(exit_terminator); 1775 } else { 1776 CHECK_EQ(LayoutUtil::MinorToMajor(reduce->shape()).size(), 1); 1777 b_.SetInsertPoint(loop->GetExitBasicBlock()); 1778 } 1779 } 1780 1781 // Since we increment the stride for the inner dimension by more than 1, we 1782 // may need to peel out an "epilogue" iteration to get the remaining elements 1783 // in the following case: 1784 if (innermost_dimension_size % vectorization_factor) { 1785 // TODO(b/63775531): Consider using a scalar loop here to save on code size. 1786 array_multi_index[innermost_dimension] = 1787 b_.getInt64(innermost_dimension_size - 1788 (innermost_dimension_size % vectorization_factor)); 1789 1790 ShardedVectorType vector_type = CreateShardedVectorType( 1791 reduce->shape().element_type(), 1792 innermost_dimension_size % vectorization_factor); 1793 llvm_ir::IrArray::Index array_index(array_multi_index, reduce->shape(), 1794 b_.getInt64Ty()); 1795 TF_ASSIGN_OR_RETURN(std::vector<llvm::Value*> accumulator, 1796 EmitInnerLoopForVectorizedReduction( 1797 reduction_generator, array_index, vector_type, 1798 init_value, arg, dimensions, element_alignment)); 1799 1800 llvm_ir::IrArray target_array = GetIrArrayFor(reduce); 1801 llvm::Value* output_address = 1802 target_array.EmitArrayElementAddress(array_index, &b_); 1803 EmitShardedVectorStore(output_address, accumulator, element_alignment, 1804 target_array); 1805 } 1806 1807 if (outermost_loop_exit_block) { 1808 b_.SetInsertPoint(outermost_loop_exit_block); 1809 } 1810 1811 return true; 1812 } 1813 1814 StatusOr<llvm::Value*> IrEmitter::EmitElementalReduce( 1815 const HloReduceInstruction* reduce, 1816 const llvm_ir::ElementGenerator& input_generator, 1817 const llvm_ir::ElementGenerator& initial_value_generator, 1818 const llvm_ir::IrArray::Index& index) { 1819 const HloInstruction* arg = reduce->operand(0); 1820 absl::Span<const int64> dimensions(reduce->dimensions()); 1821 1822 // Initialize an accumulator with init_value. 1823 PrimitiveType accumulator_type = reduce->shape().element_type(); 1824 llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry( 1825 llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_), "accumulator", 1826 &b_, MinimumAlignmentForPrimitiveType(accumulator_type)); 1827 TF_ASSIGN_OR_RETURN( 1828 llvm::Value* const init_value, 1829 initial_value_generator(llvm_ir::IrArray::Index(index.GetType()))); 1830 Store(init_value, accumulator_addr); 1831 1832 // The enclosing loops go over all the target elements. Now we have to compute 1833 // the actual target element. For this, we build a new loop nest to iterate 1834 // over all the reduction dimensions in the argument. 1835 // AddLoopsForShapeOnDimensions will return an Index where induction Value*s 1836 // are placed for each dimension in dimensions, and all the rest are nullptrs. 1837 llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &b_); 1838 std::vector<llvm::Value*> input_multi_index = 1839 loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions, 1840 "reduction_dim"); 1841 1842 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); 1843 1844 // Build a full index for the input argument, using reduced_dims_index as the 1845 // base. In reduced_dims_index only the reduction dimensions are filled in. We 1846 // fill in the rest of the dimensions with induction Value*s taken from 1847 // 'index' which iterates over the target array. See the high-level 1848 // description in the XLA documentation for details. 1849 llvm_ir::IrArray::Index::const_iterator it = index.begin(); 1850 1851 for (auto& i : input_multi_index) { 1852 if (i == nullptr) { 1853 i = *it++; 1854 } 1855 } 1856 CHECK(index.end() == it); 1857 llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(), 1858 b_.getInt64Ty()); 1859 1860 // Apply the reduction function to the loaded value. 1861 TF_ASSIGN_OR_RETURN(llvm::Value* const input_element, 1862 input_generator(input_index)); 1863 llvm::Value* result = EmitThreadLocalCall( 1864 *reduce->to_apply(), {Load(accumulator_addr), input_element}, 1865 "reduce_function"); 1866 Store(result, accumulator_addr); 1867 1868 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); 1869 return Load(accumulator_addr); 1870 } 1871 1872 Status IrEmitter::HandleReduce(HloInstruction* reduce) { 1873 // TODO(b/118333695): Support variadic reduce. 1874 if (!reduce->shape().IsArray()) { 1875 return Unimplemented("Variadic reduce is not supported on CPU"); 1876 } 1877 auto arg = reduce->mutable_operand(0); 1878 auto init_value = reduce->mutable_operand(1); 1879 absl::Span<const int64> dimensions(reduce->dimensions()); 1880 HloComputation* function = reduce->to_apply(); 1881 if (!options::VectorizedReduceDisabled(hlo_module_config_)) { 1882 string vectorization_failure_reason; 1883 TF_ASSIGN_OR_RETURN( 1884 bool vectorization_successful, 1885 EmitVectorizedReduce(reduce, arg, init_value, dimensions, function, 1886 &vectorization_failure_reason)); 1887 if (vectorization_successful) { 1888 VLOG(1) << "Successfully vectorized reduction " << reduce->ToString() 1889 << "\n"; 1890 return Status::OK(); 1891 } else { 1892 VLOG(1) << "Could not vectorize reduction " << reduce->ToString() << ": " 1893 << vectorization_failure_reason; 1894 } 1895 } 1896 1897 return DefaultAction(reduce); 1898 } 1899 1900 Status IrEmitter::HandleAllToAll(HloInstruction*) { 1901 return Unimplemented("AllToAll is not implemented on CPU."); 1902 } 1903 1904 Status IrEmitter::HandleSend(HloInstruction* send) { 1905 // TODO(b/33942983): Support Send/Recv on CPU. 1906 return Unimplemented("Send is not implemented on CPU."); 1907 } 1908 1909 Status IrEmitter::HandleSendDone(HloInstruction* send_done) { 1910 // TODO(b/33942983): Support Send/Recv on CPU. 1911 return Unimplemented("Send-done is not implemented on CPU."); 1912 } 1913 1914 Status IrEmitter::HandleScatter(HloInstruction*) { 1915 return Unimplemented("Scatter is not implemented on CPUs."); 1916 } 1917 1918 Status IrEmitter::HandleSlice(HloInstruction* slice) { 1919 VLOG(2) << "HandleSlice: " << slice->ToString(); 1920 auto operand = slice->operand(0); 1921 // The code below emits a sequential loop nest. For the parallel backend, use 1922 // ParallelLoopEmitter which respects dynamic loop bounds. 1923 if (ShouldEmitParallelLoopFor(*slice)) { 1924 return DefaultAction(slice); 1925 } 1926 1927 // The code below assumes the layouts are equal. 1928 if (!LayoutUtil::Equal(operand->shape().layout(), slice->shape().layout())) { 1929 return DefaultAction(slice); 1930 } 1931 1932 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(slice)); 1933 1934 if (ShapeUtil::IsZeroElementArray(slice->shape())) { 1935 return Status::OK(); 1936 } 1937 1938 const Layout& layout = operand->shape().layout(); 1939 const int64 num_dims = operand->shape().dimensions_size(); 1940 1941 // The slice lowering finds maximal contiguous blocks of memory that can be 1942 // copied from the source to the target. This is done by looking at the 1943 // source/target layout in minor to major order and do the following: 1944 // 1945 // * Find an initial segment of dimensions along which the slice uses the 1946 // whole dimension. These are the "inner" dimensions and can be folded into 1947 // the memcpy. 1948 // 1949 // * Of the remaining dimensions decide which ones require loops. 1950 // 1951 // * Implement the memcpy within the innermost loop. 1952 1953 absl::flat_hash_set<int64> inner_dims; 1954 for (int64 dim : LayoutUtil::MinorToMajor(layout)) { 1955 if (operand->shape().dimensions(dim) != slice->shape().dimensions(dim)) { 1956 break; 1957 } 1958 inner_dims.insert(dim); 1959 } 1960 1961 const bool is_trivial_copy = (inner_dims.size() == num_dims); 1962 if (is_trivial_copy) { 1963 if (ShapeUtil::IsEffectiveScalar(slice->shape())) { 1964 return DefaultAction(slice); 1965 } else { 1966 return EmitMemcpy(*slice, *operand); 1967 } 1968 } 1969 1970 // The memcpy will copy elements that are logically this shape (allowed to be 1971 // scalar). 1972 const Shape logical_element_shape = ShapeUtil::FilterDimensions( 1973 [&inner_dims](int64 dim) { return inner_dims.contains(dim); }, 1974 operand->shape()); 1975 1976 const int64 primitive_elements_per_logical_element = 1977 ShapeUtil::ElementsIn(logical_element_shape); 1978 1979 // memcpy_dim is the innermost (in terms of layout) dimension for which the 1980 // slice does *not* just copy all the elements along the dimension. 1981 const int64 memcpy_dim = LayoutUtil::Minor(layout, inner_dims.size()); 1982 1983 const bool memcpy_is_contiguous = slice->slice_strides(memcpy_dim) == 1; 1984 // The number of logical elements that can be copied in a single call 1985 // to memcpy. We can only copy 1 element at a time if there is a non-trivial 1986 // stride. 1987 const int64 memcpy_logical_elements = 1988 memcpy_is_contiguous 1989 ? slice->slice_limits(memcpy_dim) - slice->slice_starts(memcpy_dim) 1990 : 1; 1991 1992 // Determine the dimensions that get lowered as loops. 1993 std::vector<int64> outer_dims; 1994 for (int64 i = 0; i < num_dims - inner_dims.size() - 1; ++i) { 1995 outer_dims.push_back(LayoutUtil::Major(layout, i)); 1996 } 1997 1998 // Is the slice along the memcpy dimension contiguous? If not, then memcpy_dim 1999 // needs to be wrapped around a loop as well. 2000 if (!memcpy_is_contiguous) { 2001 outer_dims.push_back(memcpy_dim); 2002 } 2003 2004 llvm_ir::IrArray target_array = GetIrArrayFor(slice); 2005 2006 const int64 num_outer_loops = outer_dims.size(); 2007 llvm_ir::ForLoopNest loops(IrName(slice), &b_); 2008 std::vector<llvm::Value*> target_multi_index = 2009 loops.AddLoopsForShapeOnDimensions(slice->shape(), outer_dims, "slice"); 2010 2011 // Only the indices for the outer dimensions have been initialized in 2012 // target_index. The rest of the indices should get initialized to 0, since 2013 // for the rest of the dimensions the copy writes to the full dimension. 2014 std::replace(target_multi_index.begin(), target_multi_index.end(), 2015 static_cast<llvm::Value*>(nullptr), 2016 static_cast<llvm::Value*>(b_.getInt64(0))); 2017 llvm_ir::IrArray::Index target_index(target_multi_index, slice->shape(), 2018 b_.getInt64Ty()); 2019 2020 if (num_outer_loops > 0) { 2021 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); 2022 } 2023 2024 llvm_ir::IrArray source_array = GetIrArrayFor(operand); 2025 const llvm_ir::IrArray::Index source_index = target_index.SourceIndexOfSlice( 2026 /*operand_shape=*/operand->shape(), /*starts=*/slice->slice_starts(), 2027 /*strides=*/slice->slice_strides(), /*builder=*/&b_); 2028 2029 llvm::Value* memcpy_dest = 2030 target_array.EmitArrayElementAddress(target_index, &b_, "slice.dest"); 2031 llvm::Value* memcpy_source = 2032 source_array.EmitArrayElementAddress(source_index, &b_, "slice.source"); 2033 2034 const int64 memcpy_elements = 2035 primitive_elements_per_logical_element * memcpy_logical_elements; 2036 2037 EmitTransferElements(memcpy_dest, memcpy_source, memcpy_elements, 2038 slice->shape().element_type(), target_array, 2039 source_array); 2040 2041 if (VLOG_IS_ON(2)) { 2042 const int64 memcpy_bytes = 2043 ShapeUtil::ByteSizeOf(logical_element_shape) * memcpy_elements; 2044 VLOG(2) << " emitted copy of " << memcpy_bytes << " bytes inside " 2045 << num_outer_loops << " loops"; 2046 } 2047 2048 if (num_outer_loops > 0) { 2049 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); 2050 } 2051 2052 return Status::OK(); 2053 } 2054 2055 Status IrEmitter::HandleDynamicSlice(HloInstruction* dynamic_slice) { 2056 if (ShapeUtil::IsScalar(dynamic_slice->shape())) { 2057 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_slice)); 2058 return EmitMemcpy(*dynamic_slice->operand(0), *dynamic_slice); 2059 } 2060 return DefaultAction(dynamic_slice); 2061 } 2062 2063 Status IrEmitter::HandleDynamicUpdateSlice( 2064 HloInstruction* dynamic_update_slice) { 2065 auto update = dynamic_update_slice->operand(1); 2066 if (ShapeUtil::IsScalar(dynamic_update_slice->shape())) { 2067 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_update_slice)); 2068 return EmitMemcpy(*update, *dynamic_update_slice); 2069 } else if (llvm_ir::CanUpdateDynamicSliceInPlace(dynamic_update_slice, 2070 assignment_)) { 2071 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_update_slice)); 2072 auto operands = GetIrArraysForOperandsOf(dynamic_update_slice); 2073 return llvm_ir::EmitDynamicUpdateSliceInPlace( 2074 operands, GetIrArrayFor(dynamic_update_slice), 2075 IrName(dynamic_update_slice, "in_place"), &b_); 2076 } 2077 return DefaultAction(dynamic_update_slice); 2078 } 2079 2080 Status IrEmitter::HandleRecv(HloInstruction* recv) { 2081 // TODO(b/33942983): Support Send/Recv on CPU. 2082 return Unimplemented("Recv is not implemented on CPU."); 2083 } 2084 2085 Status IrEmitter::HandleRecvDone(HloInstruction* recv_done) { 2086 // TODO(b/33942983): Support Send/Recv on CPU. 2087 return Unimplemented("Recv-done is not implemented on CPU."); 2088 } 2089 2090 Status IrEmitter::HandlePad(HloInstruction* pad) { 2091 // CPU backend does not properly handle negative padding but this is ok 2092 // because negative padding should be removed by the algebraic simplifier. 2093 for (auto& padding_dimension : pad->padding_config().dimensions()) { 2094 if (padding_dimension.edge_padding_low() < 0 || 2095 padding_dimension.edge_padding_high() < 0) { 2096 return InternalErrorStrCat( 2097 "Encountered negative padding in IrEmitter on CPU. " 2098 "This should have been eliminated at the HLO level. ", 2099 pad->ToString()); 2100 } 2101 } 2102 2103 // First, fill in the padding value to all output elements. 2104 TF_RETURN_IF_ERROR(EmitTargetElementLoop( 2105 pad, "initialize", 2106 [this, pad](const llvm_ir::IrArray::Index& target_index) { 2107 const HloInstruction* padding_value = pad->operand(1); 2108 llvm::Value* padding_value_addr = GetEmittedValueFor(padding_value); 2109 return Load(padding_value_addr); 2110 })); 2111 2112 // Create a loop to iterate over the operand elements and update the output 2113 // locations where the operand elements should be stored. 2114 llvm_ir::ForLoopNest loops(IrName(pad, "assign"), &b_); 2115 const HloInstruction* operand = pad->operand(0); 2116 const llvm_ir::IrArray::Index operand_index = 2117 loops.AddLoopsForShape(operand->shape(), "operand"); 2118 2119 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); 2120 2121 // Load an element from the operand. 2122 llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); 2123 llvm::Value* operand_data = 2124 operand_array.EmitReadArrayElement(operand_index, &b_); 2125 2126 // Compute the output index the operand element should be assigned to. 2127 // output_index := edge_padding_low + operand_index * (interior_padding + 1) 2128 const PaddingConfig& padding_config = pad->padding_config(); 2129 std::vector<llvm::Value*> output_multi_index; 2130 for (size_t i = 0; i < operand_index.size(); ++i) { 2131 llvm::Value* offset = 2132 Mul(operand_index[i], 2133 b_.getInt64(padding_config.dimensions(i).interior_padding() + 1)); 2134 llvm::Value* index = Add( 2135 offset, b_.getInt64(padding_config.dimensions(i).edge_padding_low())); 2136 output_multi_index.push_back(index); 2137 } 2138 2139 // Store the operand element to the computed output location. 2140 llvm_ir::IrArray output_array(GetIrArrayFor(pad)); 2141 llvm_ir::IrArray::Index output_index( 2142 output_multi_index, output_array.GetShape(), operand_index.GetType()); 2143 output_array.EmitWriteArrayElement(output_index, operand_data, &b_); 2144 2145 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); 2146 return Status::OK(); 2147 } 2148 2149 Status IrEmitter::HandleFusion(HloInstruction* fusion) { 2150 auto* root = fusion->fused_expression_root(); 2151 if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, assignment_)) { 2152 VLOG(3) << "HandleFusion FusedDynamicUpdateSliceInPlace"; 2153 CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); 2154 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion)); 2155 // Delegate to common implementation of fused in-place dynamic-update-slice. 2156 return llvm_ir::EmitFusedDynamicUpdateSliceInPlace( 2157 fusion, GetGeneratorForOperandIrArrays(fusion), GetIrArrayFor(fusion), 2158 &elemental_emitter, &b_); 2159 } else if (fusion->fusion_kind() == HloInstruction::FusionKind::kLoop) { 2160 VLOG(3) << "HandleFusion kLoop"; 2161 CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); 2162 auto operands = GetIrArraysForOperandsOf(fusion); 2163 FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(fusion), 2164 &elemental_emitter); 2165 TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter)); 2166 2167 return EmitTargetElementLoop(fusion, fused_emitter.GetRootGenerator()); 2168 } else if (fusion->fusion_kind() == HloInstruction::FusionKind::kOutput) { 2169 VLOG(3) << "HandleFusion kOutput"; 2170 int64 dot_op_index = root->operand(0)->opcode() == HloOpcode::kDot ? 0 : 1; 2171 const HloInstruction* dot = root->operand(dot_op_index); 2172 CHECK_EQ(dot->opcode(), HloOpcode::kDot) 2173 << dot->ToString() << " " 2174 << fusion->fused_instructions_computation()->ToString(); 2175 2176 int64 dot_lhs_param_number = dot->operand(0)->parameter_number(); 2177 int64 dot_rhs_param_number = dot->operand(1)->parameter_number(); 2178 int64 addend_param_number = 2179 root->operand(1 - dot_op_index)->parameter_number(); 2180 2181 Shape target_shape = fusion->shape(); 2182 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion)); 2183 llvm_ir::IrArray target_array = GetIrArrayFor(fusion); 2184 2185 llvm_ir::IrArray lhs_array( 2186 GetIrArrayFor(fusion->operand(dot_lhs_param_number))); 2187 llvm_ir::IrArray rhs_array( 2188 GetIrArrayFor(fusion->operand(dot_rhs_param_number))); 2189 llvm_ir::IrArray addend_array( 2190 GetIrArrayFor(fusion->operand(addend_param_number))); 2191 2192 TF_RETURN_IF_ERROR( 2193 EmitDotOperation(*dot, target_array, lhs_array, rhs_array, 2194 &addend_array, GetExecutableRunOptionsArgument(), &b_, 2195 hlo_module_config_, target_machine_features_)); 2196 return Status::OK(); 2197 } else { 2198 return Unimplemented("Fusion kind not implemented on CPU"); 2199 } 2200 } 2201 2202 Status IrEmitter::HandleCall(HloInstruction* call) { 2203 HloComputation* computation = call->to_apply(); 2204 llvm::Function* call_ir_function = FindOrDie(emitted_functions_, computation); 2205 2206 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(call)); 2207 2208 if (!computation->root_instruction()->outer_dimension_partitions().empty()) { 2209 // ParallelTaskAssignment assigned partitions, emit call to 2210 // ParallelForkJoin. 2211 std::vector<llvm::Value*> call_args = GetArrayFunctionCallArguments( 2212 {}, &b_, computation->name(), 2213 /*return_value_buffer=*/emitted_value_[call], 2214 /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), 2215 /*buffer_table_arg=*/GetBufferTableArgument(), 2216 /*profile_counters_arg=*/GetProfileCountersArgument()); 2217 2218 HloInstruction* root = computation->root_instruction(); 2219 TF_RETURN_IF_ERROR(EmitCallToParallelForkJoin( 2220 call_args, root->shape(), root->outer_dimension_partitions(), &b_, 2221 call_ir_function, computation->name())); 2222 } else { 2223 EmitGlobalCall(*computation, computation->name()); 2224 } 2225 2226 return Status::OK(); 2227 } 2228 2229 Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { 2230 absl::Span<HloInstruction* const> operands(custom_call->operands()); 2231 llvm::Type* i8_ptr_type = b_.getInt8PtrTy(); 2232 llvm::AllocaInst* operands_alloca = 2233 llvm_ir::EmitAllocaAtFunctionEntryWithCount( 2234 i8_ptr_type, b_.getInt32(operands.size()), "cc_operands_alloca", &b_); 2235 for (size_t i = 0; i < operands.size(); ++i) { 2236 const HloInstruction* operand = operands[i]; 2237 llvm::Value* operand_as_i8ptr = 2238 PointerCast(GetEmittedValueFor(operand), i8_ptr_type); 2239 llvm::Value* slot_in_operands_alloca = 2240 InBoundsGEP(operands_alloca, {b_.getInt64(i)}); 2241 Store(operand_as_i8ptr, slot_in_operands_alloca); 2242 } 2243 if (emit_code_for_msan_) { 2244 // Mark the alloca as initialized for msan. The buffer gets read by the 2245 // custom callee, which might be msan-instrumented. 2246 // TODO(b/66051036): Run the msan instrumentation pass instead. 2247 const llvm::DataLayout& dl = module_->getDataLayout(); 2248 llvm::Type* intptr_type = b_.getIntPtrTy(dl); 2249 auto* msan_unpoison_ir_function = llvm::cast<llvm::Function>( 2250 module_ 2251 ->getOrInsertFunction( 2252 "__msan_unpoison", 2253 llvm::FunctionType::get( 2254 /*Result=*/b_.getVoidTy(), 2255 /*Params=*/{i8_ptr_type, intptr_type}, /*isVarArg=*/false)) 2256 .getCallee()); 2257 Call(msan_unpoison_ir_function, 2258 {PointerCast(operands_alloca, i8_ptr_type), 2259 llvm::ConstantInt::get( 2260 intptr_type, *operands_alloca->getAllocationSizeInBits(dl) / 8)}); 2261 } 2262 auto* custom_call_ir_function = llvm::dyn_cast<llvm::Function>( 2263 module_ 2264 ->getOrInsertFunction( 2265 custom_call->custom_call_target(), 2266 llvm::FunctionType::get( 2267 /*Result=*/b_.getVoidTy(), 2268 /*Params=*/{i8_ptr_type, operands_alloca->getType()}, 2269 /*isVarArg=*/false)) 2270 .getCallee()); 2271 2272 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call)); 2273 // Write the tuple table if the output is a tuple. 2274 if (custom_call->shape().IsTuple()) { 2275 std::vector<llvm::Value*> base_ptrs; 2276 for (int i = 0; i < ShapeUtil::TupleElementCount(custom_call->shape()); 2277 ++i) { 2278 const Shape& elem_shape = 2279 ShapeUtil::GetTupleElementShape(custom_call->shape(), i); 2280 TF_RET_CHECK(!elem_shape.IsTuple()) << "Nested tuples not implemented"; 2281 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, 2282 assignment_.GetUniqueSlice(custom_call, {i})); 2283 llvm::Value* addr = EmitBufferPointer(slice, elem_shape); 2284 base_ptrs.push_back(addr); 2285 } 2286 llvm_ir::EmitTuple(GetIrArrayFor(custom_call), base_ptrs, &b_); 2287 } 2288 auto* output_address_arg = 2289 PointerCast(GetEmittedValueFor(custom_call), i8_ptr_type); 2290 2291 Call(custom_call_ir_function, {output_address_arg, operands_alloca}); 2292 2293 return Status::OK(); 2294 } 2295 2296 Status IrEmitter::HandleWhile(HloInstruction* xla_while) { 2297 // Precondition: Condition computation must return a scalar bool. 2298 HloComputation* condition = xla_while->while_condition(); 2299 TF_RET_CHECK(ShapeUtil::IsScalar(condition->root_instruction()->shape()) && 2300 condition->root_instruction()->shape().element_type() == PRED) 2301 << "While condition computation must return bool; got: " 2302 << ShapeUtil::HumanString(condition->root_instruction()->shape()); 2303 // Check that all while-related buffers share an allocation slice. 2304 TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( 2305 xla_while->shape(), 2306 [this, &xla_while](const Shape& /*subshape*/, 2307 const ShapeIndex& index) -> Status { 2308 auto check = [this](const HloInstruction* a, const HloInstruction* b, 2309 const ShapeIndex& index) { 2310 const BufferAllocation::Slice slice_a = 2311 assignment_.GetUniqueSlice(a, index).ConsumeValueOrDie(); 2312 const BufferAllocation::Slice slice_b = 2313 assignment_.GetUniqueSlice(b, index).ConsumeValueOrDie(); 2314 if (slice_a != slice_b) { 2315 return InternalError( 2316 "instruction %s %s does not share slice with " 2317 "instruction %s %s", 2318 a->ToString(), slice_a.ToString(), b->ToString(), 2319 slice_b.ToString()); 2320 } 2321 return Status::OK(); 2322 }; 2323 TF_RETURN_IF_ERROR(check(xla_while, xla_while->operand(0), index)); 2324 TF_RETURN_IF_ERROR(check( 2325 xla_while, xla_while->while_condition()->parameter_instruction(0), 2326 index)); 2327 TF_RETURN_IF_ERROR( 2328 check(xla_while, xla_while->while_body()->parameter_instruction(0), 2329 index)); 2330 TF_RETURN_IF_ERROR(check( 2331 xla_while, xla_while->while_body()->root_instruction(), index)); 2332 return Status::OK(); 2333 })); 2334 2335 // Set emitted value to that of 'init' with which it shares an allocation. 2336 const HloInstruction* init = xla_while->operand(0); 2337 emitted_value_[xla_while] = GetEmittedValueFor(init); 2338 2339 // Generating: 2340 // while (Condition(while_result)) { 2341 // // CopyInsertion pass inserts copies which enable 'while_result' to 2342 // // be passed back in as 'Body' parameter. 2343 // while_result = Body(while_result); // Insert 2344 // } 2345 2346 // Terminates the current block with a branch to a while header. 2347 llvm::BasicBlock* header_bb = llvm::BasicBlock::Create( 2348 module_->getContext(), IrName(xla_while, "header"), 2349 compute_function_->function()); 2350 Br(header_bb); 2351 b_.SetInsertPoint(header_bb); 2352 2353 // Calls the condition function to determine whether to proceed with the 2354 // body. It must return a bool, so use the scalar call form. 2355 EmitGlobalCall(*xla_while->while_condition(), IrName(xla_while, "cond")); 2356 llvm::Value* while_predicate = ICmpNE( 2357 Load(GetBufferForGlobalCallReturnValue(*xla_while->while_condition())), 2358 llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0)); 2359 2360 // Branches to the body or to the while exit depending on the condition. 2361 llvm::BasicBlock* body_bb = 2362 llvm::BasicBlock::Create(module_->getContext(), IrName(xla_while, "body"), 2363 compute_function_->function()); 2364 llvm::BasicBlock* exit_bb = llvm::BasicBlock::Create( 2365 module_->getContext(), IrName(xla_while, "exit")); 2366 CondBr(while_predicate, body_bb, exit_bb); 2367 2368 // Calls the body function from the body block. 2369 b_.SetInsertPoint(body_bb); 2370 2371 // Calls the body function. 2372 EmitGlobalCall(*xla_while->while_body(), IrName(xla_while, "body")); 2373 2374 // Finishes with a branch back to the header. 2375 Br(header_bb); 2376 2377 // Adds the exit block to the function and sets the insert point there. 2378 compute_function_->function()->getBasicBlockList().push_back(exit_bb); 2379 b_.SetInsertPoint(exit_bb); 2380 2381 return Status::OK(); 2382 } 2383 2384 StatusOr<bool> IrEmitter::EmitFastConcatenate( 2385 HloInstruction* concatenate, absl::Span<HloInstruction* const> operands, 2386 string* failure_reason) { 2387 if (ShouldEmitParallelLoopFor(*concatenate)) { 2388 *failure_reason = 2389 "cannot generate memcpy-based concat for the parallel CPU backend"; 2390 return false; 2391 } 2392 2393 const Shape& output_shape = concatenate->shape(); 2394 for (auto* op : operands) { 2395 if (!LayoutUtil::Equal(op->shape().layout(), output_shape.layout())) { 2396 *failure_reason = "operand has mismatching layouts"; 2397 return false; 2398 } 2399 } 2400 2401 // We split the dimensions into three categories: the dimension over which we 2402 // are concatenating (concat_dim), the dimensions that are minor to it 2403 // (inner_dims) and the dimensions that are major to it (outer_dims). 2404 2405 int64 concat_dim = concatenate->dimensions(0); 2406 const Layout& output_layout = output_shape.layout(); 2407 auto output_min2maj = LayoutUtil::MinorToMajor(output_layout); 2408 auto concat_dim_layout_itr = absl::c_find(output_min2maj, concat_dim); 2409 2410 std::vector<int64> inner_dims(output_min2maj.begin(), concat_dim_layout_itr); 2411 std::vector<int64> outer_dims(std::next(concat_dim_layout_itr), 2412 output_min2maj.end()); 2413 2414 llvm::Type* i8_ptr_type = b_.getInt8PtrTy(); 2415 2416 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(concatenate)); 2417 llvm_ir::IrArray target_array = GetIrArrayFor(concatenate); 2418 2419 llvm_ir::ForLoopNest loops(IrName(concatenate), &b_); 2420 std::vector<llvm::Value*> target_multi_index = 2421 loops.AddLoopsForShapeOnDimensions(output_shape, outer_dims, "concat"); 2422 std::replace(target_multi_index.begin(), target_multi_index.end(), 2423 static_cast<llvm::Value*>(nullptr), 2424 static_cast<llvm::Value*>(b_.getInt64(0))); 2425 llvm_ir::IrArray::Index target_index(target_multi_index, output_shape, 2426 b_.getInt64Ty()); 2427 2428 if (!outer_dims.empty()) { 2429 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); 2430 } 2431 2432 PrimitiveType primitive_type = output_shape.element_type(); 2433 unsigned primitive_type_size = 2434 ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); 2435 2436 // Contiguous subregions from each operand to the concatenate contribute to a 2437 // contiguous subregion in the target buffer starting at target_region_begin. 2438 llvm::Value* target_region_begin = BitCast( 2439 target_array.EmitArrayElementAddress(target_index, &b_, "target_region"), 2440 i8_ptr_type); 2441 int64 byte_offset_into_target_region = 0; 2442 2443 int64 inner_dims_product = 2444 std::accumulate(inner_dims.begin(), inner_dims.end(), 1l, 2445 [&](int64 product, int64 inner_dim) { 2446 return product * output_shape.dimensions(inner_dim); 2447 }); 2448 2449 // For each operand, emit a memcpy from the operand to the target of size 2450 // equal to the product of inner dimensions. 2451 for (HloInstruction* operand : operands) { 2452 const Shape& input_shape = operand->shape(); 2453 llvm_ir::IrArray source_array = GetIrArrayFor(operand); 2454 llvm::Value* copy_source_address = BitCast( 2455 source_array.EmitArrayElementAddress(target_index, &b_, "src_addr"), 2456 i8_ptr_type); 2457 2458 llvm::Value* copy_target_address = 2459 GEP(target_region_begin, b_.getInt64(byte_offset_into_target_region)); 2460 2461 EmitTransferElements( 2462 copy_target_address, copy_source_address, 2463 inner_dims_product * input_shape.dimensions(concat_dim), primitive_type, 2464 target_array, source_array); 2465 2466 byte_offset_into_target_region += inner_dims_product * 2467 input_shape.dimensions(concat_dim) * 2468 primitive_type_size; 2469 } 2470 2471 if (!outer_dims.empty()) { 2472 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); 2473 } 2474 2475 return true; 2476 } 2477 2478 void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, 2479 int64 element_count, 2480 PrimitiveType primitive_type, 2481 const llvm_ir::IrArray& target_array, 2482 const llvm_ir::IrArray& source_array) { 2483 unsigned primitive_type_size = 2484 ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); 2485 unsigned element_alignment = tensorflow::MathUtil::GCD<unsigned>( 2486 primitive_type_size, MinimumAlignmentForPrimitiveType(primitive_type)); 2487 llvm::Type* primitive_ptr_type = llvm::PointerType::getUnqual( 2488 llvm_ir::PrimitiveTypeToIrType(primitive_type, module_)); 2489 2490 if (element_count == 1) { 2491 auto* load_instruction = 2492 AlignedLoad(BitCast(source, primitive_ptr_type), element_alignment); 2493 source_array.AnnotateLoadStoreInstructionWithMetadata(load_instruction); 2494 auto* store_instruction = 2495 AlignedStore(load_instruction, BitCast(target, primitive_ptr_type), 2496 element_alignment); 2497 target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction); 2498 } else { 2499 auto* memcpy_instruction = MemCpy( 2500 target, /*DstAlign=*/element_alignment, source, 2501 /*SrcAlign=*/element_alignment, element_count * primitive_type_size); 2502 2503 // The memcpy does the load and the store internally. The aliasing related 2504 // metadata has to reflect that. 2505 std::map<int, llvm::MDNode*> merged_metadata = 2506 llvm_ir::MergeMetadata(&module_->getContext(), source_array.metadata(), 2507 target_array.metadata()); 2508 for (const auto& kind_md_pair : merged_metadata) { 2509 memcpy_instruction->setMetadata(kind_md_pair.first, kind_md_pair.second); 2510 } 2511 } 2512 } 2513 2514 Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) { 2515 absl::Span<HloInstruction* const> operands(concatenate->operands()); 2516 string failure_reason; 2517 TF_ASSIGN_OR_RETURN( 2518 bool successful, 2519 EmitFastConcatenate(concatenate, operands, &failure_reason)); 2520 if (successful) { 2521 VLOG(1) << "Emitted fast concatenate for " << concatenate->ToString(); 2522 return Status::OK(); 2523 } 2524 2525 VLOG(1) << "Could not emit fast concatenate for " << concatenate->ToString() 2526 << ": " << failure_reason; 2527 2528 return DefaultAction(concatenate); 2529 } 2530 2531 Status IrEmitter::HandleConditional(HloInstruction* conditional) { 2532 auto branch_index = conditional->operand(0); 2533 int num_branches = conditional->branch_count(); 2534 TF_RET_CHECK(ShapeUtil::IsScalar(branch_index->shape()) && 2535 (branch_index->shape().element_type() == PRED || 2536 branch_index->shape().element_type() == S32)) 2537 << "Branch index on a conditional must be scalar bool or int32; got: " 2538 << ShapeUtil::HumanString(branch_index->shape()); 2539 2540 for (int b = 0; b < num_branches; ++b) { 2541 HloComputation* br_computation = conditional->branch_computation(b); 2542 TF_RET_CHECK(ShapeUtil::Equal(conditional->shape(), 2543 br_computation->root_instruction()->shape())) 2544 << "Shape of conditional should be same as the shape of the " << b 2545 << "th branch computation; got: " 2546 << ShapeUtil::HumanString(conditional->shape()) << " and " 2547 << ShapeUtil::HumanString(br_computation->root_instruction()->shape()); 2548 } 2549 2550 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(conditional)); 2551 2552 if (branch_index->shape().element_type() == PRED) { 2553 // Emit an if-else to LLVM: 2554 // if (pred) 2555 // cond_result = true_computation(true_operand) 2556 // else 2557 // cond_result = false_computation(false_operand) 2558 llvm::LoadInst* pred_value = Load( 2559 GetIrArrayFor(branch_index).GetBasePointer(), "load_predicate_value"); 2560 llvm::Value* pred_cond = 2561 ICmpNE(pred_value, 2562 llvm::ConstantInt::get( 2563 llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0), 2564 "boolean_predicate"); 2565 llvm_ir::LlvmIfData if_data = 2566 llvm_ir::EmitIfThenElse(pred_cond, "conditional", &b_); 2567 2568 SetToFirstInsertPoint(if_data.true_block, &b_); 2569 EmitGlobalCall(*conditional->branch_computation(0), 2570 IrName(conditional, "_true")); 2571 2572 SetToFirstInsertPoint(if_data.false_block, &b_); 2573 EmitGlobalCall(*conditional->branch_computation(1), 2574 IrName(conditional, "_false")); 2575 2576 SetToFirstInsertPoint(if_data.after_block, &b_); 2577 return Status::OK(); 2578 } 2579 // We emit a switch statement to LLVM: 2580 // switch (branch_index) { 2581 // default: 2582 // result = branch_computations[num_branches-1](operands[num_branches-1]); 2583 // break; 2584 // case 0: 2585 // result = branch_computations[0](operands[0]); break; 2586 // case 1: 2587 // result = branch_computations[1](operands[1]); break; 2588 // ... 2589 // case [[num_branches-2]]: 2590 // result = branch_computations[num_branches-2](operands[num_branches-2]); 2591 // break; 2592 // } 2593 llvm::LoadInst* branch_index_value = Load( 2594 GetIrArrayFor(branch_index).GetBasePointer(), "load_branch_index_value"); 2595 2596 auto case_block = b_.GetInsertBlock(); 2597 llvm::BasicBlock* after_block; 2598 // Add a terminator to the case block, if necessary. 2599 if (case_block->getTerminator() == nullptr) { 2600 after_block = llvm_ir::CreateBasicBlock(nullptr, "case-after", &b_); 2601 b_.SetInsertPoint(case_block); 2602 b_.CreateBr(after_block); 2603 } else { 2604 after_block = 2605 case_block->splitBasicBlock(b_.GetInsertPoint(), "case-after"); 2606 } 2607 // Our basic block should now end with an unconditional branch. Remove it; 2608 // we're going to replace it with a switch based branch. 2609 case_block->getTerminator()->eraseFromParent(); 2610 2611 // Lower the default branch computation. 2612 auto default_block = llvm_ir::CreateBasicBlock(nullptr, "case-default", &b_); 2613 b_.SetInsertPoint(default_block); 2614 EmitGlobalCall(*conditional->branch_computation(num_branches - 1), 2615 IrName(conditional, "_default")); 2616 b_.CreateBr(after_block); 2617 2618 // Prepare the switch (branch_index) { ... } instruction. 2619 b_.SetInsertPoint(case_block); 2620 llvm::SwitchInst* case_inst = 2621 b_.CreateSwitch(branch_index_value, default_block, num_branches - 1); 2622 // Lower each branch's computation. 2623 for (int b = 0; b < num_branches - 1; ++b) { // last branch is default 2624 // Lower the case b: { ... ; break; } computation. 2625 auto branch_block = 2626 llvm_ir::CreateBasicBlock(nullptr, absl::StrCat("case-branch", b), &b_); 2627 b_.SetInsertPoint(branch_block); 2628 EmitGlobalCall(*conditional->branch_computation(b), 2629 IrName(conditional, absl::StrCat("_branch", b))); 2630 b_.CreateBr(after_block); 2631 case_inst->addCase(b_.getInt32(b), branch_block); 2632 } 2633 2634 SetToFirstInsertPoint(after_block, &b_); 2635 return Status::OK(); 2636 } 2637 2638 Status IrEmitter::HandleAfterAll(HloInstruction* after_all) { 2639 TF_RET_CHECK(ByteSizeOf(after_all->shape()) == 0); 2640 // No code to generate, but we need to emit an address for book-keeping. 2641 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(after_all)); 2642 return Status::OK(); 2643 } 2644 2645 Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) { 2646 // AddDedendency just forwards its zero-th operand. 2647 emitted_value_[add_dependency] = 2648 GetEmittedValueFor(add_dependency->operand(0)); 2649 return Status::OK(); 2650 } 2651 2652 Status IrEmitter::HandleRng(HloInstruction* rng) { 2653 ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; 2654 for (const HloInstruction* operand : rng->operands()) { 2655 operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) { 2656 return GetIrArrayFor(operand).EmitReadArrayElement(index, &b_); 2657 }; 2658 } 2659 2660 CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); 2661 TF_RETURN_IF_ERROR(EmitTargetElementLoop( 2662 rng, elemental_emitter.MakeElementGenerator(rng, operand_to_generator))); 2663 2664 llvm_ir::IncrementVariableForPhiloxRngState(1, module_, &b_); 2665 2666 return Status::OK(); 2667 } 2668 2669 Status IrEmitter::FinishVisit(HloInstruction* root) { 2670 // When this method is called, we should have already emitted an IR value for 2671 // the root (return) op. The IR value holds the address of the buffer holding 2672 // the value. If the root is a constant or parameter, we perform a memcpy from 2673 // this buffer to the retval buffer of the computation. Otherwise, there's 2674 // nothing to do since the result was already written directly into the output 2675 // buffer. 2676 VLOG(2) << "FinishVisit root: " << root->ToString(); 2677 if (root->opcode() == HloOpcode::kOutfeed) { 2678 VLOG(2) << " outfeed with value: " 2679 << llvm_ir::DumpToString(*GetEmittedValueFor(root->operand(0))); 2680 } else { 2681 VLOG(2) << " value: " << llvm_ir::DumpToString(*GetEmittedValueFor(root)); 2682 } 2683 2684 auto record_complete_computation = [&](llvm::Value* prof_counter) { 2685 if (prof_counter) { 2686 profiling_state_.RecordCompleteComputation(&b_, prof_counter); 2687 } 2688 }; 2689 2690 // For the entry computation this increment is cumulative of embedded 2691 // computations since it includes cycles spent in computations invoked by 2692 // While, Call etc. 2693 record_complete_computation(GetProfileCounterFor(*root->parent())); 2694 return Status::OK(); 2695 } 2696 2697 template <typename T> 2698 llvm::Value* IrEmitter::GetProfileCounterCommon( 2699 const T& hlo, 2700 const std::unordered_map<const T*, int64>& profile_index_map) { 2701 auto it = profile_index_map.find(&hlo); 2702 if (it == profile_index_map.end()) { 2703 return nullptr; 2704 } 2705 2706 int64 prof_counter_idx = it->second; 2707 string counter_name = IrName("prof_counter", hlo.name()); 2708 return GEP(GetProfileCountersArgument(), b_.getInt64(prof_counter_idx), 2709 counter_name); 2710 } 2711 2712 void IrEmitter::ProfilingState::UpdateProfileCounter(llvm::IRBuilder<>* b, 2713 llvm::Value* prof_counter, 2714 llvm::Value* cycle_end, 2715 llvm::Value* cycle_start) { 2716 auto* cycle_diff = b->CreateSub(cycle_end, cycle_start); 2717 llvm::LoadInst* old_cycle_count = 2718 b->CreateLoad(prof_counter, "old_cycle_count"); 2719 auto* new_cycle_count = 2720 b->CreateAdd(cycle_diff, old_cycle_count, "new_cycle_count"); 2721 b->CreateStore(new_cycle_count, prof_counter); 2722 } 2723 2724 llvm::Value* IrEmitter::ProfilingState::ReadCycleCounter(llvm::IRBuilder<>* b) { 2725 llvm::Module* module = b->GetInsertBlock()->getModule(); 2726 if (use_rdtscp_) { 2727 llvm::Function* func_llvm_readcyclecounter = 2728 llvm::Intrinsic::getDeclaration(module, 2729 llvm::Intrinsic::readcyclecounter); 2730 return b->CreateCall(func_llvm_readcyclecounter); 2731 } 2732 llvm::Function* func_llvm_x86_rdtscp = 2733 llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::x86_rdtscp); 2734 if (!aux_i8ptr_) { 2735 llvm::AllocaInst* rdtscp_aux = 2736 llvm_ir::EmitAllocaAtFunctionEntry(b->getInt32Ty(), "rdtscp_aux", b); 2737 aux_i8ptr_ = b->CreateBitCast(rdtscp_aux, b->getInt8PtrTy()); 2738 } 2739 llvm::ConstantInt* alloca_size = b->getInt64(4); 2740 llvm::Function* func_llvm_lifetime_start = 2741 llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::lifetime_start); 2742 b->CreateCall(func_llvm_lifetime_start, {alloca_size, aux_i8ptr_}); 2743 llvm::Value* rdtscp_call = b->CreateCall(func_llvm_x86_rdtscp, aux_i8ptr_); 2744 llvm::Function* func_llvm_lifetime_end = 2745 llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::lifetime_end); 2746 b->CreateCall(func_llvm_lifetime_end, {alloca_size, aux_i8ptr_}); 2747 return rdtscp_call; 2748 } 2749 2750 void IrEmitter::ProfilingState::RecordCycleStart(llvm::IRBuilder<>* b, 2751 HloInstruction* hlo) { 2752 auto* cycle_start = ReadCycleCounter(b); 2753 cycle_start->setName(IrName(hlo, "cycle_start")); 2754 cycle_starts_[hlo] = cycle_start; 2755 if (first_read_cycle_start_ == nullptr) { 2756 first_read_cycle_start_ = cycle_start; 2757 } 2758 } 2759 2760 void IrEmitter::ProfilingState::RecordCycleDelta(llvm::IRBuilder<>* b, 2761 HloInstruction* hlo, 2762 llvm::Value* prof_counter) { 2763 auto* cycle_end = ReadCycleCounter(b); 2764 cycle_end->setName(IrName(hlo, "cycle_end")); 2765 auto* cycle_start = cycle_starts_[hlo]; 2766 UpdateProfileCounter(b, prof_counter, cycle_end, cycle_start); 2767 last_read_cycle_end_ = cycle_end; 2768 } 2769 2770 void IrEmitter::ProfilingState::RecordCompleteComputation( 2771 llvm::IRBuilder<>* b, llvm::Value* prof_counter) { 2772 if (last_read_cycle_end_ && first_read_cycle_start_) { 2773 UpdateProfileCounter(b, prof_counter, last_read_cycle_end_, 2774 first_read_cycle_start_); 2775 } 2776 } 2777 2778 Status IrEmitter::Preprocess(HloInstruction* hlo) { 2779 VLOG(3) << "Visiting: " << hlo->ToString(); 2780 if (instruction_to_profile_idx_.count(hlo)) { 2781 profiling_state_.RecordCycleStart(&b_, hlo); 2782 } 2783 return Status::OK(); 2784 } 2785 2786 Status IrEmitter::Postprocess(HloInstruction* hlo) { 2787 if (auto* prof_counter = GetProfileCounterFor(*hlo)) { 2788 profiling_state_.RecordCycleDelta(&b_, hlo, prof_counter); 2789 } 2790 return Status::OK(); 2791 } 2792 2793 llvm_ir::IrArray IrEmitter::GetIrArrayFor(const HloInstruction* hlo) { 2794 llvm::Value* value_for_op = GetEmittedValueFor(hlo); 2795 2796 llvm_ir::IrArray array(value_for_op, hlo->shape()); 2797 AddAliasingInformationToIrArray(*hlo, &array); 2798 return array; 2799 } 2800 2801 std::vector<llvm_ir::IrArray> IrEmitter::GetIrArraysForOperandsOf( 2802 const HloInstruction* hlo) { 2803 std::vector<llvm_ir::IrArray> arrays; 2804 std::transform( 2805 hlo->operands().begin(), hlo->operands().end(), 2806 std::back_inserter(arrays), 2807 [&](const HloInstruction* operand) { return GetIrArrayFor(operand); }); 2808 return arrays; 2809 } 2810 2811 llvm::Value* IrEmitter::GetEmittedValueFor(const HloInstruction* hlo) { 2812 auto it = emitted_value_.find(hlo); 2813 if (it == emitted_value_.end()) { 2814 LOG(FATAL) << "could not find emitted value for: " << hlo->ToString(); 2815 } 2816 return it->second; 2817 } 2818 2819 llvm::Type* IrEmitter::IrShapeType(const Shape& shape) { 2820 return llvm_ir::ShapeToIrType(shape, module_); 2821 } 2822 2823 llvm::Value* IrEmitter::GetProfileCountersArgument() { 2824 return compute_function_->profile_counters_arg(); 2825 } 2826 2827 llvm::Value* IrEmitter::GetBufferTableArgument() { 2828 return compute_function_->buffer_table_arg(); 2829 } 2830 2831 llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() { 2832 return compute_function_->exec_run_options_arg(); 2833 } 2834 2835 llvm::Value* IrEmitter::EmitThreadLocalBufferPointer( 2836 const BufferAllocation::Slice& slice, const Shape& target_shape) { 2837 const BufferAllocation& allocation = *slice.allocation(); 2838 llvm::Value* tempbuf_address = [&]() -> llvm::Value* { 2839 if (slice == computation_root_allocation_) { 2840 llvm::Argument* retval = compute_function_->result_arg(); 2841 llvm::AttrBuilder attr_builder; 2842 attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape)); 2843 attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape)); 2844 retval->addAttrs(attr_builder); 2845 return retval; 2846 } 2847 2848 auto param_it = 2849 computation_parameter_allocations_.find(slice.allocation()->index()); 2850 if (param_it != computation_parameter_allocations_.end()) { 2851 int64 param_number = param_it->second; 2852 // We have to access the parameter at offset param_number in the params 2853 // array. The code generated here is equivalent to this C code: 2854 // 2855 // i8* param_address_untyped = params[param_number]; 2856 // Param* param_address_typed = (Param*)param_address_untyped; 2857 // 2858 // Where Param is the actual element type of the underlying buffer (for 2859 // example, float for an XLA F32 element type). 2860 llvm::Value* params = compute_function_->parameters_arg(); 2861 llvm::Value* param_address_offset = 2862 llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_); 2863 llvm::LoadInst* param_address_untyped = Load(param_address_offset); 2864 2865 if (!target_shape.IsOpaque()) { 2866 AttachAlignmentMetadataForLoad(param_address_untyped, target_shape); 2867 AttachDereferenceableMetadataForLoad(param_address_untyped, 2868 target_shape); 2869 } 2870 return param_address_untyped; 2871 } 2872 2873 // Thread-local allocations should only be assigned a single buffer. 2874 const auto& assigned_buffers = allocation.assigned_buffers(); 2875 CHECK_EQ(1, assigned_buffers.size()); 2876 const Shape& shape = assigned_buffers.begin()->first->shape(); 2877 2878 std::pair<llvm::Function*, BufferAllocation::Slice> key = { 2879 compute_function_->function(), slice}; 2880 auto buf_it = thread_local_buffers_.find(key); 2881 if (buf_it == thread_local_buffers_.end()) { 2882 llvm::Value* buffer = llvm_ir::EmitAllocaAtFunctionEntry( 2883 IrShapeType(shape), absl::StrCat("thread_local", slice.ToString()), 2884 &b_, MinimumAlignmentForShape(target_shape)); 2885 auto it_inserted_pair = thread_local_buffers_.insert({key, buffer}); 2886 CHECK(it_inserted_pair.second); 2887 buf_it = it_inserted_pair.first; 2888 } 2889 return buf_it->second; 2890 }(); 2891 return BitCast(tempbuf_address, IrShapeType(target_shape)->getPointerTo()); 2892 } 2893 2894 llvm::Value* IrEmitter::EmitGlobalBufferPointer( 2895 const BufferAllocation::Slice& slice, const Shape& target_shape) { 2896 const BufferAllocation& allocation = *slice.allocation(); 2897 llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP( 2898 GetBufferTableArgument(), slice.index(), &b_); 2899 llvm::LoadInst* tempbuf_address_base = Load(tempbuf_address_ptr); 2900 if (hlo_module_config_.debug_options() 2901 .xla_llvm_enable_invariant_load_metadata()) { 2902 tempbuf_address_base->setMetadata( 2903 llvm::LLVMContext::MD_invariant_load, 2904 llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{})); 2905 } 2906 AttachAlignmentMetadataForLoad(tempbuf_address_base, allocation.size()); 2907 AttachDereferenceableMetadataForLoad(tempbuf_address_base, allocation.size()); 2908 2909 llvm::Value* tempbuf_address_untyped = tempbuf_address_base; 2910 if (slice.offset() > 0) { 2911 // Adjust the address to account for the slice offset. 2912 tempbuf_address_untyped = 2913 InBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset())); 2914 } 2915 return BitCast(tempbuf_address_untyped, 2916 IrShapeType(target_shape)->getPointerTo()); 2917 } 2918 2919 llvm::Value* IrEmitter::EmitBufferPointer(const BufferAllocation::Slice& slice, 2920 const Shape& target_shape) { 2921 if (slice.allocation()->is_thread_local()) { 2922 return EmitThreadLocalBufferPointer(slice, target_shape); 2923 } else if (slice.allocation()->is_constant()) { 2924 return BitCast( 2925 FindOrDie(constant_buffer_to_global_, slice.allocation()->index()), 2926 IrShapeType(target_shape)->getPointerTo()); 2927 } else { 2928 return EmitGlobalBufferPointer(slice, target_shape); 2929 } 2930 } 2931 2932 Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) { 2933 const Shape& target_shape = op->shape(); 2934 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, 2935 assignment_.GetUniqueTopLevelSlice(op)); 2936 llvm::Value* addr = EmitBufferPointer(slice, target_shape); 2937 addr->setName(IrName(op)); 2938 emitted_value_[op] = addr; 2939 return Status::OK(); 2940 } 2941 2942 Status IrEmitter::EmitTargetElementLoop( 2943 HloInstruction* target_op, 2944 const llvm_ir::ElementGenerator& element_generator) { 2945 return EmitTargetElementLoop(target_op, /*desc=*/"", element_generator); 2946 } 2947 2948 Status IrEmitter::EmitTargetElementLoop( 2949 HloInstruction* target_op, absl::string_view desc, 2950 const llvm_ir::ElementGenerator& element_generator) { 2951 VLOG(2) << "EmitTargetElementLoop: " << target_op->ToString(); 2952 2953 const Shape& target_shape = target_op->shape(); 2954 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(target_op)); 2955 llvm_ir::IrArray target_array = GetIrArrayFor(target_op); 2956 2957 if (target_op->IsMultiOutputFusion()) { 2958 // For multiple outputs fusion, we need to emit each operand and the root. 2959 TF_RET_CHECK(num_dynamic_loop_bounds_ == 0); 2960 std::vector<llvm_ir::IrArray> output_arrays; 2961 for (int64 i = 0; i < ShapeUtil::TupleElementCount(target_shape); ++i) { 2962 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice, 2963 assignment_.GetUniqueSlice(target_op, {i})); 2964 const Shape& element_shape = ShapeUtil::GetSubshape(target_shape, {i}); 2965 llvm::Value* op_target_address = EmitBufferPointer(slice, element_shape); 2966 output_arrays.push_back( 2967 llvm_ir::IrArray(op_target_address, element_shape)); 2968 } 2969 TF_RETURN_IF_ERROR( 2970 llvm_ir::LoopEmitter(element_generator, output_arrays, &b_) 2971 .EmitLoop(IrName(target_op))); 2972 2973 std::vector<llvm::Value*> tuple_operand_ptrs; 2974 for (int64 i = 0; i < output_arrays.size(); ++i) { 2975 tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); 2976 } 2977 llvm_ir::EmitTuple(target_array, tuple_operand_ptrs, &b_); 2978 2979 } else { 2980 if (ShouldEmitParallelLoopFor(*target_op)) { 2981 // Emit code to read dynamic loop bounds from compute function argument. 2982 std::vector<std::pair<llvm::Value*, llvm::Value*>> dynamic_loop_bounds = 2983 compute_function_->GetDynamicLoopBounds(); 2984 // Emit parallel loop with dynamic loop bounds for most-major dimensions. 2985 TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, target_array, 2986 &dynamic_loop_bounds, &b_) 2987 .EmitLoop(IrName(target_op))); 2988 } else { 2989 TF_RETURN_IF_ERROR( 2990 llvm_ir::LoopEmitter(element_generator, target_array, &b_) 2991 .EmitLoop(IrName(target_op))); 2992 } 2993 } 2994 return Status::OK(); 2995 } 2996 2997 Status IrEmitter::EmitMemcpy(const HloInstruction& source, 2998 const HloInstruction& destination) { 2999 llvm::Value* source_value = GetEmittedValueFor(&source); 3000 llvm::Value* destination_value = GetEmittedValueFor(&destination); 3001 int64 source_size = ByteSizeOf(source.shape()); 3002 // TODO(b/63762267): Be more aggressive about specifying alignment. 3003 MemCpy(destination_value, /*DstAlign=*/1, source_value, 3004 /*SrcAlign=*/1, source_size); 3005 return Status::OK(); 3006 } 3007 3008 Status IrEmitter::ElementTypesSameAndSupported( 3009 const HloInstruction& instruction, 3010 absl::Span<const HloInstruction* const> operands, 3011 absl::Span<const PrimitiveType> supported_types) { 3012 for (auto operand : operands) { 3013 TF_RET_CHECK( 3014 ShapeUtil::SameElementType(operands[0]->shape(), operand->shape())); 3015 } 3016 3017 TF_RET_CHECK(!operands.empty()); 3018 PrimitiveType primitive_type = operands[0]->shape().element_type(); 3019 if (!absl::c_linear_search(supported_types, primitive_type)) { 3020 return Unimplemented("unsupported operand type %s in op %s", 3021 PrimitiveType_Name(primitive_type), 3022 HloOpcodeString(instruction.opcode())); 3023 } 3024 return Status::OK(); 3025 } 3026 3027 Status IrEmitter::DefaultAction(HloInstruction* hlo) { 3028 ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; 3029 for (const HloInstruction* operand : hlo->operands()) { 3030 operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) { 3031 return GetIrArrayFor(operand).EmitReadArrayElement(index, &b_); 3032 }; 3033 } 3034 CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); 3035 return EmitTargetElementLoop( 3036 hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator)); 3037 } 3038 3039 llvm::Value* IrEmitter::EmitThreadLocalCall( 3040 const HloComputation& callee, absl::Span<llvm::Value* const> parameters, 3041 absl::string_view name) { 3042 CHECK(absl::c_binary_search(thread_local_computations_, &callee)); 3043 3044 const Shape& return_shape = callee.root_instruction()->shape(); 3045 3046 // Lifting this restriction to allow "small" arrays should be easy. Allowing 3047 // larger arrays is difficult because we allocate the buffer for this return 3048 // value on the stack. 3049 CHECK(ShapeUtil::IsScalar(return_shape)); 3050 3051 PrimitiveType return_type = return_shape.element_type(); 3052 3053 std::vector<llvm::Value*> parameter_addrs; 3054 for (llvm::Value* parameter : parameters) { 3055 CHECK(!parameter->getType()->isPointerTy()); 3056 llvm::Value* parameter_addr = llvm_ir::EmitAllocaAtFunctionEntry( 3057 parameter->getType(), "arg_addr", &b_); 3058 Store(parameter, parameter_addr); 3059 parameter_addrs.push_back(parameter_addr); 3060 } 3061 3062 llvm::Value* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntry( 3063 llvm_ir::PrimitiveTypeToIrType(return_type, module_), 3064 absl::StrCat(name, "_retval_addr"), &b_, 3065 MinimumAlignmentForPrimitiveType(return_type)); 3066 3067 Call(FindOrDie(emitted_functions_, &callee), 3068 GetArrayFunctionCallArguments( 3069 parameter_addrs, &b_, name, 3070 /*return_value_buffer=*/return_value_buffer, 3071 /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), 3072 /*buffer_table_arg=*/ 3073 llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()), 3074 /*profile_counters_arg=*/GetProfileCountersArgument())); 3075 3076 return Load(return_value_buffer); 3077 } 3078 3079 void IrEmitter::EmitGlobalCall(const HloComputation& callee, 3080 absl::string_view name) { 3081 CHECK(absl::c_binary_search(global_computations_, &callee)); 3082 3083 Call(FindOrDie(emitted_functions_, &callee), 3084 GetArrayFunctionCallArguments( 3085 /*parameter_addresses=*/{}, &b_, name, 3086 /*return_value_buffer=*/ 3087 llvm::Constant::getNullValue(b_.getInt8PtrTy()), 3088 /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), 3089 /*buffer_table_arg=*/GetBufferTableArgument(), 3090 /*profile_counters_arg=*/GetProfileCountersArgument())); 3091 } 3092 3093 llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue( 3094 const HloComputation& callee) { 3095 const HloInstruction* root_inst = callee.root_instruction(); 3096 if (root_inst->opcode() == HloOpcode::kOutfeed) { 3097 return llvm::Constant::getNullValue(b_.getInt8PtrTy()); 3098 } 3099 3100 const BufferAllocation::Slice root_buffer = 3101 assignment_.GetUniqueTopLevelSlice(root_inst).ValueOrDie(); 3102 return EmitBufferPointer(root_buffer, root_inst->shape()); 3103 } 3104 3105 } // namespace cpu 3106 } // namespace xla 3107