1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/cpu/ir_emitter.h" 17 18 #include <stddef.h> 19 #include <stdint.h> 20 #include <algorithm> 21 #include <iterator> 22 #include <limits> 23 #include <memory> 24 #include <utility> 25 #include <vector> 26 27 #include "tensorflow/core/lib/math/math_util.h" 28 #include "tensorflow/core/platform/logging.h" 29 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" 30 #include "llvm/CodeGen/TargetRegisterInfo.h" 31 #include "llvm/CodeGen/TargetSubtargetInfo.h" 32 #include "llvm/IR/BasicBlock.h" 33 #include "llvm/IR/Constants.h" 34 #include "llvm/IR/GlobalVariable.h" 35 #include "llvm/IR/Instructions.h" 36 #include "llvm/IR/Intrinsics.h" 37 #include "llvm/IR/LLVMContext.h" 38 #include "tensorflow/compiler/xla/layout_util.h" 39 #include "tensorflow/compiler/xla/map_util.h" 40 #include "tensorflow/compiler/xla/service/buffer_assignment.h" 41 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" 42 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" 43 #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" 44 #include "tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h" 45 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" 46 #include "tensorflow/compiler/xla/service/cpu/ir_function.h" 47 #include "tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h" 48 #include "tensorflow/compiler/xla/service/cpu/shape_partition.h" 49 #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" 50 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" 51 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 52 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" 53 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" 54 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" 55 #include "tensorflow/compiler/xla/service/llvm_ir/ops.h" 56 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" 57 #include "tensorflow/compiler/xla/shape_util.h" 58 #include "tensorflow/compiler/xla/status_macros.h" 59 #include "tensorflow/compiler/xla/types.h" 60 #include "tensorflow/compiler/xla/util.h" 61 #include "tensorflow/compiler/xla/window_util.h" 62 #include "tensorflow/core/lib/core/bits.h" 63 #include "tensorflow/core/lib/core/errors.h" 64 #include "tensorflow/core/lib/gtl/array_slice.h" 65 #include "tensorflow/core/lib/gtl/flatmap.h" 66 #include "tensorflow/core/lib/gtl/flatset.h" 67 #include "tensorflow/core/lib/strings/strcat.h" 68 #include "tensorflow/core/lib/strings/stringprintf.h" 69 70 namespace xla { 71 72 namespace { 73 using llvm_ir::AsStringRef; 74 using llvm_ir::IrName; 75 using llvm_ir::SetToFirstInsertPoint; 76 namespace gtl = tensorflow::gtl; 77 } // namespace 78 79 namespace cpu { 80 81 IrEmitter::IrEmitter( 82 const HloModule& hlo_module, const BufferAssignment& assignment, 83 llvm::Module* llvm_module, 84 std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx, 85 std::unordered_map<const HloComputation*, int64> computation_to_profile_idx, 86 llvm::TargetMachine* target_machine, 87 ExternalConstantPool* external_constant_pool) 88 : assignment_(assignment), 89 module_(llvm_module), 90 arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()), 91 ir_builder_(llvm_module->getContext()), 92 instruction_to_profile_idx_(std::move(instruction_to_profile_idx)), 93 computation_to_profile_idx_(std::move(computation_to_profile_idx)), 94 alias_analysis_(hlo_module, assignment, &llvm_module->getContext()), 95 hlo_module_config_(hlo_module.config()), 96 parallel_cpu_backend_( 97 options::CpuParallelBackendRequested(hlo_module_config_)), 98 is_top_level_computation_(false), 99 target_machine_features_(target_machine), 100 external_constant_pool_(external_constant_pool) { 101 ir_builder_.setFastMathFlags(llvm_ir::GetFastMathFlags( 102 /*fast_math_enabled=*/hlo_module_config_.debug_options() 103 .xla_enable_fast_math())); 104 } 105 106 StatusOr<llvm::Function*> IrEmitter::EmitComputation( 107 HloComputation* computation, const string& function_name_prefix, 108 bool is_top_level_computation, 109 std::vector<const HloInstruction*>* instruction_order) { 110 string function_name = name_uniquer_.GetUniqueName(function_name_prefix); 111 VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix 112 << "]; ordered? " << (instruction_order != nullptr); 113 is_top_level_computation_ = is_top_level_computation; 114 num_dynamic_loop_bounds_ = 0; 115 if (!computation->root_instruction()->outer_dimension_partitions().empty()) { 116 num_dynamic_loop_bounds_ = 117 computation->root_instruction()->outer_dimension_partitions().size(); 118 } 119 120 InitializeIrFunction(function_name); 121 // The rdtscp instruction is x86 specific. We will fallback to LLVM's generic 122 // readcyclecounter if it is unavailable. 123 bool use_rdtscp = arch_type_ == llvm::Triple::ArchType::x86 || 124 arch_type_ == llvm::Triple::ArchType::x86_64; 125 profiling_state_ = ProfilingState(use_rdtscp, GetProfileCountersArgument()); 126 if (instruction_order == nullptr) { 127 TF_RETURN_IF_ERROR(computation->Accept(this)); 128 } else { 129 TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, *instruction_order)); 130 } 131 llvm::Function* ir_function = compute_function_->function(); 132 InsertOrDie(&emitted_functions_, computation, ir_function); 133 // Delete 'compute_function', finalizing 'ir_function' and restoring caller 134 // IR insert point. 135 compute_function_.reset(); 136 return ir_function; 137 } 138 139 void IrEmitter::InitializeIrFunction(const string& function_name) { 140 // Functions with local linkage get an inlining bonus. Because we know 141 // a-priori that embedded functions (non-entry functions) will not have its 142 // name resolved, give it local linkage. 143 llvm::Function::LinkageTypes linkage = 144 is_top_level_computation_ ? llvm::GlobalValue::ExternalLinkage 145 : llvm::GlobalValue::InternalLinkage; 146 // Create and initialize new IrFunction. 147 compute_function_.reset( 148 new IrFunction(function_name, linkage, 149 options::OptimizeForSizeRequested(hlo_module_config_), 150 hlo_module_config_.debug_options().xla_enable_fast_math(), 151 module_, &ir_builder_, num_dynamic_loop_bounds_)); 152 } 153 154 IrEmitter::~IrEmitter() {} 155 156 Status IrEmitter::HandleBitcast(HloInstruction* bitcast) { 157 VLOG(2) << "HandleBitcast: " << bitcast->ToString(); 158 emitted_value_[bitcast] = 159 ir_builder_.CreateBitCast(GetEmittedValueFor(bitcast->operand(0)), 160 IrShapeType(bitcast->shape())->getPointerTo(), 161 AsStringRef(IrName(bitcast))); 162 return Status::OK(); 163 } 164 165 Status IrEmitter::HandleConstant(HloInstruction* constant) { 166 VLOG(2) << "HandleConstant: " << constant->ToString(); 167 const Literal& literal = constant->literal(); 168 llvm::GlobalVariable* global_for_const; 169 170 // We avoid creating large constants in the LLVM IR since LLVM is not 171 // efficient for large constant arrays. We still emit "small enough" constant 172 // arrays into the Ir, in the off chance the LLVM optimizer can do something 173 // interesting with it. 174 const int kMaxInternalConstantSizeInBytes = 128; 175 if (external_constant_pool_ && 176 ByteSizeOf(literal.shape()) >= kMaxInternalConstantSizeInBytes) { 177 string global_name = tensorflow::strings::StrCat( 178 "constant_global_", external_global_constant_counter_++); 179 global_for_const = new llvm::GlobalVariable( 180 /*Module=*/*module_, 181 /*Type=*/IrShapeType(literal.shape()), 182 /*isConstant=*/true, 183 /*Linkage=*/llvm::GlobalValue::ExternalLinkage, 184 /*Initializer=*/nullptr, 185 /*Name=*/AsStringRef(global_name)); 186 global_for_const->setAlignment(MinimumAlignmentForShape(literal.shape())); 187 external_constant_pool_->Insert(global_name, literal, 188 MinimumAlignmentForShape(literal.shape())); 189 } else { 190 llvm::Constant* initializer = 191 llvm_ir::ConvertLiteralToIrConstant(literal, module_); 192 global_for_const = new llvm::GlobalVariable( 193 /*Module=*/*module_, 194 /*Type=*/initializer->getType(), 195 /*isConstant=*/true, 196 /*Linkage=*/llvm::GlobalValue::PrivateLinkage, 197 /*Initializer=*/initializer, 198 /*Name=*/""); 199 global_for_const->setAlignment(MinimumAlignmentForShape(literal.shape())); 200 } 201 emitted_value_[constant] = global_for_const; 202 VLOG(2) << " emitted value: " << llvm_ir::DumpToString(*global_for_const); 203 VLOG(2) << " its type: " 204 << llvm_ir::DumpToString(*global_for_const->getType()); 205 return Status::OK(); 206 } 207 208 Status IrEmitter::HandleCopy(HloInstruction* copy) { 209 if (ShapeUtil::IsTuple(copy->shape())) { 210 // kCopy shallow copies a tuple so just memcpy the top-level buffer. 211 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(copy)); 212 return EmitMemcpy(*(copy->operand(0)), *copy); 213 } else { 214 // Use the elemental emitter for non-tuple shapes. 215 return DefaultAction(copy); 216 } 217 } 218 219 // Calculate the alignment of a buffer with a particular size. 220 int IrEmitter::MinimumAlignmentForBufferSize(int64 buffer_size) { 221 // GLibc returns a pointer with alignment 8 on 32-bit platforms and 16 on 222 // 64-bit platforms. TCMalloc returns a pointer with alignment 8 for 223 // allocations smaller than kMallocAlignmentThreshold bytes and at least 224 // alignment 16 for allocations greater than or equal to 225 // kMallocAlignmentThreshold bytes. N.B. We could improve on this lower bound 226 // by explicitly allocating the memory with posix_memalign. This is 227 // complicated by our desire to allow parameter buffers created by clients to 228 // be consumed directly by the JIT. 229 if (buffer_size == 0) { 230 // No need to align empty buffers. 231 return 1; 232 } 233 234 const int64 kMallocAlignmentThreshold = 512; 235 236 int pointer_size = module_->getDataLayout().getPointerSize(); 237 int buffer_alignment = buffer_size >= kMallocAlignmentThreshold 238 ? 2 * pointer_size 239 : pointer_size; 240 DCHECK_GT(buffer_alignment, 0); 241 242 return buffer_alignment; 243 } 244 245 // Calculate the alignment of a buffer allocated for a given primitive type. 246 int IrEmitter::MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type) { 247 int64 byte_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); 248 DCHECK_GE(byte_size, 0); 249 // Largest scalar is a complex64 so we don't need to worry about the 250 // int64->int truncation here. 251 DCHECK_LE(byte_size, 8); 252 return byte_size; 253 } 254 255 int64 IrEmitter::ByteSizeOf(const Shape& shape) const { 256 return llvm_ir::ByteSizeOf(shape, module_->getDataLayout()); 257 } 258 259 // Calculate the alignment of a buffer allocated for a given shape. 260 int IrEmitter::MinimumAlignmentForShape(const Shape& shape) { 261 if (ShapeUtil::IsScalar(shape)) { 262 return MinimumAlignmentForPrimitiveType(shape.element_type()); 263 } 264 265 int64 buffer_size = ByteSizeOf(shape); 266 DCHECK_GE(buffer_size, 0); 267 DCHECK_LE(buffer_size, SIZE_MAX); 268 269 return MinimumAlignmentForBufferSize(buffer_size); 270 } 271 272 void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load, 273 const Shape& shape) { 274 int alignment = MinimumAlignmentForShape(shape); 275 if (alignment > 1) { 276 llvm_ir::SetAlignmentMetadataForLoad(load, alignment); 277 } 278 } 279 280 void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load, 281 int64 buffer_size) { 282 int alignment = MinimumAlignmentForBufferSize(buffer_size); 283 if (alignment > 1) { 284 llvm_ir::SetAlignmentMetadataForLoad(load, alignment); 285 } 286 } 287 288 void IrEmitter::AttachDereferenceableMetadataForLoad(llvm::LoadInst* load, 289 const Shape& shape) { 290 AttachDereferenceableMetadataForLoad(load, ByteSizeOf(shape)); 291 } 292 293 void IrEmitter::AttachDereferenceableMetadataForLoad(llvm::LoadInst* load, 294 int64 buffer_size) { 295 if (buffer_size > 0) { 296 llvm_ir::SetDereferenceableMetadataForLoad(load, buffer_size); 297 } 298 } 299 300 Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) { 301 // A tuple is an array of pointers, one for each operand. Each pointer points 302 // to the output buffer of its corresponding operand. A GetTupleElement 303 // instruction forwards a pointer to the tuple element buffer at the given 304 // index. 305 auto operand = get_tuple_element->operand(0); 306 const Shape& shape = get_tuple_element->shape(); 307 emitted_value_[get_tuple_element] = llvm_ir::EmitGetTupleElement( 308 shape, get_tuple_element->tuple_index(), MinimumAlignmentForShape(shape), 309 GetEmittedValueFor(operand), &ir_builder_, module_); 310 return Status::OK(); 311 } 312 313 Status IrEmitter::HandleSelect(HloInstruction* select) { 314 auto pred = select->operand(0); 315 auto on_true = select->operand(1); 316 auto on_false = select->operand(2); 317 TF_RET_CHECK(pred->shape().element_type() == PRED); 318 319 if (ShapeUtil::IsTuple(select->shape())) { 320 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(select)); 321 llvm_ir::EmitTupleSelect( 322 GetIrArrayFor(select), GetIrArrayFor(pred), GetEmittedValueFor(on_true), 323 GetEmittedValueFor(on_false), &ir_builder_, module_); 324 return Status::OK(); 325 } 326 327 return DefaultAction(select); 328 } 329 330 Status IrEmitter::HandleInfeed(HloInstruction* infeed) { 331 VLOG(2) << "HandleInfeed: " << infeed->ToString(); 332 333 const Shape& shape = infeed->shape(); 334 335 // The infeed operation produces data (dequeued from the infeed queue) at this 336 // address, which has been provided by buffer assignment. 337 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(infeed)); 338 llvm_ir::IrArray infeed_array = GetIrArrayFor(infeed); 339 340 if (ShapeUtil::IsTuple(shape)) { 341 TF_RET_CHECK(!ShapeUtil::IsNestedTuple(shape)); 342 343 // For a tuple, we first copy each of the internal elements to 344 // their corresponding target locations. We then construct the 345 // tuple outer buffer containing pointers to the internal 346 // elements. 347 std::vector<llvm::Value*> tuple_element_addresses; 348 for (int64 i = 0; i < shape.tuple_shapes_size(); ++i) { 349 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice buffer, 350 assignment_.GetUniqueSlice(infeed, {i})); 351 352 const Shape& tuple_element_shape = 353 ShapeUtil::GetTupleElementShape(shape, i); 354 355 // Only the outer tuple buffer's target address is obtained from 356 // GetEmittedValueFor, to handle the case when Infeed is the root 357 // instruction. Target addresses for internal elements can be obtained 358 // from EmitTempBufferPointer. 359 llvm::Value* tuple_element_address = 360 EmitTempBufferPointer(buffer, tuple_element_shape); 361 362 TF_RETURN_IF_ERROR(EmitXfeedTransfer( 363 XfeedKind::kInfeed, tuple_element_shape, tuple_element_address)); 364 365 tuple_element_addresses.push_back(tuple_element_address); 366 } 367 368 llvm_ir::EmitTuple(infeed_array, tuple_element_addresses, &ir_builder_, 369 module_); 370 } else { 371 TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kInfeed, shape, 372 GetEmittedValueFor(infeed))); 373 } 374 375 return Status::OK(); 376 } 377 378 Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, 379 llvm::Value* program_buffer_address) { 380 int64 length = ByteSizeOf(shape); 381 if (length <= 0 || length > std::numeric_limits<int32>::max()) { 382 return InvalidArgument( 383 "xfeed (infeed or outfeed) buffer length %lld is outside the valid " 384 "size range", 385 length); 386 } 387 int32 length_32 = static_cast<int32>(length); 388 389 int32 shape_length; 390 TF_ASSIGN_OR_RETURN(llvm::Value * shape_ptr, 391 llvm_ir::EncodeSelfDescribingShapeConstant( 392 shape, &shape_length, &ir_builder_)); 393 394 // The signature of the acquire infeed buffer function is: 395 // 396 // (void*)(int32 length); 397 llvm::Type* int32_type = ir_builder_.getInt32Ty(); 398 llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); 399 llvm::FunctionType* acquire_type = llvm::FunctionType::get( 400 i8_ptr_type, {int32_type, i8_ptr_type, int32_type}, 401 /*isVarArg=*/false); 402 403 llvm::Function* acquire_func; 404 if (kind == XfeedKind::kInfeed) { 405 acquire_func = llvm::cast<llvm::Function>(module_->getOrInsertFunction( 406 runtime::kAcquireInfeedBufferForDequeueSymbolName, acquire_type)); 407 } else { 408 acquire_func = llvm::cast<llvm::Function>(module_->getOrInsertFunction( 409 runtime::kAcquireOutfeedBufferForPopulationSymbolName, acquire_type)); 410 } 411 acquire_func->setCallingConv(llvm::CallingConv::C); 412 413 // The signature of the release infeed buffer function is: 414 // 415 // (void)(int32 length, void* buffer); 416 llvm::FunctionType* release_type = llvm::FunctionType::get( 417 ir_builder_.getVoidTy(), 418 {int32_type, i8_ptr_type, i8_ptr_type, int32_type}, 419 /*isVarArg=*/false); 420 421 llvm::Function* release_func; 422 if (kind == XfeedKind::kInfeed) { 423 release_func = llvm::cast<llvm::Function>(module_->getOrInsertFunction( 424 runtime::kReleaseInfeedBufferAfterDequeueSymbolName, release_type)); 425 } else { 426 release_func = llvm::cast<llvm::Function>(module_->getOrInsertFunction( 427 runtime::kReleaseOutfeedBufferAfterPopulationSymbolName, release_type)); 428 } 429 release_func->setCallingConv(llvm::CallingConv::C); 430 431 // Implementation note: this call informs the runtime that it wants a buffer 432 // of size exactly 'length_32', and the runtime is responsible for 433 // check-failing the process if there is a mismatch, versus passing us back a 434 // buffer that we might overrun. 435 llvm::Value* acquired_pointer = ir_builder_.CreateCall( 436 acquire_func, {ir_builder_.getInt32(length_32), shape_ptr, 437 ir_builder_.getInt32(shape_length)}); 438 439 if (kind == XfeedKind::kInfeed) { 440 // Copy to the program buffer address from the acquired buffer. 441 ir_builder_.CreateMemCpy(program_buffer_address, acquired_pointer, 442 length_32, 1); 443 } else { 444 // Outfeed -- copy from the in-program address to the acquired buffer. 445 ir_builder_.CreateMemCpy(acquired_pointer, program_buffer_address, 446 length_32, 1); 447 } 448 449 ir_builder_.CreateCall(release_func, 450 {ir_builder_.getInt32(length_32), acquired_pointer, 451 shape_ptr, ir_builder_.getInt32(shape_length)}); 452 453 return Status::OK(); 454 } 455 456 Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) { 457 HloInstruction* operand = outfeed->operands()[0]; 458 const Shape& operand_shape = operand->shape(); 459 460 llvm::Value* value = GetEmittedValueFor(operand); 461 if (!ShapeUtil::IsTuple(operand_shape)) { 462 return EmitXfeedTransfer(XfeedKind::kOutfeed, operand_shape, value); 463 } 464 465 TF_RET_CHECK(!ShapeUtil::IsNestedTuple(operand_shape)); 466 467 for (int64 i = 0; i < operand_shape.tuple_shapes_size(); ++i) { 468 const Shape& tuple_element_shape = 469 ShapeUtil::GetTupleElementShape(operand_shape, i); 470 llvm::Value* tuple_element = llvm_ir::EmitGetTupleElement( 471 tuple_element_shape, i, MinimumAlignmentForShape(tuple_element_shape), 472 value, &ir_builder_, module_); 473 TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kOutfeed, 474 tuple_element_shape, tuple_element)); 475 } 476 477 return Status::OK(); 478 } 479 480 Status IrEmitter::HandleSort(HloInstruction* sort) { 481 // TODO(b/26783907): Implement sort on CPU. 482 return Unimplemented("Sort is not implemented on CPU."); 483 } 484 485 Status IrEmitter::HandleTuple(HloInstruction* tuple) { 486 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(tuple)); 487 std::vector<llvm::Value*> base_ptrs; 488 for (auto operand : tuple->operands()) { 489 base_ptrs.push_back(GetEmittedValueFor(operand)); 490 } 491 llvm_ir::EmitTuple(GetIrArrayFor(tuple), base_ptrs, &ir_builder_, module_); 492 return Status::OK(); 493 } 494 495 Status IrEmitter::HandleMap(HloInstruction* map) { 496 gtl::ArraySlice<HloInstruction*> operands(map->operands()); 497 HloComputation* function = map->to_apply(); 498 // The called computation should have been emitted previously. 499 llvm::Function* mapped_ir_function = FindOrDie(emitted_functions_, function); 500 501 return EmitTargetElementLoop(map, [this, map, operands, mapped_ir_function]( 502 const llvm_ir::IrArray::Index& index) { 503 std::vector<llvm::Value*> parameter_addresses; 504 for (const HloInstruction* operand : operands) { 505 const llvm_ir::IrArray& array = GetIrArrayFor(operand); 506 parameter_addresses.push_back( 507 array.EmitArrayElementAddress(index, &ir_builder_)); 508 } 509 return EmitElementFunctionCall(mapped_ir_function, map->shape(), 510 parameter_addresses, "map_function"); 511 }); 512 } 513 514 Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { 515 auto operand = reduce_window->operand(0); 516 const Window& window = reduce_window->window(); 517 HloComputation* function = reduce_window->to_apply(); 518 TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( 519 /*instruction=*/*reduce_window, /*operands=*/{operand}, 520 /*supported_types=*/{F32, BF16})); 521 522 // TODO(b/31410564): Implement dilation for reduce-window. 523 if (window_util::HasDilation(window)) { 524 return Unimplemented( 525 "Dilation for ReduceWindow is not implemented on CPU."); 526 } 527 528 // The called computation should have been emitted previously. 529 llvm::Function* reducer_function = FindOrDie(emitted_functions_, function); 530 531 // Pseudo code for reduce window: 532 // 533 // for (coordinates O in the output) 534 // value = init_value; 535 // for (coordinates W in the window) 536 // for each index i: 537 // input coordinates I_i = O_i * stride_i + W_i - pad_low_i 538 // if I within bounds of input: 539 // value = function(value, input(I)); 540 // output(O) = value; 541 // 542 // This is completely un-optimized and just here to have something 543 // that works. 544 return EmitTargetElementLoop( 545 reduce_window, [this, reduce_window, operand, window, 546 reducer_function](const llvm_ir::IrArray::Index& index) { 547 // We fold inputs into the accumulator and initialize it to 548 // the initial value on the reduce_window. 549 PrimitiveType operand_element_type = operand->shape().element_type(); 550 llvm::Value* accumulator_address = llvm_ir::EmitAllocaAtFunctionEntry( 551 llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), 552 "reduce_window_accumulator_address", &ir_builder_, 553 MinimumAlignmentForPrimitiveType(operand_element_type)); 554 ir_builder_.CreateStore(ir_builder_.CreateLoad(GetEmittedValueFor( 555 reduce_window->operand(1))), 556 accumulator_address); 557 558 llvm_ir::ForLoopNest loops(IrName(reduce_window, "inner"), 559 &ir_builder_); 560 std::vector<int64> window_size; 561 for (const auto& dim : window.dimensions()) { 562 window_size.push_back(dim.size()); 563 } 564 const llvm_ir::IrArray::Index window_index = loops.AddLoopsForShape( 565 ShapeUtil::MakeShape(operand_element_type, window_size), "window"); 566 CHECK_EQ(window_index.size(), index.size()); 567 568 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); 569 570 llvm_ir::IrArray::Index input_index(index.size()); 571 llvm::Value* in_bounds_condition = nullptr; 572 for (size_t i = 0; i < index.size(); ++i) { 573 llvm::Value* strided_index = ir_builder_.CreateNSWMul( 574 index[i], ir_builder_.getInt64(window.dimensions(i).stride())); 575 input_index[i] = ir_builder_.CreateNSWSub( 576 ir_builder_.CreateNSWAdd(strided_index, window_index[i]), 577 ir_builder_.getInt64(window.dimensions(i).padding_low())); 578 579 // We need to check if 0 <= input_index[i] < bound, as 580 // otherwise we are in the padding so that we can skip the 581 // computation. That is equivalent to input_index[i] < bound 582 // as an *unsigned* comparison, since a negative value will 583 // wrap to a large positive value. 584 llvm::Value* index_condition = ir_builder_.CreateICmpULT( 585 input_index[i], ir_builder_.getInt64(ShapeUtil::GetDimension( 586 operand->shape(), i))); 587 if (in_bounds_condition == nullptr) { 588 in_bounds_condition = index_condition; 589 } else { 590 in_bounds_condition = 591 ir_builder_.CreateAnd(in_bounds_condition, index_condition); 592 } 593 } 594 CHECK(in_bounds_condition != nullptr); 595 596 llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( 597 in_bounds_condition, "in-bounds", &ir_builder_); 598 SetToFirstInsertPoint(if_data.true_block, &ir_builder_); 599 600 // We are not in the padding, so carry out the computation. 601 llvm_ir::IrArray input_array(GetIrArrayFor(operand)); 602 llvm::Value* input_value_address = 603 input_array.EmitArrayElementAddress(input_index, &ir_builder_); 604 llvm::Value* result = EmitElementFunctionCall( 605 reducer_function, reduce_window->shape(), 606 {accumulator_address, input_value_address}, "reducer_function"); 607 ir_builder_.CreateStore(result, accumulator_address); 608 609 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_); 610 return ir_builder_.CreateLoad(accumulator_address); 611 }); 612 } 613 614 Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { 615 CHECK_EQ(select_and_scatter->operand_count(), 3); 616 const auto operand = select_and_scatter->operand(0); 617 const auto source = select_and_scatter->operand(1); 618 const auto init_value = select_and_scatter->operand(2); 619 const Window& window = select_and_scatter->window(); 620 PrimitiveType operand_element_type = operand->shape().element_type(); 621 const int64 rank = ShapeUtil::Rank(operand->shape()); 622 CHECK_EQ(rank, ShapeUtil::Rank(source->shape())); 623 CHECK_EQ(rank, window.dimensions_size()); 624 625 // TODO(b/31410564): Implement dilation for select-and-scatter. 626 if (window_util::HasDilation(window)) { 627 return Unimplemented( 628 "Dilation for SelectAndScatter is not implemented on CPU. "); 629 } 630 631 // The select and scatter computations should have been emitted previously. 632 llvm::Function* select_function = 633 FindOrDie(emitted_functions_, select_and_scatter->select()); 634 llvm::Function* scatter_function = 635 FindOrDie(emitted_functions_, select_and_scatter->scatter()); 636 637 // Pseudo code for select-and-scatter: 638 // 639 // initialized_flag is initially off for every window, and is turned on after 640 // the first iteration is completed and the first operand value is selected. 641 // 642 // output(*) = init_value 643 // for (coordinates S in the source) { 644 // initialized_flag = false 645 // for (coordinates W in the window) { 646 // I = S * stride + W - pad_low 647 // if I within bounds of operand: 648 // if !initialized_flag or select(selected_value, operand(I)) == false: 649 // selected_value = operand(I) 650 // selected_index = I 651 // initialized_flag = true 652 // } 653 // output(selected_index) = scatter(output(selected_index), source(S)) 654 // } 655 // 656 657 // Initialize the output array with the given init_value. 658 TF_RETURN_IF_ERROR(EmitTargetElementLoop( 659 select_and_scatter, /*desc=*/IrName(select_and_scatter, "init"), 660 [this, init_value](const llvm_ir::IrArray::Index& target_index) { 661 llvm::Value* init_value_addr = GetEmittedValueFor(init_value); 662 return ir_builder_.CreateLoad(init_value_addr); 663 })); 664 665 // Create a loop to iterate over the source array to scatter to the output. 666 llvm_ir::ForLoopNest source_loops(IrName(select_and_scatter), &ir_builder_); 667 const llvm_ir::IrArray::Index source_index = 668 source_loops.AddLoopsForShape(source->shape(), "source"); 669 SetToFirstInsertPoint(source_loops.GetInnerLoopBodyBasicBlock(), 670 &ir_builder_); 671 672 // Allocate space to keep the currently selected value, its index, and 673 // the boolean initialized_flag, which is initially set to false. 674 llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry( 675 llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), 676 "selected_value_address", &ir_builder_, 677 MinimumAlignmentForPrimitiveType(operand_element_type)); 678 llvm::Value* selected_index_address = 679 llvm_ir::EmitAllocaAtFunctionEntryWithCount( 680 ir_builder_.getInt64Ty(), ir_builder_.getInt32(rank), 681 "selected_index_address", &ir_builder_); 682 llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry( 683 ir_builder_.getInt1Ty(), "initialized_flag_address", &ir_builder_); 684 ir_builder_.CreateStore(ir_builder_.getInt1(false), initialized_flag_address); 685 686 // Create the inner loop to iterate over the window. 687 llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "window"), 688 &ir_builder_); 689 std::vector<int64> window_size; 690 for (const auto& dim : window.dimensions()) { 691 window_size.push_back(dim.size()); 692 } 693 const llvm_ir::IrArray::Index window_index = window_loops.AddLoopsForShape( 694 ShapeUtil::MakeShape(operand_element_type, window_size), "window"); 695 SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(), 696 &ir_builder_); 697 698 // Compute the operand index to visit and evaluate the condition whether the 699 // operand index is within the bounds. The unsigned comparison includes 700 // checking whether the operand index >= 0. 701 llvm_ir::IrArray::Index operand_index(source_index.size()); 702 llvm::Value* in_bounds_condition = ir_builder_.getTrue(); 703 for (int64 i = 0; i < rank; ++i) { 704 llvm::Value* strided_index = ir_builder_.CreateNSWMul( 705 source_index[i], ir_builder_.getInt64(window.dimensions(i).stride())); 706 operand_index[i] = ir_builder_.CreateNSWSub( 707 ir_builder_.CreateNSWAdd(strided_index, window_index[i]), 708 ir_builder_.getInt64(window.dimensions(i).padding_low())); 709 llvm::Value* index_condition = ir_builder_.CreateICmpULT( 710 operand_index[i], 711 ir_builder_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); 712 in_bounds_condition = 713 ir_builder_.CreateAnd(in_bounds_condition, index_condition); 714 } 715 CHECK(in_bounds_condition != nullptr); 716 717 // Only need to do something if the operand index is within the bounds. First 718 // check if the initialized_flag is set. 719 llvm_ir::LlvmIfData if_in_bounds = 720 llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &ir_builder_); 721 SetToFirstInsertPoint(if_in_bounds.true_block, &ir_builder_); 722 llvm_ir::LlvmIfData if_initialized = 723 llvm_ir::EmitIfThenElse(ir_builder_.CreateLoad(initialized_flag_address), 724 "initialized", &ir_builder_); 725 726 // If the initialized_flag is false, initialize the selected value and index 727 // with the currently visiting operand. 728 SetToFirstInsertPoint(if_initialized.false_block, &ir_builder_); 729 const auto save_operand_index = 730 [&](const llvm_ir::IrArray::Index& operand_index) { 731 for (int64 i = 0; i < rank; ++i) { 732 llvm::Value* selected_index_address_slot = 733 ir_builder_.CreateInBoundsGEP(selected_index_address, 734 {ir_builder_.getInt32(i)}); 735 ir_builder_.CreateStore(operand_index[i], 736 selected_index_address_slot); 737 } 738 }; 739 llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); 740 llvm::Value* operand_data = 741 operand_array.EmitReadArrayElement(operand_index, &ir_builder_); 742 ir_builder_.CreateStore(operand_data, selected_value_address); 743 save_operand_index(operand_index); 744 ir_builder_.CreateStore(ir_builder_.getInt1(true), initialized_flag_address); 745 746 // If the initialized_flag is true, call the `select` function to potentially 747 // update the selected value and index with the currently visiting operand. 748 SetToFirstInsertPoint(if_initialized.true_block, &ir_builder_); 749 const Shape output_shape = ShapeUtil::MakeShape(PRED, {}); 750 llvm::Value* operand_address = 751 operand_array.EmitArrayElementAddress(operand_index, &ir_builder_); 752 llvm::Value* result = EmitElementFunctionCall( 753 select_function, output_shape, {selected_value_address, operand_address}, 754 "select_function"); 755 756 // If the 'select' function returns false, update the selected value and the 757 // index to the currently visiting operand. 758 llvm::Value* cond = ir_builder_.CreateICmpNE( 759 result, 760 llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0), 761 "boolean_predicate"); 762 llvm_ir::LlvmIfData if_select_lhs = 763 llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &ir_builder_); 764 SetToFirstInsertPoint(if_select_lhs.false_block, &ir_builder_); 765 ir_builder_.CreateStore(ir_builder_.CreateLoad(operand_address), 766 selected_value_address); 767 save_operand_index(operand_index); 768 769 // After iterating over the window elements, scatter the source element to 770 // the selected index of the output. The value we store at the output 771 // location is computed by calling the `scatter` function with the source 772 // value and the current output value. 773 SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(), 774 &ir_builder_); 775 llvm_ir::IrArray::Index selected_index; 776 for (int64 i = 0; i < rank; ++i) { 777 llvm::Value* selected_index_address_slot = ir_builder_.CreateInBoundsGEP( 778 selected_index_address, {ir_builder_.getInt32(i)}); 779 selected_index.push_back( 780 ir_builder_.CreateLoad(selected_index_address_slot)); 781 } 782 llvm_ir::IrArray source_array(GetIrArrayFor(source)); 783 llvm::Value* source_value_address = 784 source_array.EmitArrayElementAddress(source_index, &ir_builder_); 785 llvm_ir::IrArray output_array(GetIrArrayFor(select_and_scatter)); 786 llvm::Value* output_value_address = 787 output_array.EmitArrayElementAddress(selected_index, &ir_builder_); 788 llvm::Value* scatter_value = EmitElementFunctionCall( 789 scatter_function, source->shape(), 790 {output_value_address, source_value_address}, "scatter_function"); 791 output_array.EmitWriteArrayElement(selected_index, scatter_value, 792 &ir_builder_); 793 794 SetToFirstInsertPoint(source_loops.GetOuterLoopExitBasicBlock(), 795 &ir_builder_); 796 return Status::OK(); 797 } 798 799 Status IrEmitter::HandleDot(HloInstruction* dot) { 800 auto lhs = dot->operand(0); 801 auto rhs = dot->operand(1); 802 TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( 803 /*instruction=*/*dot, /*operands=*/{lhs, rhs}, 804 /*supported_types=*/{F16, F32, F64, C64})); 805 const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); 806 if (dnums.lhs_batch_dimensions_size() > 0 || 807 dnums.rhs_batch_dimensions_size() > 0) { 808 return Unimplemented("Dot with batch dimensions not implemented."); 809 } 810 811 if (dnums.lhs_contracting_dimensions_size() != 1) { 812 // This is disallowed by ShapeInference today. 813 return Unimplemented( 814 "Dot with multiple contracting dimensions not implemented."); 815 } 816 817 if (dnums.lhs_contracting_dimensions(0) != 818 std::min(lhs->shape().dimensions_size() - 1, 1) || 819 dnums.rhs_contracting_dimensions(0) != 0) { 820 return Unimplemented( 821 "Dot with non-standard contracting dimensions not implemented."); 822 } 823 824 llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs)); 825 llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs)); 826 827 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dot)); 828 llvm_ir::IrArray target_array = GetIrArrayFor(dot); 829 830 VLOG(2) << "HandleDot: "; 831 VLOG(2) << " lhs operand: " 832 << llvm_ir::DumpToString(*lhs_array.GetBasePointer()); 833 VLOG(2) << " rhs operand: " 834 << llvm_ir::DumpToString(*rhs_array.GetBasePointer()); 835 VLOG(2) << " target: " 836 << llvm_ir::DumpToString(*target_array.GetBasePointer()); 837 838 // Dot operation is complicated so we delegate to a helper class. 839 return DotOpEmitter::EmitDotOperation( 840 *dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array, 841 lhs_array, rhs_array, /*addend_array=*/nullptr, 842 GetExecutableRunOptionsArgument(), &ir_builder_, hlo_module_config_, 843 target_machine_features_); 844 } 845 846 Status IrEmitter::HandleConvolution(HloInstruction* convolution) { 847 auto lhs = convolution->operand(0); 848 auto rhs = convolution->operand(1); 849 const auto& window = convolution->window(); 850 TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( 851 /*instruction=*/*convolution, /*operands=*/{lhs, rhs}, 852 /*supported_types=*/{F16, F32, C64})); 853 854 const ConvolutionDimensionNumbers& dnums = 855 convolution->convolution_dimension_numbers(); 856 857 if (PotentiallyImplementedAsEigenConvolution(*convolution)) { 858 const Shape& lhs_shape = lhs->shape(); 859 const Shape& rhs_shape = rhs->shape(); 860 const Shape& convolution_shape = convolution->shape(); 861 // The input, kernel and output agree with respect to layout. 862 if (LayoutUtil::IsMonotonicWithDim0Major(lhs_shape.layout()) && 863 LayoutUtil::IsMonotonicWithDim0Major(rhs_shape.layout()) && 864 LayoutUtil::IsMonotonicWithDim0Major(convolution_shape.layout())) { 865 // We lower 1D convolutions into calls to the same Eigen function as 2D 866 // convolutions, except that we pretend that the 1D convolution is really 867 // a 2D convolution with the missing dimension set to 1. We also adjust 868 // the padding, dilation parameters as needed. 869 bool one_dim_convolution = lhs_shape.dimensions_size() == 3; 870 llvm::Value* lhs_address = GetEmittedValueFor(lhs); 871 llvm::Value* rhs_address = GetEmittedValueFor(rhs); 872 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(convolution)); 873 874 const ConvolutionDimensionNumbers& dnums = 875 convolution->convolution_dimension_numbers(); 876 877 // Input tensor. 878 const Shape& input_shape = convolution->operand(0)->shape(); 879 int64 input_batch = input_shape.dimensions(dnums.input_batch_dimension()); 880 int64 input_rows = 881 input_shape.dimensions(dnums.input_spatial_dimensions(0)); 882 int64 input_cols = 883 one_dim_convolution 884 ? 1 885 : input_shape.dimensions(dnums.input_spatial_dimensions(1)); 886 int64 input_channels = 887 input_shape.dimensions(dnums.input_feature_dimension()); 888 889 // Kernel tensor. 890 const Shape& kernel_shape = convolution->operand(1)->shape(); 891 int64 kernel_rows = 892 kernel_shape.dimensions(dnums.kernel_spatial_dimensions(0)); 893 int64 kernel_cols = 894 one_dim_convolution 895 ? 1 896 : kernel_shape.dimensions(dnums.kernel_spatial_dimensions(1)); 897 int64 kernel_channels = 898 kernel_shape.dimensions(dnums.kernel_input_feature_dimension()); 899 int64 kernel_filters = 900 kernel_shape.dimensions(dnums.kernel_output_feature_dimension()); 901 902 // Output tensor. 903 const Shape& convolution_shape = convolution->shape(); 904 int64 output_rows = 905 convolution_shape.dimensions(dnums.output_spatial_dimensions(0)); 906 int64 output_cols = one_dim_convolution 907 ? 1 908 : convolution_shape.dimensions( 909 dnums.output_spatial_dimensions(1)); 910 911 // Extract the window stride for the convolution. 912 const Window& window = convolution->window(); 913 int64 row_stride = window.dimensions(0).stride(); 914 int64 col_stride = 915 one_dim_convolution ? 1 : window.dimensions(1).stride(); 916 917 int64 padding_top = window.dimensions(0).padding_low(); 918 int64 padding_bottom = window.dimensions(0).padding_high(); 919 int64 padding_left = 920 one_dim_convolution ? 0 : window.dimensions(1).padding_low(); 921 int64 padding_right = 922 one_dim_convolution ? 0 : window.dimensions(1).padding_high(); 923 924 int64 lhs_row_dilation = window.dimensions(0).base_dilation(); 925 int64 lhs_col_dilation = 926 one_dim_convolution ? 1 : window.dimensions(1).base_dilation(); 927 int64 rhs_row_dilation = window.dimensions(0).window_dilation(); 928 int64 rhs_col_dilation = 929 one_dim_convolution ? 1 : window.dimensions(1).window_dilation(); 930 931 PrimitiveType primitive_type = lhs->shape().element_type(); 932 llvm::Type* ir_ptr_type = primitive_type == F16 933 ? ir_builder_.getHalfTy()->getPointerTo() 934 : ir_builder_.getFloatTy()->getPointerTo(); 935 llvm::Type* int64_type = ir_builder_.getInt64Ty(); 936 llvm::Type* int8_ptr_type = ir_builder_.getInt8Ty()->getPointerTo(); 937 llvm::FunctionType* conv_type = llvm::FunctionType::get( 938 ir_builder_.getVoidTy(), 939 {int8_ptr_type, ir_ptr_type, ir_ptr_type, ir_ptr_type, int64_type, 940 int64_type, int64_type, int64_type, int64_type, int64_type, 941 int64_type, int64_type, int64_type, int64_type, int64_type, 942 int64_type, int64_type, int64_type, int64_type, int64_type, 943 int64_type, int64_type, int64_type, int64_type}, 944 /*isVarArg=*/false); 945 bool multi_threaded_eigen = 946 hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); 947 const char* fn_name = 948 primitive_type == F16 949 ? (multi_threaded_eigen 950 ? runtime::kEigenConvF16SymbolName 951 : runtime::kEigenSingleThreadedConvF16SymbolName) 952 : (multi_threaded_eigen 953 ? runtime::kEigenConvF32SymbolName 954 : runtime::kEigenSingleThreadedConvF32SymbolName); 955 llvm::Function* conv_func = llvm::cast<llvm::Function>( 956 module_->getOrInsertFunction(fn_name, conv_type)); 957 conv_func->setCallingConv(llvm::CallingConv::C); 958 conv_func->setDoesNotThrow(); 959 conv_func->setOnlyAccessesArgMemory(); 960 ir_builder_.CreateCall( 961 conv_func, { 962 GetExecutableRunOptionsArgument(), 963 ir_builder_.CreateBitCast( 964 GetEmittedValueFor(convolution), ir_ptr_type), 965 ir_builder_.CreateBitCast(lhs_address, ir_ptr_type), 966 ir_builder_.CreateBitCast(rhs_address, ir_ptr_type), 967 ir_builder_.getInt64(input_batch), 968 ir_builder_.getInt64(input_rows), 969 ir_builder_.getInt64(input_cols), 970 ir_builder_.getInt64(input_channels), 971 ir_builder_.getInt64(kernel_rows), 972 ir_builder_.getInt64(kernel_cols), 973 ir_builder_.getInt64(kernel_channels), 974 ir_builder_.getInt64(kernel_filters), 975 ir_builder_.getInt64(output_rows), 976 ir_builder_.getInt64(output_cols), 977 ir_builder_.getInt64(row_stride), 978 ir_builder_.getInt64(col_stride), 979 ir_builder_.getInt64(padding_top), 980 ir_builder_.getInt64(padding_bottom), 981 ir_builder_.getInt64(padding_left), 982 ir_builder_.getInt64(padding_right), 983 ir_builder_.getInt64(lhs_row_dilation), 984 ir_builder_.getInt64(lhs_col_dilation), 985 ir_builder_.getInt64(rhs_row_dilation), 986 ir_builder_.getInt64(rhs_col_dilation), 987 }); 988 989 return Status::OK(); 990 } 991 } 992 993 // This is a completely un-optimized version of convolution just to 994 // have an early version that works. E.g. the input index and 995 // padding calculation is not hoisted out of the inner loop. 996 // 997 // See the description of convolution in the XLA documentation for the pseudo 998 // code for convolution. 999 return EmitTargetElementLoop( 1000 convolution, [this, convolution, lhs, rhs, window, 1001 dnums](const llvm_ir::IrArray::Index& index) { 1002 int num_spatial_dims = dnums.output_spatial_dimensions_size(); 1003 std::vector<llvm::Value*> output_spatial(num_spatial_dims); 1004 for (int i = 0; i < num_spatial_dims; ++i) { 1005 output_spatial[i] = index[dnums.output_spatial_dimensions(i)]; 1006 } 1007 llvm::Value* output_feature = index[dnums.output_feature_dimension()]; 1008 llvm::Value* batch = index[dnums.output_batch_dimension()]; 1009 1010 // We will accumulate the products into this sum to calculate 1011 // the output entry at the given index. 1012 PrimitiveType lhs_element_type = lhs->shape().element_type(); 1013 llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry( 1014 llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_), 1015 "convolution_sum_address", &ir_builder_, 1016 MinimumAlignmentForPrimitiveType(lhs_element_type)); 1017 ir_builder_.CreateStore( 1018 llvm::ConstantFP::get(ir_builder_.getFloatTy(), 0.0), sum_address); 1019 1020 llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &ir_builder_); 1021 std::vector<llvm::Value*> kernel_spatial(num_spatial_dims); 1022 for (int i = 0; i < num_spatial_dims; ++i) { 1023 kernel_spatial[i] = 1024 loops 1025 .AddLoop(0, 1026 rhs->shape().dimensions( 1027 dnums.kernel_spatial_dimensions(i)), 1028 tensorflow::strings::StrCat("k", i)) 1029 ->GetIndVarValue(); 1030 } 1031 llvm::Value* input_feature = 1032 loops 1033 .AddLoop( 1034 0, lhs->shape().dimensions(dnums.input_feature_dimension()), 1035 "iz") 1036 ->GetIndVarValue(); 1037 1038 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); 1039 1040 // Calculate the spatial index in the input array, taking striding, 1041 // dilation and padding into account. An index in the padding will be 1042 // out of the bounds of the array. 1043 const auto calculate_input_index = 1044 [this](llvm::Value* output_index, llvm::Value* kernel_index, 1045 const WindowDimension& window_dim) { 1046 llvm::Value* strided_index = ir_builder_.CreateNSWMul( 1047 output_index, ir_builder_.getInt64(window_dim.stride())); 1048 llvm::Value* dilated_kernel_index = ir_builder_.CreateNSWMul( 1049 kernel_index, 1050 ir_builder_.getInt64(window_dim.window_dilation())); 1051 return ir_builder_.CreateNSWSub( 1052 ir_builder_.CreateNSWAdd(strided_index, dilated_kernel_index), 1053 ir_builder_.getInt64(window_dim.padding_low())); 1054 }; 1055 std::vector<llvm::Value*> input_spatial(num_spatial_dims); 1056 for (int i = 0; i < num_spatial_dims; ++i) { 1057 input_spatial[i] = calculate_input_index( 1058 output_spatial[i], kernel_spatial[i], window.dimensions(i)); 1059 } 1060 1061 // We need to check if 0 <= input dim < bound, as otherwise we are in 1062 // the padding so that we can skip the computation. That is equivalent 1063 // to input dim < bound as an *unsigned* comparison, since a negative 1064 // value will wrap to a large positive value. The input dim is dilated, 1065 // so we need to dilate the bound as well to match. 1066 1067 // Also need to check that the input coordinates are not in one of the 1068 // holes created by base dilation. 1069 const auto not_in_hole = [&](llvm::Value* input_index, 1070 int64 base_dilation) { 1071 llvm::Value* remainder = ir_builder_.CreateSRem( 1072 input_index, ir_builder_.getInt64(base_dilation)); 1073 return ir_builder_.CreateICmpEQ(remainder, ir_builder_.getInt64(0)); 1074 }; 1075 1076 llvm::Value* in_bounds_condition = ir_builder_.getInt1(true); 1077 for (int i = 0; i < num_spatial_dims; ++i) { 1078 llvm::ConstantInt* input_bound = 1079 ir_builder_.getInt64(window_util::DilatedBound( 1080 lhs->shape().dimensions(dnums.input_spatial_dimensions(i)), 1081 window.dimensions(i).base_dilation())); 1082 llvm::Value* dim_in_bound = 1083 ir_builder_.CreateICmpULT(input_spatial[i], input_bound); 1084 llvm::Value* dim_not_in_hole = not_in_hole( 1085 input_spatial[i], window.dimensions(i).base_dilation()); 1086 llvm::Value* dim_ok = 1087 ir_builder_.CreateAnd(dim_in_bound, dim_not_in_hole); 1088 in_bounds_condition = 1089 ir_builder_.CreateAnd(in_bounds_condition, dim_ok); 1090 } 1091 1092 // Now we need to map the dilated base coordinates back to the actual 1093 // data indices on the lhs. 1094 const auto undilate = [&](llvm::Value* input_index, 1095 int64 base_dilation) { 1096 return ir_builder_.CreateSDiv(input_index, 1097 ir_builder_.getInt64(base_dilation)); 1098 }; 1099 for (int i = 0; i < num_spatial_dims; ++i) { 1100 input_spatial[i] = 1101 undilate(input_spatial[i], window.dimensions(i).base_dilation()); 1102 } 1103 1104 llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( 1105 in_bounds_condition, "in-bounds", &ir_builder_); 1106 SetToFirstInsertPoint(if_data.true_block, &ir_builder_); 1107 1108 // We are not in the padding, so carry out the computation. 1109 int num_dims = num_spatial_dims + 2; 1110 llvm_ir::IrArray::Index input_index(num_dims); 1111 for (int i = 0; i < num_spatial_dims; ++i) { 1112 input_index[dnums.input_spatial_dimensions(i)] = input_spatial[i]; 1113 } 1114 input_index[dnums.input_feature_dimension()] = input_feature; 1115 input_index[dnums.input_batch_dimension()] = batch; 1116 1117 llvm_ir::IrArray kernel_array(GetIrArrayFor(rhs)); 1118 llvm_ir::IrArray::Index kernel_index(num_dims); 1119 for (int i = 0; i < num_spatial_dims; ++i) { 1120 kernel_index[dnums.kernel_spatial_dimensions(i)] = 1121 window.dimensions(i).window_reversal() 1122 ? ir_builder_.CreateNSWSub( 1123 ir_builder_.getInt64(window.dimensions(i).size() - 1), 1124 kernel_spatial[i]) 1125 : kernel_spatial[i]; 1126 } 1127 1128 kernel_index[dnums.kernel_input_feature_dimension()] = input_feature; 1129 kernel_index[dnums.kernel_output_feature_dimension()] = output_feature; 1130 1131 llvm_ir::IrArray input_array(GetIrArrayFor(lhs)); 1132 llvm::Value* product = ir_builder_.CreateFMul( 1133 input_array.EmitReadArrayElement(input_index, &ir_builder_), 1134 kernel_array.EmitReadArrayElement(kernel_index, &ir_builder_)); 1135 llvm::Value* sum = ir_builder_.CreateFAdd( 1136 ir_builder_.CreateLoad(sum_address), product); 1137 ir_builder_.CreateStore(sum, sum_address); 1138 1139 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_); 1140 return ir_builder_.CreateLoad(sum_address); 1141 }); 1142 } 1143 1144 Status IrEmitter::HandleFft(HloInstruction* fft) { 1145 auto operand = fft->operand(0); 1146 TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( 1147 /*instruction=*/*fft, /*operands=*/{operand}, 1148 /*supported_types=*/{F32, C64})); 1149 TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(operand->shape().layout())); 1150 TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(fft->shape().layout())); 1151 VLOG(3) << "operand=" << ShapeUtil::HumanStringWithLayout(operand->shape()); 1152 VLOG(3) << "fft=" << ShapeUtil::HumanStringWithLayout(fft->shape()); 1153 1154 llvm::Value* operand_address = GetEmittedValueFor(operand); 1155 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fft)); 1156 1157 const std::vector<int64>& fft_length = fft->fft_length(); 1158 int64 input_batch = 1; 1159 for (int i = 0; i < fft->shape().dimensions_size() - fft_length.size(); i++) { 1160 input_batch *= fft->shape().dimensions(i); 1161 } 1162 1163 // Args have been computed, make the call. 1164 llvm::Type* int8_ptr_type = ir_builder_.getInt8Ty()->getPointerTo(); 1165 llvm::Type* int32_type = ir_builder_.getInt32Ty(); 1166 llvm::Type* int64_type = ir_builder_.getInt64Ty(); 1167 llvm::FunctionType* fft_type = llvm::FunctionType::get( 1168 ir_builder_.getVoidTy(), 1169 {int8_ptr_type, int8_ptr_type, int8_ptr_type, int32_type, int32_type, 1170 int64_type, int64_type, int64_type, int64_type}, 1171 /*isVarArg=*/false); 1172 const char* fn_name = runtime::kEigenFftSymbolName; 1173 llvm::Function* fft_func = llvm::cast<llvm::Function>( 1174 module_->getOrInsertFunction(fn_name, fft_type)); 1175 fft_func->setCallingConv(llvm::CallingConv::C); 1176 fft_func->setDoesNotThrow(); 1177 fft_func->setOnlyAccessesInaccessibleMemOrArgMem(); 1178 const int fft_rank = fft_length.size(); 1179 ir_builder_.CreateCall( 1180 fft_func, 1181 {GetExecutableRunOptionsArgument(), 1182 ir_builder_.CreateBitCast(GetEmittedValueFor(fft), int8_ptr_type), 1183 ir_builder_.CreateBitCast(operand_address, int8_ptr_type), 1184 ir_builder_.getInt32(fft->fft_type()), ir_builder_.getInt32(fft_rank), 1185 ir_builder_.getInt64(input_batch), 1186 ir_builder_.getInt64(fft_rank > 0 ? fft_length[0] : 0), 1187 ir_builder_.getInt64(fft_rank > 1 ? fft_length[1] : 0), 1188 ir_builder_.getInt64(fft_rank > 2 ? fft_length[2] : 0)}); 1189 1190 return Status::OK(); 1191 } 1192 1193 Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { 1194 if (hlo_module_config_.replica_count() == 1) { 1195 // When there is a single replica, a cross replica sum is the identity 1196 // function, and the buffer assignment expects a copy (we could eliminate 1197 // these at the HLO level as an optimization). 1198 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs)); 1199 return EmitMemcpy(*crs->operand(0), *crs); 1200 } 1201 1202 // TODO(b/33011107): Support cross replica sum on CPU. 1203 return Unimplemented("CrossReplicaSum is not implemented on CPU."); 1204 } 1205 1206 // Fills up the free variables in 'index_with_free_var' with values from 1207 // 'filler_index'. The size of free variables must be the same as the 1208 // size of 'filler_index'. 1209 // 1210 // This is often used after dimension reduction, where 1211 // 'index_with_free_var' has one or more dimensions reduced, which serves as 1212 // free variables (represented as nullptr). For example, if we have a 4 1213 // dimensional input and index for the dimension being reduced is 1214 // 2 (third dimension), we will have an index like [i, j, NULL, k] 1215 // after reduced dimension. 1216 // 1217 // Here we fill up that free variable by 'filler_index', which contains 1218 // the value in the reduced dimension. 1219 static llvm_ir::IrArray::Index FillReducedDimensionIndex( 1220 llvm_ir::IrArray::Index index_with_free_var, 1221 llvm_ir::IrArray::Index filler_index) { 1222 llvm_ir::IrArray::Index::const_iterator it = filler_index.begin(); 1223 1224 for (size_t i = 0; i < index_with_free_var.size(); ++i) { 1225 if (index_with_free_var[i] == nullptr) { 1226 index_with_free_var[i] = *it++; 1227 } 1228 } 1229 CHECK(filler_index.end() == it); 1230 return index_with_free_var; 1231 } 1232 1233 Status IrEmitter::HandleParameter(HloInstruction* parameter) { 1234 VLOG(2) << "HandleParameter: " << parameter->ToString(); 1235 auto param_number = parameter->parameter_number(); 1236 auto param_shape = parameter->shape(); 1237 1238 // We have to access the parameter at offset param_number in the params 1239 // array. The code generated here is equivalent to this C code: 1240 // 1241 // i8* param_address_untyped = params[param_number]; 1242 // Param* param_address_typed = (Param*)param_address_untyped; 1243 // 1244 // Where Param is the actual element type of the underlying buffer (for 1245 // example, float for an XLA F32 element type). 1246 llvm::Value* params = compute_function_->parameters_arg(); 1247 llvm::Value* param_address_offset = 1248 llvm_ir::EmitBufferIndexingGEP(params, param_number, &ir_builder_); 1249 llvm::LoadInst* param_address_untyped = 1250 ir_builder_.CreateLoad(param_address_offset); 1251 param_address_untyped->setName(AsStringRef(IrName(parameter, "untyped"))); 1252 if (is_top_level_computation_ && 1253 hlo_module_config_.debug_options() 1254 .xla_llvm_enable_invariant_load_metadata()) { 1255 // In the entry computation the parameter slots in the %params argument are 1256 // invariant through program execution. In computations that are called 1257 // from the entry computation (via kWhile, kCall and kConditional) the 1258 // parameter slots are *not* invariant since they're written to by their 1259 // callers. 1260 param_address_untyped->setMetadata( 1261 llvm::LLVMContext::MD_invariant_load, 1262 llvm::MDNode::get(param_address_untyped->getContext(), /*MDs=*/{})); 1263 } 1264 1265 llvm::Value* param_address_typed = ir_builder_.CreateBitCast( 1266 param_address_untyped, IrShapeType(param_shape)->getPointerTo()); 1267 emitted_value_[parameter] = param_address_typed; 1268 1269 if (!ShapeUtil::IsOpaque(param_shape)) { 1270 AttachAlignmentMetadataForLoad(param_address_untyped, param_shape); 1271 AttachDereferenceableMetadataForLoad(param_address_untyped, param_shape); 1272 } 1273 1274 VLOG(2) << " emitted value: " << llvm_ir::DumpToString(*param_address_typed); 1275 return Status::OK(); 1276 } 1277 1278 // Returns true if the relative order of the unreduced dimensions stays the same 1279 // through the reduce operation. 1280 static bool ReductionPreservesLayout(const HloInstruction& reduce) { 1281 DCHECK_EQ(reduce.opcode(), HloOpcode::kReduce); 1282 1283 // Maps dimensions that were not reduced from their dimension numbers in the 1284 // source shape to their dimensions numbers in the destination shape. 1285 // 1286 // So if we reduce f32[A,B,C,D] on dimensions 1 and 2, this map contains 1287 // [0->0, 3->1]. 1288 gtl::FlatMap<int64, int64> unreduced_dim_map; 1289 1290 gtl::FlatSet<int64> reduced_dims(reduce.dimensions().begin(), 1291 reduce.dimensions().end()); 1292 1293 const Shape& operand_shape = reduce.operand(0)->shape(); 1294 const Shape& result_shape = reduce.shape(); 1295 1296 int64 delta = 0; 1297 for (int64 i = 0; i < operand_shape.dimensions_size(); i++) { 1298 if (reduced_dims.count(i)) { 1299 delta++; 1300 } else { 1301 InsertOrDie(&unreduced_dim_map, i, i - delta); 1302 } 1303 } 1304 1305 // Iterate dimensions minor to major and check that the corresponding 1306 // dimensions in the source and target shapes are equivalent. 1307 int64 result_dim_idx = 0; 1308 for (int64 operand_dim_idx = 0; 1309 operand_dim_idx < operand_shape.dimensions_size(); operand_dim_idx++) { 1310 int64 operand_dim = operand_shape.layout().minor_to_major(operand_dim_idx); 1311 if (!reduced_dims.count(operand_dim)) { 1312 if (FindOrDie(unreduced_dim_map, operand_dim) != 1313 result_shape.layout().minor_to_major(result_dim_idx++)) { 1314 return false; 1315 } 1316 } 1317 } 1318 1319 CHECK_EQ(result_dim_idx, result_shape.dimensions_size()); 1320 1321 return true; 1322 } 1323 1324 IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator( 1325 HloComputation* function, string* failure_reason) const { 1326 CHECK_EQ(function->num_parameters(), 2); 1327 1328 auto root_instruction = function->root_instruction(); 1329 CHECK(ShapeUtil::IsScalar(root_instruction->shape())); 1330 1331 if (root_instruction->operand_count() != 2) { 1332 *failure_reason = "root instruction is not a binary operation"; 1333 return nullptr; 1334 } 1335 1336 const Shape& root_shape = root_instruction->shape(); 1337 if (ShapeUtil::ElementIsComplex(root_shape)) { 1338 // TODO(b/65408531): Complex add could by done via bitcast to <float x [2N]> 1339 // Complex multiply would be more challenging. We could perhaps use a 1340 // strided load to get all reals in a vector, all images in a vector, or use 1341 // CreateShuffleVector on a bitcast to float x [2N]. 1342 *failure_reason = "complex values not supported"; 1343 return nullptr; 1344 } 1345 bool root_is_floating_point = ShapeUtil::ElementIsFloating(root_shape); 1346 bool root_is_integral = ShapeUtil::ElementIsIntegral(root_shape); 1347 bool root_is_signed = ShapeUtil::ElementIsSigned(root_shape); 1348 1349 auto lhs = root_instruction->operand(0); 1350 auto rhs = root_instruction->operand(1); 1351 1352 auto param_0 = function->parameter_instruction(0); 1353 auto param_1 = function->parameter_instruction(1); 1354 if (!(lhs == param_0 && rhs == param_1) && 1355 !(rhs == param_0 && lhs == param_1)) { 1356 *failure_reason = 1357 "root instruction is not a binary operation on the incoming arguments"; 1358 return nullptr; 1359 } 1360 1361 CHECK(ShapeUtil::IsScalar(lhs->shape()) && ShapeUtil::IsScalar(rhs->shape())); 1362 1363 // This is visually similar to ElementalIrEmitter, though conceptually we're 1364 // doing something different here. ElementalIrEmitter emits scalar operations 1365 // while these emit scalar or vector operations depending on the type of the 1366 // operands. See CreateShardedVectorType for the actual types in use here. 1367 switch (root_instruction->opcode()) { 1368 default: 1369 *failure_reason = "did not recognize root instruction opcode"; 1370 return nullptr; 1371 1372 case HloOpcode::kAdd: 1373 return [root_is_integral](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs, 1374 llvm::Value* rhs) { 1375 return root_is_integral ? ir_builder->CreateAdd(lhs, rhs) 1376 : ir_builder->CreateFAdd(lhs, rhs); 1377 }; 1378 1379 case HloOpcode::kMultiply: 1380 return [root_is_integral](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs, 1381 llvm::Value* rhs) { 1382 return root_is_integral ? ir_builder->CreateMul(lhs, rhs) 1383 : ir_builder->CreateFMul(lhs, rhs); 1384 }; 1385 1386 case HloOpcode::kAnd: 1387 return [](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs, 1388 llvm::Value* rhs) { return ir_builder->CreateAnd(lhs, rhs); }; 1389 1390 case HloOpcode::kOr: 1391 return [](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs, 1392 llvm::Value* rhs) { return ir_builder->CreateOr(lhs, rhs); }; 1393 1394 case HloOpcode::kMaximum: 1395 return [root_is_floating_point, root_is_signed]( 1396 llvm::IRBuilder<>* ir_builder, llvm::Value* lhs, 1397 llvm::Value* rhs) { 1398 if (root_is_floating_point) { 1399 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::maxnum, 1400 {lhs, rhs}, {lhs->getType()}, 1401 ir_builder); 1402 } 1403 1404 return ir_builder->CreateSelect( 1405 ir_builder->CreateICmp(root_is_signed ? llvm::ICmpInst::ICMP_SGE 1406 : llvm::ICmpInst::ICMP_UGE, 1407 lhs, rhs), 1408 lhs, rhs); 1409 }; 1410 1411 case HloOpcode::kMinimum: 1412 return [root_is_floating_point, root_is_signed]( 1413 llvm::IRBuilder<>* ir_builder, llvm::Value* lhs, 1414 llvm::Value* rhs) { 1415 if (root_is_floating_point) { 1416 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::minnum, 1417 {lhs, rhs}, {lhs->getType()}, 1418 ir_builder); 1419 } 1420 1421 return ir_builder->CreateSelect( 1422 ir_builder->CreateICmp(root_is_signed ? llvm::ICmpInst::ICMP_SLE 1423 : llvm::ICmpInst::ICMP_ULE, 1424 lhs, rhs), 1425 lhs, rhs); 1426 }; 1427 } 1428 } 1429 1430 IrEmitter::ShardedVectorType IrEmitter::CreateShardedVectorType( 1431 PrimitiveType element_type, unsigned element_count) { 1432 int vector_register_size_in_elements = 1433 target_machine_features_.vector_register_byte_size( 1434 *compute_function_->function()) / 1435 ShapeUtil::ByteSizeOfPrimitiveType(element_type); 1436 1437 ShardedVectorType sharded_vector_type; 1438 llvm::Type* element_ir_type = 1439 llvm_ir::PrimitiveTypeToIrType(element_type, module_); 1440 1441 for (int i = 0, e = 1 + tensorflow::Log2Ceiling(element_count); i < e; i++) { 1442 // For every power of two present in element_count, we generate one or more 1443 // vector or scalar types. 1444 const unsigned current_size_fragment = 1u << i; 1445 if (!(element_count & current_size_fragment)) { 1446 // Power of two not present in element_count. 1447 continue; 1448 } 1449 1450 if (current_size_fragment == 1) { 1451 // Single element, use a scalar type. 1452 sharded_vector_type.push_back(element_ir_type); 1453 continue; 1454 } 1455 1456 // Lower "current_size_fragment" number of elements using (as few as 1457 // possible) vector registers. 1458 1459 if (current_size_fragment >= vector_register_size_in_elements) { 1460 auto vector_type = llvm::VectorType::get( 1461 element_ir_type, vector_register_size_in_elements); 1462 sharded_vector_type.insert( 1463 sharded_vector_type.end(), 1464 current_size_fragment / vector_register_size_in_elements, 1465 vector_type); 1466 1467 // Both current_size_fragment and vector_register_size_in_elements are 1468 // powers of two. 1469 CHECK_EQ(current_size_fragment % vector_register_size_in_elements, 0); 1470 continue; 1471 } 1472 1473 // For now we assume that vector_register_size_in_elements and lower powers 1474 // of two are all legal vector sizes (or at least can be lowered easily by 1475 // LLVM). 1476 sharded_vector_type.push_back( 1477 llvm::VectorType::get(element_ir_type, current_size_fragment)); 1478 } 1479 return sharded_vector_type; 1480 } 1481 1482 StatusOr<IrEmitter::ShardedVector> 1483 IrEmitter::EmitInnerLoopForVectorizedReduction( 1484 const ReductionGenerator& reduction_generator, 1485 const llvm_ir::IrArray::Index& output_index, 1486 const ShardedVectorType& accumulator_type, HloInstruction* init_value, 1487 HloInstruction* arg, gtl::ArraySlice<int64> dimensions, 1488 unsigned element_alignment) { 1489 ShardedVector accumulator; 1490 accumulator.reserve(accumulator_type.size()); 1491 for (auto accumulator_shard_type : accumulator_type) { 1492 accumulator.push_back(llvm_ir::EmitAllocaAtFunctionEntry( 1493 accumulator_shard_type, "accumulator", &ir_builder_, 0)); 1494 } 1495 1496 llvm::Value* init_value_ssa = 1497 ir_builder_.CreateLoad(GetEmittedValueFor(init_value)); 1498 1499 for (llvm::Value* accumulator_shard : accumulator) { 1500 llvm::Value* initial_value; 1501 auto shard_type = accumulator_shard->getType()->getPointerElementType(); 1502 if (auto vector_type = llvm::dyn_cast<llvm::VectorType>(shard_type)) { 1503 initial_value = ir_builder_.CreateVectorSplat( 1504 vector_type->getNumElements(), init_value_ssa); 1505 } else { 1506 initial_value = init_value_ssa; 1507 } 1508 1509 ir_builder_.CreateAlignedStore(initial_value, accumulator_shard, 1510 element_alignment); 1511 } 1512 1513 llvm_ir::ForLoopNest reduction_loop_nest(IrName(arg, "vectorized_inner"), 1514 &ir_builder_); 1515 llvm_ir::IrArray::Index reduced_dims_index = 1516 reduction_loop_nest.AddLoopsForShapeOnDimensions(arg->shape(), dimensions, 1517 "reduction_dim"); 1518 1519 SetToFirstInsertPoint(reduction_loop_nest.GetInnerLoopBodyBasicBlock(), 1520 &ir_builder_); 1521 1522 llvm_ir::IrArray arg_array(GetIrArrayFor(arg)); 1523 llvm_ir::IrArray::Index input_index = reduced_dims_index; 1524 llvm_ir::IrArray::Index::const_iterator it = output_index.begin(); 1525 1526 for (size_t i = 0; i < input_index.size(); ++i) { 1527 if (input_index[i] == nullptr) { 1528 input_index[i] = *it++; 1529 } 1530 } 1531 CHECK(output_index.end() == it); 1532 1533 llvm::Value* input_address = ir_builder_.CreateBitCast( 1534 arg_array.EmitArrayElementAddress(input_index, &ir_builder_), 1535 ir_builder_.getInt8PtrTy()); 1536 1537 for (int i = 0; i < accumulator.size(); i++) { 1538 auto input_address_typed = 1539 ir_builder_.CreateBitCast(input_address, accumulator[i]->getType()); 1540 auto current_accumulator_value = 1541 ir_builder_.CreateAlignedLoad(accumulator[i], element_alignment); 1542 auto addend = 1543 ir_builder_.CreateAlignedLoad(input_address_typed, element_alignment); 1544 arg_array.AnnotateLoadStoreInstructionWithMetadata(addend); 1545 1546 auto reduced_result = 1547 reduction_generator(&ir_builder_, current_accumulator_value, addend); 1548 ir_builder_.CreateAlignedStore(reduced_result, accumulator[i], 1549 element_alignment); 1550 1551 if (i != (accumulator.size() - 1)) { 1552 input_address = ir_builder_.CreateConstInBoundsGEP1_32( 1553 reduced_result->getType(), input_address_typed, 1); 1554 } 1555 } 1556 1557 SetToFirstInsertPoint(reduction_loop_nest.GetOuterLoopExitBasicBlock(), 1558 &ir_builder_); 1559 1560 ShardedVector result_ssa; 1561 result_ssa.reserve(accumulator.size()); 1562 for (auto accumulator_shard : accumulator) { 1563 result_ssa.push_back( 1564 ir_builder_.CreateAlignedLoad(accumulator_shard, element_alignment)); 1565 } 1566 return result_ssa; 1567 } 1568 1569 void IrEmitter::EmitShardedVectorStore( 1570 llvm::Value* store_address, const std::vector<llvm::Value*>& value_to_store, 1571 const int alignment, const llvm_ir::IrArray& containing_array) { 1572 for (int i = 0; i < value_to_store.size(); i++) { 1573 auto store_address_typed = ir_builder_.CreateBitCast( 1574 store_address, 1575 llvm::PointerType::getUnqual(value_to_store[i]->getType())); 1576 1577 auto store_instruction = ir_builder_.CreateAlignedStore( 1578 value_to_store[i], store_address_typed, alignment); 1579 containing_array.AnnotateLoadStoreInstructionWithMetadata( 1580 store_instruction); 1581 1582 if (i != (value_to_store.size() - 1)) { 1583 store_address = ir_builder_.CreateConstInBoundsGEP1_32( 1584 value_to_store[i]->getType(), store_address_typed, 1); 1585 } 1586 } 1587 } 1588 1589 StatusOr<bool> IrEmitter::EmitVectorizedReduce( 1590 HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, 1591 gtl::ArraySlice<int64> dimensions, HloComputation* function, 1592 string* failure_reason) { 1593 if (!ReductionPreservesLayout(*reduce)) { 1594 return false; 1595 } 1596 1597 ReductionGenerator reduction_generator = 1598 MatchReductionGenerator(function, failure_reason); 1599 if (!reduction_generator) { 1600 return false; 1601 } 1602 1603 int vectorization_factor_in_bytes = 1604 target_machine_features_.vectorization_factor_in_bytes(); 1605 1606 // We try to process vectorization_factor elements at the same time. 1607 const int vectorization_factor = 1608 vectorization_factor_in_bytes / 1609 ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()); 1610 1611 bool is_reduction_over_minor_dimension = 1612 std::find(dimensions.begin(), dimensions.end(), 1613 LayoutUtil::Minor(arg->shape().layout(), 0)) != 1614 dimensions.end(); 1615 1616 unsigned element_alignment = tensorflow::MathUtil::GCD<unsigned>( 1617 ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()), 1618 MinimumAlignmentForPrimitiveType(reduce->shape().element_type())); 1619 1620 if (is_reduction_over_minor_dimension) { 1621 // TODO(sanjoy): Implement vectorized reduction over the minor dimension. 1622 *failure_reason = "reduction over minor dimension not implemented"; 1623 return false; 1624 } 1625 1626 CHECK(!ShapeUtil::IsTuple(reduce->shape())); 1627 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(reduce)); 1628 1629 // We know we're not reducing over the most minor dimension, which means we 1630 // can lower the reduction loop as: 1631 // 1632 // 1. We're reducing over dimensions R0, R1. 1633 // 2. D0 is the most minor dimension. 1634 // 3. VS is the vectorization stride (we want to reduce this many elements at 1635 // once) 1636 // 1637 // for (d1 in D1) { 1638 // for (d0 in D0 with stride VS) { 1639 // vector_acc = init 1640 // for (r1 in R1) { 1641 // for (r0 in R0) { 1642 // vector_acc = elementwise_reduce(vector_acc, input[d1, d0, r1, r0] 1643 // } 1644 // } 1645 // output[d1, d0] = vector_acc 1646 // } 1647 // } 1648 1649 llvm_ir::ForLoopNest loop_nest(IrName(reduce), &ir_builder_); 1650 llvm_ir::IrArray::Index array_index(reduce->shape().dimensions_size()); 1651 for (int i = LayoutUtil::MinorToMajor(reduce->shape()).size() - 1; i > 0; 1652 --i) { 1653 int64 dimension = LayoutUtil::Minor(reduce->shape().layout(), i); 1654 int64 start_index = 0; 1655 int64 end_index = reduce->shape().dimensions(dimension); 1656 std::unique_ptr<llvm_ir::ForLoop> loop = 1657 loop_nest.AddLoop(start_index, end_index, 1658 tensorflow::strings::Printf("dim.%lld", dimension)); 1659 array_index[dimension] = loop->GetIndVarValue(); 1660 } 1661 1662 int64 innermost_dimension = LayoutUtil::Minor(reduce->shape().layout(), 0); 1663 int64 innermost_dimension_size = 1664 reduce->shape().dimensions(innermost_dimension); 1665 1666 if (llvm::BasicBlock* innermost_body_bb = 1667 loop_nest.GetInnerLoopBodyBasicBlock()) { 1668 SetToFirstInsertPoint(innermost_body_bb, &ir_builder_); 1669 } 1670 1671 auto outermost_loop_exit_block = loop_nest.GetOuterLoopExitBasicBlock(); 1672 1673 if (innermost_dimension_size >= vectorization_factor) { 1674 int64 start_index = 0; 1675 int64 end_index = (innermost_dimension_size / vectorization_factor) * 1676 vectorization_factor; 1677 std::unique_ptr<llvm_ir::ForLoop> loop = loop_nest.AddLoop( 1678 start_index, end_index, vectorization_factor, 1679 tensorflow::strings::Printf("dim.%lld", innermost_dimension)); 1680 array_index[innermost_dimension] = loop->GetIndVarValue(); 1681 1682 SetToFirstInsertPoint(loop->GetBodyBasicBlock(), &ir_builder_); 1683 1684 ShardedVectorType vector_type = CreateShardedVectorType( 1685 reduce->shape().element_type(), vectorization_factor); 1686 TF_ASSIGN_OR_RETURN(std::vector<llvm::Value*> accumulator, 1687 EmitInnerLoopForVectorizedReduction( 1688 reduction_generator, array_index, vector_type, 1689 init_value, arg, dimensions, element_alignment)); 1690 1691 llvm_ir::IrArray target_array = GetIrArrayFor(reduce); 1692 llvm::Value* output_address = 1693 target_array.EmitArrayElementAddress(array_index, &ir_builder_); 1694 EmitShardedVectorStore(output_address, accumulator, element_alignment, 1695 target_array); 1696 1697 if (auto exit_terminator = loop->GetExitBasicBlock()->getTerminator()) { 1698 CHECK_GT(LayoutUtil::MinorToMajor(reduce->shape()).size(), 1); 1699 ir_builder_.SetInsertPoint(exit_terminator); 1700 } else { 1701 CHECK_EQ(LayoutUtil::MinorToMajor(reduce->shape()).size(), 1); 1702 ir_builder_.SetInsertPoint(loop->GetExitBasicBlock()); 1703 } 1704 } 1705 1706 // Since we increment the stride for the inner dimension by more than 1, we 1707 // may need to peel out an "epilogue" iteration to get the remaining elements 1708 // in the following case: 1709 if (innermost_dimension_size % vectorization_factor) { 1710 // TODO(b/63775531): Consider using a scalar loop here to save on code size. 1711 array_index[innermost_dimension] = 1712 ir_builder_.getInt64(innermost_dimension_size - 1713 (innermost_dimension_size % vectorization_factor)); 1714 1715 ShardedVectorType vector_type = CreateShardedVectorType( 1716 reduce->shape().element_type(), 1717 innermost_dimension_size % vectorization_factor); 1718 TF_ASSIGN_OR_RETURN(std::vector<llvm::Value*> accumulator, 1719 EmitInnerLoopForVectorizedReduction( 1720 reduction_generator, array_index, vector_type, 1721 init_value, arg, dimensions, element_alignment)); 1722 1723 llvm_ir::IrArray target_array = GetIrArrayFor(reduce); 1724 llvm::Value* output_address = 1725 target_array.EmitArrayElementAddress(array_index, &ir_builder_); 1726 EmitShardedVectorStore(output_address, accumulator, element_alignment, 1727 target_array); 1728 } 1729 1730 if (outermost_loop_exit_block) { 1731 ir_builder_.SetInsertPoint(outermost_loop_exit_block); 1732 } 1733 1734 return true; 1735 } 1736 1737 Status IrEmitter::HandleReduce(HloInstruction* reduce) { 1738 auto arg = reduce->mutable_operand(0); 1739 auto init_value = reduce->mutable_operand(1); 1740 gtl::ArraySlice<int64> dimensions(reduce->dimensions()); 1741 HloComputation* function = reduce->to_apply(); 1742 if (!options::VectorizedReduceDisabled(hlo_module_config_)) { 1743 string vectorization_failure_reason; 1744 TF_ASSIGN_OR_RETURN( 1745 bool vectorization_successful, 1746 EmitVectorizedReduce(reduce, arg, init_value, dimensions, function, 1747 &vectorization_failure_reason)); 1748 if (vectorization_successful) { 1749 VLOG(1) << "Successfully vectorized reduction " << reduce->ToString() 1750 << "\n"; 1751 return Status::OK(); 1752 } else { 1753 VLOG(1) << "Could not vectorize reduction " << reduce->ToString() << ": " 1754 << vectorization_failure_reason; 1755 } 1756 } 1757 1758 // The called computation should have been emitted previously. 1759 llvm::Function* reducer_function = FindOrDie(emitted_functions_, function); 1760 return EmitTargetElementLoop( 1761 reduce, [this, reduce, arg, init_value, dimensions, 1762 reducer_function](const llvm_ir::IrArray::Index& index) { 1763 // Initialize an accumulator with init_value. 1764 PrimitiveType accumulator_type = reduce->shape().element_type(); 1765 llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry( 1766 llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_), 1767 "accumulator", &ir_builder_, 1768 MinimumAlignmentForPrimitiveType(accumulator_type)); 1769 llvm::Value* init_value_addr = GetEmittedValueFor(init_value); 1770 llvm::Value* load_init_value = ir_builder_.CreateLoad(init_value_addr); 1771 ir_builder_.CreateStore(load_init_value, accumulator_addr); 1772 1773 // The enclosing loops go over all the target elements. Now we have to 1774 // compute the actual target element. For this, we build a new loop nest 1775 // to iterate over all the reduction dimensions in the argument. 1776 // AddLoopsForShapeOnDimensions will return an Index where induction 1777 // Value*s are placed for each dimension in dimensions, and all the rest 1778 // are nullptrs. 1779 llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &ir_builder_); 1780 const llvm_ir::IrArray::Index reduced_dims_index = 1781 loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions, 1782 "reduction_dim"); 1783 1784 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); 1785 1786 // Build a full index for the input argument, using reduced_dims_index 1787 // as the base. In reduced_dims_index only the reduction dimensions are 1788 // filled in. We fill in the rest of the dimensions with induction 1789 // Value*s taken from 'index' which iterates over the target array. 1790 // See the high-level description in the XLA documentation for details. 1791 llvm_ir::IrArray arg_array(GetIrArrayFor(arg)); 1792 llvm_ir::IrArray::Index input_index = reduced_dims_index; 1793 llvm_ir::IrArray::Index::const_iterator it = index.begin(); 1794 1795 for (size_t i = 0; i < input_index.size(); ++i) { 1796 if (input_index[i] == nullptr) { 1797 input_index[i] = *it++; 1798 } 1799 } 1800 CHECK(index.end() == it); 1801 1802 // Apply the reduction function to the loaded value. 1803 llvm::Value* input_address = 1804 arg_array.EmitArrayElementAddress(input_index, &ir_builder_); 1805 llvm::Value* result = EmitElementFunctionCall( 1806 reducer_function, reduce->shape(), 1807 {accumulator_addr, input_address}, "reduce_function"); 1808 ir_builder_.CreateStore(result, accumulator_addr); 1809 1810 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_); 1811 return ir_builder_.CreateLoad(accumulator_addr); 1812 }); 1813 } 1814 1815 Status IrEmitter::HandleSend(HloInstruction* send) { 1816 // TODO(b/33942983): Support Send/Recv on CPU. 1817 return Unimplemented("Send is not implemented on CPU."); 1818 } 1819 1820 Status IrEmitter::HandleSendDone(HloInstruction* send_done) { 1821 // TODO(b/33942983): Support Send/Recv on CPU. 1822 return Unimplemented("Send-done is not implemented on CPU."); 1823 } 1824 1825 Status IrEmitter::HandleSlice(HloInstruction* slice) { 1826 VLOG(2) << "HandleSlice: " << slice->ToString(); 1827 auto operand = slice->operand(0); 1828 // The code below emits a sequential loop nest. For the parallel backend, use 1829 // ParallelLoopEmitter which respects dynamic loop bounds. 1830 if (ShouldEmitParallelLoopFor(*slice)) { 1831 return DefaultAction(slice); 1832 } 1833 1834 // The code below assumes the layouts are equal. 1835 if (!LayoutUtil::Equal(operand->shape().layout(), slice->shape().layout())) { 1836 return DefaultAction(slice); 1837 } 1838 1839 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(slice)); 1840 1841 if (ShapeUtil::HasZeroElements(slice->shape())) { 1842 return Status::OK(); 1843 } 1844 1845 const Layout& layout = operand->shape().layout(); 1846 const int64 num_dims = operand->shape().dimensions_size(); 1847 1848 // The slice lowering finds maximal contiguous blocks of memory that can be 1849 // copied from the source to the target. This is done by looking at the 1850 // source/target layout in minor to major order and do the following: 1851 // 1852 // * Find an initial segment of dimensions along which the slice uses the 1853 // whole dimension. These are the "inner" dimensions and can be folded into 1854 // the memcpy. 1855 // 1856 // * Of the remaining dimensions decide which ones require loops. 1857 // 1858 // * Implement the memcpy within the innermost loop. 1859 1860 gtl::FlatSet<int64> inner_dims; 1861 for (int64 dim : LayoutUtil::MinorToMajor(layout)) { 1862 if (operand->shape().dimensions(dim) != slice->shape().dimensions(dim)) { 1863 break; 1864 } 1865 inner_dims.insert(dim); 1866 } 1867 1868 const bool is_trivial_copy = (inner_dims.size() == num_dims); 1869 if (is_trivial_copy) { 1870 if (ShapeUtil::IsEffectiveScalar(slice->shape())) { 1871 return DefaultAction(slice); 1872 } else { 1873 return EmitMemcpy(*slice, *operand); 1874 } 1875 } 1876 1877 // The memcpy will copy elements that are logically this shape (allowed to be 1878 // scalar). 1879 const Shape logical_element_shape = ShapeUtil::FilterDimensions( 1880 [&inner_dims](int64 dim) -> bool { return inner_dims.count(dim); }, 1881 operand->shape()); 1882 1883 const int64 primitive_elements_per_logical_element = 1884 ShapeUtil::ElementsIn(logical_element_shape); 1885 1886 // memcpy_dim is the innermost (in terms of layout) dimension for which the 1887 // slice does *not* just copy all the elements along the dimension. 1888 const int64 memcpy_dim = LayoutUtil::Minor(layout, inner_dims.size()); 1889 1890 const bool memcpy_is_contiguous = slice->slice_strides(memcpy_dim) == 1; 1891 // The number of logical elements that can be copied in a single call 1892 // to memcpy. We can only copy 1 element at a time if there is a non-trivial 1893 // stride. 1894 const int64 memcpy_logical_elements = 1895 memcpy_is_contiguous 1896 ? slice->slice_limits(memcpy_dim) - slice->slice_starts(memcpy_dim) 1897 : 1; 1898 1899 // Determine the dimensions that get lowered as loops. 1900 std::vector<int64> outer_dims; 1901 for (int64 i = 0; i < num_dims - inner_dims.size() - 1; ++i) { 1902 outer_dims.push_back(LayoutUtil::Major(layout, i)); 1903 } 1904 1905 // Is the slice along the memcpy dimension contiguous? If not, then memcpy_dim 1906 // needs to be wrapped around a loop as well. 1907 if (!memcpy_is_contiguous) { 1908 outer_dims.push_back(memcpy_dim); 1909 } 1910 1911 llvm_ir::IrArray target_array = GetIrArrayFor(slice); 1912 1913 const int64 num_outer_loops = outer_dims.size(); 1914 llvm_ir::ForLoopNest loops(IrName(slice), &ir_builder_); 1915 llvm_ir::IrArray::Index target_index = 1916 loops.AddLoopsForShapeOnDimensions(slice->shape(), outer_dims, "slice"); 1917 1918 // Only the indices for the outer dimensions have been initialized in 1919 // target_index. The rest of the indices should get initialized to 0, since 1920 // for the rest of the dimensions the copy writes to the full dimension. 1921 std::replace(target_index.begin(), target_index.end(), 1922 static_cast<llvm::Value*>(nullptr), 1923 static_cast<llvm::Value*>(ir_builder_.getInt64(0))); 1924 1925 if (num_outer_loops > 0) { 1926 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); 1927 } 1928 1929 llvm_ir::IrArray source_array = GetIrArrayFor(operand); 1930 const llvm_ir::IrArray::Index source_index = target_index.SourceIndexOfSlice( 1931 /*shape=*/slice->shape(), /*starts=*/slice->slice_starts(), 1932 /*strides=*/slice->slice_strides(), /*builder=*/&ir_builder_); 1933 1934 llvm::Value* memcpy_dest = target_array.EmitArrayElementAddress( 1935 target_index, &ir_builder_, "slice.dest"); 1936 llvm::Value* memcpy_source = source_array.EmitArrayElementAddress( 1937 source_index, &ir_builder_, "slice.source"); 1938 1939 const int64 memcpy_elements = 1940 primitive_elements_per_logical_element * memcpy_logical_elements; 1941 1942 EmitTransferElements(memcpy_dest, memcpy_source, memcpy_elements, 1943 slice->shape().element_type(), target_array, 1944 source_array); 1945 1946 if (VLOG_IS_ON(2)) { 1947 const int64 memcpy_bytes = 1948 ShapeUtil::ByteSizeOf(logical_element_shape) * memcpy_elements; 1949 VLOG(2) << " emitted copy of " << memcpy_bytes << " bytes inside " 1950 << num_outer_loops << " loops"; 1951 } 1952 1953 if (num_outer_loops > 0) { 1954 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_); 1955 } 1956 1957 return Status::OK(); 1958 } 1959 1960 Status IrEmitter::HandleDynamicSlice(HloInstruction* dynamic_slice) { 1961 if (ShapeUtil::IsScalar(dynamic_slice->shape())) { 1962 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_slice)); 1963 return EmitMemcpy(*dynamic_slice->operand(0), *dynamic_slice); 1964 } 1965 return DefaultAction(dynamic_slice); 1966 } 1967 1968 Status IrEmitter::HandleDynamicUpdateSlice( 1969 HloInstruction* dynamic_update_slice) { 1970 auto update = dynamic_update_slice->operand(1); 1971 if (ShapeUtil::IsScalar(dynamic_update_slice->shape())) { 1972 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_update_slice)); 1973 return EmitMemcpy(*update, *dynamic_update_slice); 1974 } else if (llvm_ir::CanUpdateDynamicSliceInPlace(dynamic_update_slice, 1975 assignment_)) { 1976 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_update_slice)); 1977 auto operands = GetIrArraysForOperandsOf(dynamic_update_slice); 1978 return llvm_ir::EmitDynamicUpdateSliceInPlace( 1979 operands, GetIrArrayFor(dynamic_update_slice), 1980 IrName(dynamic_update_slice, "in_place"), &ir_builder_); 1981 } 1982 return DefaultAction(dynamic_update_slice); 1983 } 1984 1985 Status IrEmitter::HandleRecv(HloInstruction* recv) { 1986 // TODO(b/33942983): Support Send/Recv on CPU. 1987 return Unimplemented("Recv is not implemented on CPU."); 1988 } 1989 1990 Status IrEmitter::HandleRecvDone(HloInstruction* recv_done) { 1991 // TODO(b/33942983): Support Send/Recv on CPU. 1992 return Unimplemented("Recv-done is not implemented on CPU."); 1993 } 1994 1995 Status IrEmitter::HandlePad(HloInstruction* pad) { 1996 // CPU backend does not properly handle negative padding but this is ok 1997 // because negative padding should be removed by the algebraic simplifier. 1998 for (auto& padding_dimension : pad->padding_config().dimensions()) { 1999 if (padding_dimension.edge_padding_low() < 0 || 2000 padding_dimension.edge_padding_high() < 0) { 2001 return InternalErrorStrCat( 2002 "Encountered negative padding in IrEmitter on CPU. " 2003 "This should have been eliminated at the HLO level. ", 2004 pad->ToString()); 2005 } 2006 } 2007 2008 // First, fill in the padding value to all output elements. 2009 TF_RETURN_IF_ERROR(EmitTargetElementLoop( 2010 pad, "initialize", 2011 [this, pad](const llvm_ir::IrArray::Index& target_index) { 2012 const HloInstruction* padding_value = pad->operand(1); 2013 llvm::Value* padding_value_addr = GetEmittedValueFor(padding_value); 2014 return ir_builder_.CreateLoad(padding_value_addr); 2015 })); 2016 2017 // Create a loop to iterate over the operand elements and update the output 2018 // locations where the operand elements should be stored. 2019 llvm_ir::ForLoopNest loops(IrName(pad, "assign"), &ir_builder_); 2020 const HloInstruction* operand = pad->operand(0); 2021 const llvm_ir::IrArray::Index operand_index = 2022 loops.AddLoopsForShape(operand->shape(), "operand"); 2023 2024 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); 2025 2026 // Load an element from the operand. 2027 llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); 2028 llvm::Value* operand_data = 2029 operand_array.EmitReadArrayElement(operand_index, &ir_builder_); 2030 2031 // Compute the output index the operand element should be assigned to. 2032 // output_index := edge_padding_low + operand_index * (interior_padding + 1) 2033 const PaddingConfig& padding_config = pad->padding_config(); 2034 llvm_ir::IrArray::Index output_index; 2035 for (size_t i = 0; i < operand_index.size(); ++i) { 2036 llvm::Value* offset = ir_builder_.CreateMul( 2037 operand_index[i], 2038 ir_builder_.getInt64(padding_config.dimensions(i).interior_padding() + 2039 1)); 2040 llvm::Value* index = ir_builder_.CreateAdd( 2041 offset, 2042 ir_builder_.getInt64(padding_config.dimensions(i).edge_padding_low())); 2043 output_index.push_back(index); 2044 } 2045 2046 // Store the operand element to the computed output location. 2047 llvm_ir::IrArray output_array(GetIrArrayFor(pad)); 2048 output_array.EmitWriteArrayElement(output_index, operand_data, &ir_builder_); 2049 2050 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_); 2051 return Status::OK(); 2052 } 2053 2054 // If `hlo` is a Transpose, returns its operand; otherwise returns `hlo` itself. 2055 static const HloInstruction* StripTranspose(const HloInstruction& hlo) { 2056 if (hlo.IsRank2Transpose()) { 2057 return hlo.operand(0); 2058 } 2059 return &hlo; 2060 } 2061 2062 Status IrEmitter::HandleFusion(HloInstruction* fusion) { 2063 auto* root = fusion->fused_expression_root(); 2064 if (fusion->fusion_kind() == HloInstruction::FusionKind::kTransposeDot) { 2065 DCHECK(root->opcode() == HloOpcode::kDot); 2066 const HloInstruction* lhs_parameter = StripTranspose(*root->operand(0)); 2067 const HloInstruction* rhs_parameter = StripTranspose(*root->operand(1)); 2068 DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter && 2069 rhs_parameter->opcode() == HloOpcode::kParameter); 2070 const HloInstruction* lhs = 2071 fusion->operand(lhs_parameter->parameter_number()); 2072 const HloInstruction* rhs = 2073 fusion->operand(rhs_parameter->parameter_number()); 2074 2075 TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( 2076 /*instruction=*/*root, /*operands=*/{lhs, rhs}, 2077 /*supported_types=*/{F32})); 2078 2079 llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs)); 2080 llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs)); 2081 2082 Shape target_shape = fusion->shape(); 2083 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion)); 2084 llvm_ir::IrArray target_array = GetIrArrayFor(fusion); 2085 VLOG(2) << "HandleFusion kTransposeDot: "; 2086 VLOG(2) << " lhs operand: " 2087 << llvm_ir::DumpToString(*lhs_array.GetBasePointer()); 2088 VLOG(2) << " rhs operand: " 2089 << llvm_ir::DumpToString(*rhs_array.GetBasePointer()); 2090 VLOG(2) << " target: " 2091 << llvm_ir::DumpToString(*target_array.GetBasePointer()); 2092 2093 // Dot operation is complicated so we delegate to a helper class. 2094 TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( 2095 *root, root->operand(0)->IsRank2Transpose(), 2096 root->operand(1)->IsRank2Transpose(), target_array, lhs_array, 2097 rhs_array, /*addend_array=*/nullptr, GetExecutableRunOptionsArgument(), 2098 &ir_builder_, hlo_module_config_, target_machine_features_)); 2099 return Status::OK(); 2100 } else if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, 2101 assignment_)) { 2102 VLOG(3) << "HandleFusion FusedDynamicUpdateSliceInPlace"; 2103 CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); 2104 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion)); 2105 2106 // Delegate to common implementation of fused in-place dynamic-update-slice. 2107 auto operands = GetIrArraysForOperandsOf(fusion); 2108 return llvm_ir::EmitFusedDynamicUpdateSliceInPlace( 2109 fusion, operands, GetIrArrayFor(fusion), &elemental_emitter, 2110 &ir_builder_); 2111 } else if (fusion->fusion_kind() == HloInstruction::FusionKind::kLoop) { 2112 VLOG(3) << "HandleFusion kLoop"; 2113 CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); 2114 auto operands = GetIrArraysForOperandsOf(fusion); 2115 FusedIrEmitter fused_emitter(operands, &elemental_emitter); 2116 TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter)); 2117 2118 return EmitTargetElementLoop(fusion, fused_emitter.GetRootGenerator()); 2119 } else if (fusion->fusion_kind() == HloInstruction::FusionKind::kOutput) { 2120 VLOG(3) << "HandleFusion kOutput"; 2121 int64 dot_op_index = root->operand(0)->opcode() == HloOpcode::kDot ? 0 : 1; 2122 const HloInstruction* dot = root->operand(dot_op_index); 2123 CHECK_EQ(dot->opcode(), HloOpcode::kDot) 2124 << dot->ToString() << " " 2125 << fusion->fused_instructions_computation()->ToString(); 2126 2127 int64 dot_lhs_param_number = dot->operand(0)->parameter_number(); 2128 int64 dot_rhs_param_number = dot->operand(1)->parameter_number(); 2129 int64 addend_param_number = 2130 root->operand(1 - dot_op_index)->parameter_number(); 2131 2132 Shape target_shape = fusion->shape(); 2133 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion)); 2134 llvm_ir::IrArray target_array = GetIrArrayFor(fusion); 2135 2136 llvm_ir::IrArray lhs_array( 2137 GetIrArrayFor(fusion->operand(dot_lhs_param_number))); 2138 llvm_ir::IrArray rhs_array( 2139 GetIrArrayFor(fusion->operand(dot_rhs_param_number))); 2140 llvm_ir::IrArray addend_array( 2141 GetIrArrayFor(fusion->operand(addend_param_number))); 2142 2143 TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( 2144 *dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array, 2145 lhs_array, rhs_array, &addend_array, GetExecutableRunOptionsArgument(), 2146 &ir_builder_, hlo_module_config_, target_machine_features_)); 2147 return Status::OK(); 2148 } else { 2149 return Unimplemented("Fusion kind not implemented on CPU"); 2150 } 2151 } 2152 2153 Status IrEmitter::HandleCall(HloInstruction* call) { 2154 HloComputation* computation = call->to_apply(); 2155 llvm::Function* call_ir_function = FindOrDie(emitted_functions_, computation); 2156 2157 std::vector<llvm::Value*> parameter_addresses; 2158 for (const HloInstruction* operand : call->operands()) { 2159 parameter_addresses.push_back(GetEmittedValueFor(operand)); 2160 } 2161 2162 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(call)); 2163 2164 if (!computation->root_instruction()->outer_dimension_partitions().empty() && 2165 !parallel_cpu_backend_) { 2166 // ParallelTaskAssignment assigned partitions, emit call to 2167 // ParallelForkJoin. 2168 std::vector<llvm::Value*> call_args = GetArrayFunctionCallArguments( 2169 parameter_addresses, &ir_builder_, computation->name(), 2170 /*return_value_buffer=*/emitted_value_[call], 2171 /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), 2172 /*temp_buffers_arg=*/GetTempBuffersArgument(), 2173 /*profile_counters_arg=*/GetProfileCountersArgument()); 2174 2175 HloInstruction* root = computation->root_instruction(); 2176 TF_RETURN_IF_ERROR(EmitCallToParallelForkJoin( 2177 call_args, root->shape(), root->outer_dimension_partitions(), 2178 &ir_builder_, call_ir_function, computation->name())); 2179 } else { 2180 EmitArrayFunctionCallInto(call_ir_function, parameter_addresses, 2181 emitted_value_[call], computation->name()); 2182 } 2183 2184 return Status::OK(); 2185 } 2186 2187 Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { 2188 gtl::ArraySlice<HloInstruction*> operands(custom_call->operands()); 2189 tensorflow::StringPiece custom_call_target(custom_call->custom_call_target()); 2190 llvm::Type* i8_ptr_type = ir_builder_.getInt8PtrTy(); 2191 llvm::AllocaInst* operands_alloca = 2192 llvm_ir::EmitAllocaAtFunctionEntryWithCount( 2193 i8_ptr_type, ir_builder_.getInt32(operands.size()), 2194 "cc_operands_alloca", &ir_builder_); 2195 for (size_t i = 0; i < operands.size(); ++i) { 2196 const HloInstruction* operand = operands[i]; 2197 llvm::Value* operand_as_i8ptr = 2198 ir_builder_.CreatePointerCast(GetEmittedValueFor(operand), i8_ptr_type); 2199 llvm::Value* slot_in_operands_alloca = ir_builder_.CreateInBoundsGEP( 2200 operands_alloca, {ir_builder_.getInt64(i)}); 2201 ir_builder_.CreateStore(operand_as_i8ptr, slot_in_operands_alloca); 2202 } 2203 auto* custom_call_ir_function = 2204 llvm::cast<llvm::Function>(module_->getOrInsertFunction( 2205 AsStringRef(custom_call_target), 2206 llvm::FunctionType::get( 2207 /*Result=*/ir_builder_.getVoidTy(), 2208 /*Params=*/{i8_ptr_type, operands_alloca->getType()}, 2209 /*isVarArg=*/false))); 2210 2211 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call)); 2212 auto* output_address_arg = ir_builder_.CreatePointerCast( 2213 GetEmittedValueFor(custom_call), i8_ptr_type); 2214 2215 ir_builder_.CreateCall(custom_call_ir_function, 2216 {output_address_arg, operands_alloca}); 2217 2218 return Status::OK(); 2219 } 2220 2221 Status IrEmitter::HandleWhile(HloInstruction* xla_while) { 2222 // Precondition: Condition computation must return a scalar bool. 2223 HloComputation* condition = xla_while->while_condition(); 2224 TF_RET_CHECK(ShapeUtil::IsScalar(condition->root_instruction()->shape()) && 2225 condition->root_instruction()->shape().element_type() == PRED) 2226 << "While condition computation must return bool; got: " 2227 << ShapeUtil::HumanString(condition->root_instruction()->shape()); 2228 // Check that all while-related buffers share an allocation slice. 2229 TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( 2230 xla_while->shape(), 2231 [this, &xla_while](const Shape& /*subshape*/, 2232 const ShapeIndex& index) -> Status { 2233 auto check = [this](const HloInstruction* a, const HloInstruction* b, 2234 const ShapeIndex& index) { 2235 const BufferAllocation::Slice slice_a = 2236 assignment_.GetUniqueSlice(a, index).ConsumeValueOrDie(); 2237 const BufferAllocation::Slice slice_b = 2238 assignment_.GetUniqueSlice(b, index).ConsumeValueOrDie(); 2239 if (slice_a != slice_b) { 2240 return InternalError( 2241 "instruction %s %s does not share slice with " 2242 "instruction %s %s", 2243 a->ToString().c_str(), slice_a.ToString().c_str(), 2244 b->ToString().c_str(), slice_b.ToString().c_str()); 2245 } 2246 return Status::OK(); 2247 }; 2248 TF_RETURN_IF_ERROR(check(xla_while, xla_while->operand(0), index)); 2249 TF_RETURN_IF_ERROR(check( 2250 xla_while, xla_while->while_condition()->parameter_instruction(0), 2251 index)); 2252 TF_RETURN_IF_ERROR( 2253 check(xla_while, xla_while->while_body()->parameter_instruction(0), 2254 index)); 2255 TF_RETURN_IF_ERROR(check( 2256 xla_while, xla_while->while_body()->root_instruction(), index)); 2257 return Status::OK(); 2258 })); 2259 2260 // Set emitted value to that of 'init' with which it shares an allocation. 2261 const HloInstruction* init = xla_while->operand(0); 2262 emitted_value_[xla_while] = GetEmittedValueFor(init); 2263 2264 // The called computation should have been emitted previously. 2265 llvm::Function* condition_ir_function = 2266 FindOrDie(emitted_functions_, condition); 2267 llvm::Function* body_ir_function = 2268 FindOrDie(emitted_functions_, xla_while->while_body()); 2269 2270 // Generating: 2271 // while (Condition(while_result)) { 2272 // // CopyInsertion pass inserts copies which enable 'while_result' to 2273 // // be passed back in as 'Body' parameter. 2274 // while_result = Body(while_result); // Insert 2275 // } 2276 2277 // Terminates the current block with a branch to a while header. 2278 llvm::BasicBlock* header_bb = llvm::BasicBlock::Create( 2279 module_->getContext(), AsStringRef(IrName(xla_while, "header")), 2280 compute_function_->function()); 2281 ir_builder_.CreateBr(header_bb); 2282 ir_builder_.SetInsertPoint(header_bb); 2283 2284 // Calls the condition function to determine whether to proceed with the 2285 // body. It must return a bool, so use the scalar call form. 2286 llvm::Value* while_result = GetEmittedValueFor(xla_while); 2287 llvm::Value* while_condition = EmitElementFunctionCall( 2288 condition_ir_function, condition->root_instruction()->shape(), 2289 {while_result}, IrName(xla_while, "cond")); 2290 llvm::Value* while_predicate = ir_builder_.CreateICmpNE( 2291 while_condition, 2292 llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0)); 2293 2294 // Branches to the body or to the while exit depending on the condition. 2295 llvm::BasicBlock* body_bb = llvm::BasicBlock::Create( 2296 module_->getContext(), AsStringRef(IrName(xla_while, "body")), 2297 compute_function_->function()); 2298 llvm::BasicBlock* exit_bb = llvm::BasicBlock::Create( 2299 module_->getContext(), AsStringRef(IrName(xla_while, "exit"))); 2300 ir_builder_.CreateCondBr(while_predicate, body_bb, exit_bb); 2301 2302 // Calls the body function from the body block. 2303 ir_builder_.SetInsertPoint(body_bb); 2304 2305 // Calls the body function. 2306 EmitArrayFunctionCallInto(body_ir_function, {while_result}, while_result, 2307 IrName(xla_while, "body")); 2308 // Finishes with a branch back to the header. 2309 ir_builder_.CreateBr(header_bb); 2310 2311 // Adds the exit block to the function and sets the insert point there. 2312 compute_function_->function()->getBasicBlockList().push_back(exit_bb); 2313 ir_builder_.SetInsertPoint(exit_bb); 2314 2315 return Status::OK(); 2316 } 2317 2318 StatusOr<bool> IrEmitter::EmitFastConcatenate( 2319 HloInstruction* concatenate, gtl::ArraySlice<HloInstruction*> operands, 2320 string* failure_reason) { 2321 if (ShouldEmitParallelLoopFor(*concatenate)) { 2322 *failure_reason = 2323 "cannot generate memcpy-based concat for the parallel CPU backend"; 2324 return false; 2325 } 2326 2327 const Shape& output_shape = concatenate->shape(); 2328 for (auto* op : operands) { 2329 if (!LayoutUtil::Equal(op->shape().layout(), output_shape.layout())) { 2330 *failure_reason = "operand has mismatching layouts"; 2331 return false; 2332 } 2333 if (LayoutUtil::IsPadded(op->shape())) { 2334 *failure_reason = "operand has padded layout"; 2335 return false; 2336 } 2337 } 2338 2339 CHECK(!LayoutUtil::IsPadded(concatenate->shape())); 2340 2341 // We split the dimensions into three categories: the dimension over which we 2342 // are concatenating (concat_dim), the dimensions that are minor to it 2343 // (inner_dims) and the dimensions that are major to it (outer_dims). 2344 2345 int64 concat_dim = concatenate->dimensions(0); 2346 const Layout& output_layout = output_shape.layout(); 2347 auto output_min2maj = LayoutUtil::MinorToMajor(output_layout); 2348 auto concat_dim_layout_itr = 2349 std::find(output_min2maj.begin(), output_min2maj.end(), concat_dim); 2350 2351 std::vector<int64> inner_dims(output_min2maj.begin(), concat_dim_layout_itr); 2352 std::vector<int64> outer_dims(std::next(concat_dim_layout_itr), 2353 output_min2maj.end()); 2354 2355 llvm::Type* i8_ptr_type = ir_builder_.getInt8PtrTy(); 2356 llvm::Type* i8_type = ir_builder_.getInt8Ty(); 2357 2358 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(concatenate)); 2359 llvm_ir::IrArray target_array = GetIrArrayFor(concatenate); 2360 2361 llvm_ir::ForLoopNest loops(IrName(concatenate), &ir_builder_); 2362 llvm_ir::IrArray::Index outer_dims_index = 2363 loops.AddLoopsForShapeOnDimensions(output_shape, outer_dims, "concat"); 2364 std::replace(outer_dims_index.begin(), outer_dims_index.end(), 2365 static_cast<llvm::Value*>(nullptr), 2366 static_cast<llvm::Value*>(ir_builder_.getInt64(0))); 2367 2368 if (!outer_dims.empty()) { 2369 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); 2370 } 2371 2372 PrimitiveType primitive_type = output_shape.element_type(); 2373 unsigned primitive_type_size = 2374 ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); 2375 2376 // Contiguous subregions from each operand to the concatenate contribute to a 2377 // contiguous subregion in the target buffer starting at target_region_begin. 2378 llvm::Value* target_region_begin = ir_builder_.CreateBitCast( 2379 target_array.EmitArrayElementAddress(outer_dims_index, &ir_builder_, 2380 "target_region"), 2381 i8_ptr_type); 2382 int64 byte_offset_into_target_region = 0; 2383 2384 int64 inner_dims_product = 2385 std::accumulate(inner_dims.begin(), inner_dims.end(), 1l, 2386 [&](int64 product, int64 inner_dim) { 2387 return product * output_shape.dimensions(inner_dim); 2388 }); 2389 2390 // For each operand, emit a memcpy from the operand to the target of size 2391 // equal to the product of inner dimensions. 2392 for (HloInstruction* operand : operands) { 2393 const Shape& input_shape = operand->shape(); 2394 llvm_ir::IrArray source_array = GetIrArrayFor(operand); 2395 llvm::Value* copy_source_address = ir_builder_.CreateBitCast( 2396 source_array.EmitArrayElementAddress(outer_dims_index, &ir_builder_, 2397 "src_addr"), 2398 i8_ptr_type); 2399 2400 llvm::Value* copy_target_address = ir_builder_.CreateGEP( 2401 i8_type, target_region_begin, 2402 ir_builder_.getInt64(byte_offset_into_target_region)); 2403 2404 EmitTransferElements( 2405 copy_target_address, copy_source_address, 2406 inner_dims_product * input_shape.dimensions(concat_dim), primitive_type, 2407 target_array, source_array); 2408 2409 byte_offset_into_target_region += inner_dims_product * 2410 input_shape.dimensions(concat_dim) * 2411 primitive_type_size; 2412 } 2413 2414 if (!outer_dims.empty()) { 2415 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_); 2416 } 2417 2418 return true; 2419 } 2420 2421 void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, 2422 int64 element_count, 2423 PrimitiveType primitive_type, 2424 const llvm_ir::IrArray& target_array, 2425 const llvm_ir::IrArray& source_array) { 2426 unsigned primitive_type_size = 2427 ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); 2428 unsigned element_alignment = tensorflow::MathUtil::GCD<unsigned>( 2429 primitive_type_size, MinimumAlignmentForPrimitiveType(primitive_type)); 2430 llvm::Type* primitive_ptr_type = llvm::PointerType::getUnqual( 2431 llvm_ir::PrimitiveTypeToIrType(primitive_type, module_)); 2432 2433 if (element_count == 1) { 2434 auto* load_instruction = ir_builder_.CreateAlignedLoad( 2435 ir_builder_.CreateBitCast(source, primitive_ptr_type), 2436 element_alignment); 2437 source_array.AnnotateLoadStoreInstructionWithMetadata(load_instruction); 2438 auto* store_instruction = ir_builder_.CreateAlignedStore( 2439 load_instruction, ir_builder_.CreateBitCast(target, primitive_ptr_type), 2440 element_alignment); 2441 target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction); 2442 } else { 2443 auto* memcpy_instruction = ir_builder_.CreateMemCpy( 2444 target, source, element_count * primitive_type_size, element_alignment); 2445 2446 // The memcpy does the load and the store internally. The aliasing related 2447 // metadata has to reflect that. 2448 std::map<int, llvm::MDNode*> merged_metadata = 2449 llvm_ir::MergeMetadata(&module_->getContext(), source_array.metadata(), 2450 target_array.metadata()); 2451 for (const auto& kind_md_pair : merged_metadata) { 2452 memcpy_instruction->setMetadata(kind_md_pair.first, kind_md_pair.second); 2453 } 2454 } 2455 } 2456 2457 Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) { 2458 gtl::ArraySlice<HloInstruction*> operands(concatenate->operands()); 2459 string failure_reason; 2460 TF_ASSIGN_OR_RETURN( 2461 bool successful, 2462 EmitFastConcatenate(concatenate, operands, &failure_reason)); 2463 if (successful) { 2464 VLOG(1) << "Emitted fast concatenate for " << concatenate->ToString(); 2465 return Status::OK(); 2466 } 2467 2468 VLOG(1) << "Could not emit fast concatenate for " << concatenate->ToString() 2469 << ": " << failure_reason; 2470 2471 return DefaultAction(concatenate); 2472 } 2473 2474 Status IrEmitter::HandleConditional(HloInstruction* conditional) { 2475 auto pred = conditional->operand(0); 2476 auto true_arg = conditional->operand(1); 2477 auto false_arg = conditional->operand(2); 2478 TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()) && 2479 pred->shape().element_type() == PRED) 2480 << "Predicate on a Conditional must be bool; got: " 2481 << ShapeUtil::HumanString(pred->shape()); 2482 2483 HloComputation* true_computation = conditional->true_computation(); 2484 HloComputation* false_computation = conditional->false_computation(); 2485 TF_RET_CHECK(ShapeUtil::Equal(conditional->shape(), 2486 true_computation->root_instruction()->shape())) 2487 << "Shape of conditional should be same as the shape of the true " 2488 << "computation; got: " << ShapeUtil::HumanString(conditional->shape()) 2489 << " and " 2490 << ShapeUtil::HumanString(true_computation->root_instruction()->shape()); 2491 2492 TF_RET_CHECK(ShapeUtil::Equal(conditional->shape(), 2493 false_computation->root_instruction()->shape())) 2494 << "Shape of conditional should be same as the shape of the false " 2495 << "computation; got: " << ShapeUtil::HumanString(conditional->shape()) 2496 << " and " 2497 << ShapeUtil::HumanString(false_computation->root_instruction()->shape()); 2498 2499 llvm::Function* true_function = 2500 FindOrDie(emitted_functions_, true_computation); 2501 llvm::Function* false_function = 2502 FindOrDie(emitted_functions_, false_computation); 2503 2504 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(conditional)); 2505 llvm::Value* conditional_result = GetEmittedValueFor(conditional); 2506 2507 // Generating: 2508 // if (pred) 2509 // cond_result = true_computation(true_operand) 2510 // else 2511 // cond_result = false_computation(false_operand) 2512 llvm::LoadInst* pred_value = ir_builder_.CreateLoad( 2513 GetIrArrayFor(pred).GetBasePointer(), "load_predicate_value"); 2514 llvm::Value* pred_cond = ir_builder_.CreateICmpNE( 2515 pred_value, 2516 llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0), 2517 "boolean_predicate"); 2518 llvm_ir::LlvmIfData if_data = 2519 llvm_ir::EmitIfThenElse(pred_cond, "conditional", &ir_builder_); 2520 2521 SetToFirstInsertPoint(if_data.true_block, &ir_builder_); 2522 EmitArrayFunctionCallInto(true_function, {GetEmittedValueFor(true_arg)}, 2523 conditional_result, IrName(conditional, "_true")); 2524 2525 SetToFirstInsertPoint(if_data.false_block, &ir_builder_); 2526 EmitArrayFunctionCallInto(false_function, {GetEmittedValueFor(false_arg)}, 2527 conditional_result, IrName(conditional, "_false")); 2528 2529 SetToFirstInsertPoint(if_data.after_block, &ir_builder_); 2530 return Status::OK(); 2531 } 2532 2533 Status IrEmitter::FinishVisit(HloInstruction* root) { 2534 // When this method is called, we should have already emitted an IR value for 2535 // the root (return) op. The IR value holds the address of the buffer holding 2536 // the value. If the root is a constant or parameter, we perform a memcpy from 2537 // this buffer to the retval buffer of the computation. Otherwise, there's 2538 // nothing to do since the result was already written directly into the output 2539 // buffer. 2540 VLOG(2) << "FinishVisit root: " << root->ToString(); 2541 llvm::Value* root_value = GetEmittedValueFor(root); 2542 VLOG(2) << " value: " << llvm_ir::DumpToString(*root_value); 2543 2544 auto record_complete_computation = [&](llvm::Value* prof_counter) { 2545 if (prof_counter) { 2546 profiling_state_.RecordCompleteComputation(&ir_builder_, prof_counter); 2547 } 2548 }; 2549 2550 // For the parallel cpu backend, we record the total for each embedded 2551 // computation callee with its caller kCall HLO. 2552 if (parallel_cpu_backend_ && is_top_level_computation_) { 2553 auto* computation = root->parent(); 2554 auto* entry_computation = computation->parent()->entry_computation(); 2555 if (computation != entry_computation) { 2556 for (HloInstruction* instruction : entry_computation->instructions()) { 2557 if (instruction->opcode() == HloOpcode::kCall && 2558 instruction->to_apply()->root_instruction() == root) { 2559 record_complete_computation(GetProfileCounterFor(*instruction)); 2560 return Status::OK(); 2561 } 2562 } 2563 } 2564 } 2565 2566 // For the entry computation this increment is cumulative of embedded 2567 // computations since it includes cycles spent in computations invoked by 2568 // While, Call etc. 2569 record_complete_computation(GetProfileCounterFor(*root->parent())); 2570 return Status::OK(); 2571 } 2572 2573 template <typename T> 2574 llvm::Value* IrEmitter::GetProfileCounterCommon( 2575 const T& hlo, 2576 const std::unordered_map<const T*, int64>& profile_index_map) { 2577 auto it = profile_index_map.find(&hlo); 2578 if (it == profile_index_map.end()) { 2579 return nullptr; 2580 } 2581 2582 int64 prof_counter_idx = it->second; 2583 string counter_name = IrName("prof_counter", hlo.name()); 2584 return ir_builder_.CreateGEP(GetProfileCountersArgument(), 2585 ir_builder_.getInt64(prof_counter_idx), 2586 AsStringRef(counter_name)); 2587 } 2588 2589 void IrEmitter::ProfilingState::UpdateProfileCounter( 2590 llvm::IRBuilder<>* ir_builder, llvm::Value* prof_counter, 2591 llvm::Value* cycle_end, llvm::Value* cycle_start) { 2592 auto* cycle_diff = ir_builder->CreateSub(cycle_end, cycle_start); 2593 llvm::LoadInst* old_cycle_count = 2594 ir_builder->CreateLoad(prof_counter, "old_cycle_count"); 2595 auto* new_cycle_count = 2596 ir_builder->CreateAdd(cycle_diff, old_cycle_count, "new_cycle_count"); 2597 ir_builder->CreateStore(new_cycle_count, prof_counter); 2598 } 2599 2600 llvm::Value* IrEmitter::ProfilingState::ReadCycleCounter( 2601 llvm::IRBuilder<>* ir_builder) { 2602 llvm::Module* module = ir_builder->GetInsertBlock()->getModule(); 2603 if (use_rdtscp_) { 2604 llvm::Function* func_llvm_readcyclecounter = 2605 llvm::Intrinsic::getDeclaration(module, 2606 llvm::Intrinsic::readcyclecounter); 2607 return ir_builder->CreateCall(func_llvm_readcyclecounter); 2608 } 2609 llvm::Function* func_llvm_x86_rdtscp = 2610 llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::x86_rdtscp); 2611 if (!aux_i8ptr_) { 2612 llvm::AllocaInst* rdtscp_aux = llvm_ir::EmitAllocaAtFunctionEntry( 2613 ir_builder->getInt32Ty(), "rdtscp_aux", ir_builder); 2614 aux_i8ptr_ = 2615 ir_builder->CreateBitCast(rdtscp_aux, ir_builder->getInt8PtrTy()); 2616 } 2617 llvm::ConstantInt* alloca_size = ir_builder->getInt64(4); 2618 llvm::Function* func_llvm_lifetime_start = 2619 llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::lifetime_start); 2620 ir_builder->CreateCall(func_llvm_lifetime_start, {alloca_size, aux_i8ptr_}); 2621 llvm::Value* rdtscp_call = 2622 ir_builder->CreateCall(func_llvm_x86_rdtscp, aux_i8ptr_); 2623 llvm::Function* func_llvm_lifetime_end = 2624 llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::lifetime_end); 2625 ir_builder->CreateCall(func_llvm_lifetime_end, {alloca_size, aux_i8ptr_}); 2626 return rdtscp_call; 2627 } 2628 2629 void IrEmitter::ProfilingState::RecordCycleStart(llvm::IRBuilder<>* ir_builder, 2630 HloInstruction* hlo) { 2631 auto* cycle_start = ReadCycleCounter(ir_builder); 2632 cycle_start->setName(AsStringRef(IrName(hlo, "cycle_start"))); 2633 cycle_starts_[hlo] = cycle_start; 2634 if (first_read_cycle_start_ == nullptr) { 2635 first_read_cycle_start_ = cycle_start; 2636 } 2637 } 2638 2639 void IrEmitter::ProfilingState::RecordCycleDelta(llvm::IRBuilder<>* ir_builder, 2640 HloInstruction* hlo, 2641 llvm::Value* prof_counter) { 2642 auto* cycle_end = ReadCycleCounter(ir_builder); 2643 cycle_end->setName(AsStringRef(IrName(hlo, "cycle_end"))); 2644 auto* cycle_start = cycle_starts_[hlo]; 2645 UpdateProfileCounter(ir_builder, prof_counter, cycle_end, cycle_start); 2646 last_read_cycle_end_ = cycle_end; 2647 } 2648 2649 void IrEmitter::ProfilingState::RecordCompleteComputation( 2650 llvm::IRBuilder<>* ir_builder, llvm::Value* prof_counter) { 2651 if (last_read_cycle_end_ && first_read_cycle_start_) { 2652 UpdateProfileCounter(ir_builder, prof_counter, last_read_cycle_end_, 2653 first_read_cycle_start_); 2654 } 2655 } 2656 2657 Status IrEmitter::Preprocess(HloInstruction* hlo) { 2658 VLOG(3) << "Visiting: " << hlo->ToString(); 2659 if (instruction_to_profile_idx_.count(hlo)) { 2660 profiling_state_.RecordCycleStart(&ir_builder_, hlo); 2661 } 2662 return Status::OK(); 2663 } 2664 2665 Status IrEmitter::Postprocess(HloInstruction* hlo) { 2666 if (auto* prof_counter = GetProfileCounterFor(*hlo)) { 2667 profiling_state_.RecordCycleDelta(&ir_builder_, hlo, prof_counter); 2668 } 2669 return Status::OK(); 2670 } 2671 2672 llvm_ir::IrArray IrEmitter::GetIrArrayFor(const HloInstruction* hlo) { 2673 llvm::Value* value_for_op = GetEmittedValueFor(hlo); 2674 2675 llvm_ir::IrArray array(value_for_op, hlo->shape()); 2676 AddAliasingInformationToIrArray(*hlo, &array); 2677 return array; 2678 } 2679 2680 std::vector<llvm_ir::IrArray> IrEmitter::GetIrArraysForOperandsOf( 2681 const HloInstruction* hlo) { 2682 std::vector<llvm_ir::IrArray> arrays; 2683 std::transform( 2684 hlo->operands().begin(), hlo->operands().end(), 2685 std::back_inserter(arrays), 2686 [&](const HloInstruction* operand) { return GetIrArrayFor(operand); }); 2687 return arrays; 2688 } 2689 2690 llvm::Value* IrEmitter::GetEmittedValueFor(const HloInstruction* hlo) { 2691 auto it = emitted_value_.find(hlo); 2692 if (it == emitted_value_.end()) { 2693 LOG(FATAL) << "could not find emitted value for: " << hlo->ToString(); 2694 } 2695 return it->second; 2696 } 2697 2698 llvm::Type* IrEmitter::IrShapeType(const Shape& shape) { 2699 return llvm_ir::ShapeToIrType(shape, module_); 2700 } 2701 2702 llvm::Value* IrEmitter::GetProfileCountersArgument() { 2703 return compute_function_->profile_counters_arg(); 2704 } 2705 2706 llvm::Value* IrEmitter::GetTempBuffersArgument() { 2707 return compute_function_->temp_buffers_arg(); 2708 } 2709 2710 llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() { 2711 return compute_function_->exec_run_options_arg(); 2712 } 2713 2714 llvm::Value* IrEmitter::EmitTempBufferPointer( 2715 const BufferAllocation::Slice& slice, const Shape& target_shape) { 2716 llvm::Type* element_type = IrShapeType(target_shape); 2717 // The alignment and number of bytes within the temporary buffer is determined 2718 // by the maximal shape as determined by buffer assignment. 2719 const BufferAllocation& allocation = assignment_.GetAllocation(slice.index()); 2720 if (allocation.is_thread_local()) { 2721 // Thread-local allocations should only be assigned a single buffer. 2722 const auto& assigned_buffers = allocation.assigned_buffers(); 2723 CHECK_EQ(1, assigned_buffers.size()); 2724 const Shape& shape = assigned_buffers.begin()->first->shape(); 2725 2726 llvm::AllocaInst*& tempbuf_address = thread_local_buffers_[{ 2727 ir_builder_.GetInsertBlock()->getParent(), slice}]; 2728 if (tempbuf_address == nullptr) { 2729 tempbuf_address = llvm_ir::EmitAllocaAtFunctionEntry( 2730 IrShapeType(shape), 2731 tensorflow::strings::StrCat("thread_local", slice.ToString()), 2732 &ir_builder_, MinimumAlignmentForShape(target_shape)); 2733 } 2734 return ir_builder_.CreateBitCast(tempbuf_address, 2735 element_type->getPointerTo()); 2736 } 2737 2738 llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP( 2739 GetTempBuffersArgument(), slice.index(), &ir_builder_); 2740 llvm::LoadInst* tempbuf_address_base = 2741 ir_builder_.CreateLoad(tempbuf_address_ptr); 2742 if (is_top_level_computation_ && 2743 hlo_module_config_.debug_options() 2744 .xla_llvm_enable_invariant_load_metadata()) { 2745 // In the entry computation the parameter slots in the %params argument are 2746 // invariant through program execution. In computations that are called 2747 // from the entry computation (via kWhile, kCall and kConditional) the 2748 // parameter slots are *not* invariant since they're written to by their 2749 // callers. 2750 tempbuf_address_base->setMetadata( 2751 llvm::LLVMContext::MD_invariant_load, 2752 llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{})); 2753 } 2754 AttachAlignmentMetadataForLoad(tempbuf_address_base, allocation.size()); 2755 AttachDereferenceableMetadataForLoad(tempbuf_address_base, allocation.size()); 2756 2757 llvm::Value* tempbuf_address_untyped = tempbuf_address_base; 2758 if (slice.offset() > 0) { 2759 // Adjust the address to account for the slice offset. 2760 tempbuf_address_untyped = ir_builder_.CreateInBoundsGEP( 2761 tempbuf_address_base, ir_builder_.getInt64(slice.offset())); 2762 } 2763 return ir_builder_.CreateBitCast(tempbuf_address_untyped, 2764 element_type->getPointerTo()); 2765 } 2766 2767 // Emits a function call returning a single array element. Allocates space 2768 // for a single element_type value, and loads it after call. 2769 llvm::Value* IrEmitter::EmitElementFunctionCall( 2770 llvm::Function* function, const Shape& return_shape, 2771 gtl::ArraySlice<llvm::Value*> parameter_addresses, 2772 tensorflow::StringPiece name) { 2773 llvm::Value* return_value_buffer = EmitArrayFunctionCall( 2774 function, return_shape, 1, parameter_addresses, name); 2775 return ir_builder_.CreateLoad( 2776 return_value_buffer, 2777 AsStringRef(tensorflow::strings::StrCat(name, "_return_value"))); 2778 } 2779 2780 // Emits a core function call based on the following pseudo-code. 2781 // 2782 // char** parameter_addresses_buffer = 2783 // allocate buffer with a pointer for each parameter to the function 2784 // for each parameter index, i.e. for i = 0, ..., #parameters: 2785 // parameter_addresses_buffer[i] = parameter_addresses[i] 2786 // call function(return_value_buffer, 2787 // parameter_addresses_buffer, 2788 // temps) 2789 // return return_value_buffer -- address of the return value. 2790 void IrEmitter::EmitArrayFunctionCallInto( 2791 llvm::Function* function, gtl::ArraySlice<llvm::Value*> parameter_addresses, 2792 llvm::Value* return_value_buffer, tensorflow::StringPiece name) { 2793 ir_builder_.CreateCall( 2794 function, GetArrayFunctionCallArguments( 2795 parameter_addresses, &ir_builder_, name, 2796 /*return_value_buffer=*/return_value_buffer, 2797 /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), 2798 /*temp_buffers_arg=*/GetTempBuffersArgument(), 2799 /*profile_counters_arg=*/GetProfileCountersArgument())); 2800 } 2801 2802 llvm::Value* IrEmitter::EmitArrayFunctionCall( 2803 llvm::Function* function, const Shape& return_shape, int64 element_count, 2804 gtl::ArraySlice<llvm::Value*> parameter_addresses, 2805 tensorflow::StringPiece name) { 2806 llvm::Value* elements = 2807 llvm::ConstantInt::get(ir_builder_.getInt64Ty(), element_count); 2808 PrimitiveType return_type = return_shape.element_type(); 2809 llvm::Value* return_value_buffer = 2810 llvm_ir::EmitAllocaAtFunctionEntryWithCount( 2811 llvm_ir::PrimitiveTypeToIrType(return_type, module_), elements, 2812 tensorflow::strings::StrCat(name, "_return_value_address"), 2813 &ir_builder_, MinimumAlignmentForPrimitiveType(return_type)); 2814 EmitArrayFunctionCallInto(function, parameter_addresses, return_value_buffer, 2815 name); 2816 return return_value_buffer; 2817 } 2818 2819 Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) { 2820 llvm::Value* addr; 2821 const Shape& target_shape = op->shape(); 2822 if (op == op->parent()->root_instruction()) { 2823 // For the root node, we write directly to the output buffer of the 2824 // function. 2825 llvm::Argument* retval = compute_function_->result_arg(); 2826 if (!ShapeUtil::IsNil(target_shape)) { 2827 llvm::AttrBuilder attr_builder; 2828 attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape)); 2829 attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape)); 2830 retval->addAttrs(attr_builder); 2831 } 2832 addr = ir_builder_.CreateBitCast(retval, 2833 IrShapeType(target_shape)->getPointerTo()); 2834 } else { 2835 // For other nodes, we need the temporary buffer allocated for this node to 2836 // write the result into. 2837 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, 2838 assignment_.GetUniqueTopLevelSlice(op)); 2839 addr = EmitTempBufferPointer(slice, target_shape); 2840 } 2841 addr->setName(AsStringRef(IrName(op))); 2842 emitted_value_[op] = addr; 2843 return Status::OK(); 2844 } 2845 2846 Status IrEmitter::EmitTargetElementLoop( 2847 HloInstruction* target_op, 2848 const llvm_ir::ElementGenerator& element_generator) { 2849 return EmitTargetElementLoop(target_op, /*desc=*/"", element_generator); 2850 } 2851 2852 Status IrEmitter::EmitTargetElementLoop( 2853 HloInstruction* target_op, tensorflow::StringPiece desc, 2854 const llvm_ir::ElementGenerator& element_generator) { 2855 VLOG(2) << "EmitTargetElementLoop: " << target_op->ToString(); 2856 2857 const Shape& target_shape = target_op->shape(); 2858 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(target_op)); 2859 llvm_ir::IrArray target_array = GetIrArrayFor(target_op); 2860 2861 if (target_op->IsMultiOutputFusion()) { 2862 // For multiple outputs fusion, we need to emit each operand and the root. 2863 TF_RET_CHECK(num_dynamic_loop_bounds_ == 0); 2864 std::vector<llvm_ir::IrArray> output_arrays; 2865 for (int64 i = 0; i < ShapeUtil::TupleElementCount(target_shape); ++i) { 2866 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice, 2867 assignment_.GetUniqueSlice(target_op, {i})); 2868 const Shape& element_shape = ShapeUtil::GetSubshape(target_shape, {i}); 2869 llvm::Value* op_target_address = 2870 EmitTempBufferPointer(slice, element_shape); 2871 output_arrays.push_back( 2872 llvm_ir::IrArray(op_target_address, element_shape)); 2873 } 2874 TF_RETURN_IF_ERROR( 2875 llvm_ir::LoopEmitter(element_generator, output_arrays, &ir_builder_) 2876 .EmitLoop(IrName(target_op))); 2877 2878 std::vector<llvm::Value*> tuple_operand_ptrs; 2879 for (int64 i = 0; i < output_arrays.size(); ++i) { 2880 tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); 2881 } 2882 llvm_ir::EmitTuple(target_array, tuple_operand_ptrs, &ir_builder_, module_); 2883 2884 } else { 2885 if (ShouldEmitParallelLoopFor(*target_op)) { 2886 // Emit code to read dynamic loop bounds from compute function argument. 2887 std::vector<std::pair<llvm::Value*, llvm::Value*>> dynamic_loop_bounds = 2888 compute_function_->GetDynamicLoopBounds(); 2889 // Emit parallel loop with dynamic loop bounds for most-major dimensions. 2890 TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, target_array, 2891 &dynamic_loop_bounds, &ir_builder_) 2892 .EmitLoop(IrName(target_op))); 2893 } else { 2894 TF_RETURN_IF_ERROR( 2895 llvm_ir::LoopEmitter(element_generator, target_array, &ir_builder_) 2896 .EmitLoop(IrName(target_op))); 2897 } 2898 } 2899 return Status::OK(); 2900 } 2901 2902 Status IrEmitter::EmitMemcpy(const HloInstruction& source, 2903 const HloInstruction& destination) { 2904 llvm::Value* source_value = GetEmittedValueFor(&source); 2905 llvm::Value* destination_value = GetEmittedValueFor(&destination); 2906 int64 source_size = ByteSizeOf(source.shape()); 2907 // TODO(b/63762267): Be more aggressive about specifying alignment. 2908 ir_builder_.CreateMemCpy(destination_value, source_value, source_size, 1); 2909 return Status::OK(); 2910 } 2911 2912 Status IrEmitter::ElementTypesSameAndSupported( 2913 const HloInstruction& instruction, 2914 gtl::ArraySlice<const HloInstruction*> operands, 2915 gtl::ArraySlice<PrimitiveType> supported_types) { 2916 for (auto operand : operands) { 2917 TF_RET_CHECK( 2918 ShapeUtil::SameElementType(operands[0]->shape(), operand->shape())); 2919 } 2920 2921 TF_RET_CHECK(!operands.empty()); 2922 PrimitiveType primitive_type = operands[0]->shape().element_type(); 2923 if (std::find(supported_types.begin(), supported_types.end(), 2924 primitive_type) == supported_types.end()) { 2925 return Unimplemented("unsupported operand type %s in op %s", 2926 PrimitiveType_Name(primitive_type).c_str(), 2927 HloOpcodeString(instruction.opcode()).c_str()); 2928 } 2929 return Status::OK(); 2930 } 2931 2932 Status IrEmitter::DefaultAction(HloInstruction* hlo) { 2933 ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; 2934 for (const HloInstruction* operand : hlo->operands()) { 2935 operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) { 2936 return GetIrArrayFor(operand).EmitReadArrayElement(index, &ir_builder_); 2937 }; 2938 } 2939 CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); 2940 return EmitTargetElementLoop( 2941 hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator)); 2942 } 2943 2944 StatusOr<llvm::Value*> IrEmitter::EmitScalarCall( 2945 PrimitiveType return_type, HloComputation* computation, 2946 const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name) { 2947 llvm::Function* llvm_function = FindOrDie(emitted_functions_, computation); 2948 std::vector<llvm::Value*> argument_addrs; 2949 for (auto argument : arguments) { 2950 llvm::Value* argument_addr = llvm_ir::EmitAllocaAtFunctionEntry( 2951 argument->getType(), "arg_addr", &ir_builder_); 2952 ir_builder_.CreateStore(argument, argument_addr); 2953 argument_addrs.push_back(argument_addr); 2954 } 2955 return EmitElementFunctionCall(llvm_function, 2956 ShapeUtil::MakeShape(return_type, {}), 2957 argument_addrs, name); 2958 } 2959 } // namespace cpu 2960 } // namespace xla 2961