1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" 17 18 #include <algorithm> 19 #include <memory> 20 #include <vector> 21 22 #include "absl/base/casts.h" 23 #include "absl/strings/match.h" 24 #include "absl/strings/str_cat.h" 25 #include "llvm/IR/DerivedTypes.h" 26 #include "llvm/IR/GlobalValue.h" 27 #include "llvm/IR/GlobalVariable.h" 28 #include "llvm/IR/MDBuilder.h" 29 #include "llvm/IR/Operator.h" 30 #include "llvm/Target/TargetOptions.h" 31 #include "llvm/Transforms/Utils/Cloning.h" 32 #include "tensorflow/compiler/xla/layout_util.h" 33 #include "tensorflow/compiler/xla/literal.h" 34 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" 35 #include "tensorflow/compiler/xla/service/dump.h" 36 #include "tensorflow/compiler/xla/service/name_uniquer.h" 37 #include "tensorflow/compiler/xla/shape_util.h" 38 #include "tensorflow/compiler/xla/types.h" 39 #include "tensorflow/compiler/xla/util.h" 40 #include "tensorflow/core/lib/core/errors.h" 41 #include "tensorflow/core/lib/io/path.h" 42 #include "tensorflow/core/platform/byte_order.h" 43 #include "tensorflow/core/platform/env.h" 44 #include "tensorflow/core/platform/logging.h" 45 #include "tensorflow/core/platform/types.h" 46 47 namespace xla { 48 namespace llvm_ir { 49 50 namespace { 51 52 // Note, this function is only useful in an insertion context; in a global 53 // (e.g. constants) context it will CHECK fail. 54 llvm::Module* ModuleFromIRBuilder(llvm::IRBuilder<>* b) { 55 auto block = CHECK_NOTNULL(b->GetInsertBlock()); 56 auto fn = CHECK_NOTNULL(block->getParent()); 57 auto module = CHECK_NOTNULL(fn->getParent()); 58 return module; 59 } 60 61 } // namespace 62 63 std::unique_ptr<llvm::Module> DropConstantInitializers( 64 const llvm::Module& module) { 65 std::unique_ptr<llvm::Module> cloned_module = CloneModule(module); 66 for (llvm::GlobalVariable& global_var : cloned_module->globals()) { 67 global_var.setInitializer(nullptr); 68 global_var.setLinkage(llvm::GlobalValue::LinkageTypes::ExternalLinkage); 69 } 70 return cloned_module; 71 } 72 73 string DumpModuleToString(const llvm::Module& module) { 74 std::string buffer_string; 75 llvm::raw_string_ostream ostream(buffer_string); 76 module.print(ostream, nullptr); 77 ostream.flush(); 78 return buffer_string; 79 } 80 81 llvm::CallInst* EmitCallToIntrinsic( 82 llvm::Intrinsic::ID intrinsic_id, absl::Span<llvm::Value* const> operands, 83 absl::Span<llvm::Type* const> overloaded_types, llvm::IRBuilder<>* b) { 84 llvm::Module* module = ModuleFromIRBuilder(b); 85 llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration( 86 module, intrinsic_id, AsArrayRef(overloaded_types)); 87 return b->CreateCall(intrinsic, AsArrayRef(operands)); 88 } 89 90 llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value, 91 llvm::IRBuilder<>* b) { 92 if (b->getFastMathFlags().noNaNs()) { 93 auto cmp = b->CreateFCmpUGE(lhs_value, rhs_value); 94 return b->CreateSelect(cmp, lhs_value, rhs_value); 95 } else { 96 auto cmp_ge = b->CreateFCmpOGE(lhs_value, rhs_value); 97 auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value); 98 auto sel_lhs = b->CreateOr(cmp_ge, lhs_is_nan); 99 return b->CreateSelect(sel_lhs, lhs_value, rhs_value); 100 } 101 } 102 103 llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value, 104 llvm::IRBuilder<>* b) { 105 if (b->getFastMathFlags().noNaNs()) { 106 auto cmp = b->CreateFCmpULE(lhs_value, rhs_value); 107 return b->CreateSelect(cmp, lhs_value, rhs_value); 108 } else { 109 auto cmp_le = b->CreateFCmpOLE(lhs_value, rhs_value); 110 auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value); 111 auto sel_lhs = b->CreateOr(cmp_le, lhs_is_nan); 112 return b->CreateSelect(sel_lhs, lhs_value, rhs_value); 113 } 114 } 115 116 llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Value* index, 117 llvm::IRBuilder<>* b) { 118 llvm::Type* array_type = array->getType(); 119 CHECK(array_type->isPointerTy()); 120 llvm::PointerType* array_type_as_pointer = 121 llvm::cast<llvm::PointerType>(array_type); 122 VLOG(2) << "EmitBufferIndexingGEP with type=" 123 << llvm_ir::DumpToString(*array_type) 124 << " array=" << llvm_ir::DumpToString(*array) 125 << " index=" << llvm_ir::DumpToString(*index); 126 127 return b->CreateInBoundsGEP( 128 array_type_as_pointer->getElementType(), array, 129 llvm::isa<llvm::GlobalVariable>(array) 130 ? llvm::ArrayRef<llvm::Value*>({b->getInt64(0), index}) 131 : index); 132 } 133 134 llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64 index, 135 llvm::IRBuilder<>* b) { 136 return EmitBufferIndexingGEP(array, b->getInt64(index), b); 137 } 138 139 llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, 140 llvm::Module* module) { 141 switch (element_type) { 142 case PRED: 143 case S8: 144 case U8: 145 return llvm::Type::getInt8Ty(module->getContext()); 146 case S16: 147 case U16: 148 case BF16: 149 // For BF16 we just need some type that is 16 bits wide so that it will 150 // take up the right amount of space in memory. LLVM does not have a BF16 151 // type (the LLVM half type is IEEE 16 bit floating point, not bfloat), so 152 // we can't map it directly to an LLVM type. We will not map a BF16 153 // addition to an addition on this type (int16) - this is just the type 154 // used for storage. 155 return llvm::Type::getInt16Ty(module->getContext()); 156 case F16: 157 return llvm::Type::getHalfTy(module->getContext()); 158 case S32: 159 case U32: 160 return llvm::Type::getInt32Ty(module->getContext()); 161 case S64: 162 case U64: 163 return llvm::Type::getInt64Ty(module->getContext()); 164 case F32: 165 return llvm::Type::getFloatTy(module->getContext()); 166 case F64: 167 return llvm::Type::getDoubleTy(module->getContext()); 168 case C64: { 169 auto cplx_t = module->getTypeByName("complex64"); 170 if (cplx_t == nullptr) { 171 // C++ standard dictates the memory layout of std::complex is contiguous 172 // real followed by imaginary. C++11 section 26.4 [complex.numbers]: 173 // If z is an lvalue expression of type cv std::complex<T> then the 174 // expression reinterpret_cast<cv T(&)[2]>(z) shall be well-formed, 175 // reinterpret_cast<cv T(&)[2]>(z)[0] shall designate the real part of 176 // z, and reinterpret_cast<cv T(&)[2]>(z)[1] shall designate the 177 // imaginary part of z. 178 return llvm::StructType::create( 179 {llvm::Type::getFloatTy(module->getContext()), 180 llvm::Type::getFloatTy(module->getContext())}, 181 "complex64", /*isPacked=*/true); 182 } 183 return cplx_t; 184 } 185 case C128: { 186 auto cplx_t = module->getTypeByName("complex128"); 187 if (cplx_t == nullptr) { 188 return llvm::StructType::create( 189 {llvm::Type::getDoubleTy(module->getContext()), 190 llvm::Type::getDoubleTy(module->getContext())}, 191 "complex128", /*isPacked=*/true); 192 } 193 return cplx_t; 194 } // A Tuple contains an array of pointers. Use i8*. 195 case TUPLE: 196 // An Opaque is like a void*, use i8*. 197 case OPAQUE: 198 return llvm::Type::getInt8PtrTy(module->getContext()); 199 case TOKEN: 200 // Tokens do not have a physical representation, but the compiler needs 201 // some placeholder type, so use int8*. 202 return llvm::Type::getInt8PtrTy(module->getContext()); 203 default: 204 LOG(FATAL) << "unsupported type " << element_type; 205 } 206 } 207 208 int GetSizeInBits(llvm::Type* type) { 209 const llvm::StructType* struct_ty = llvm::dyn_cast<llvm::StructType>(type); 210 if (struct_ty) { 211 CHECK(struct_ty->isPacked()); 212 int bits = 0; 213 for (auto element_type : struct_ty->elements()) { 214 bits += GetSizeInBits(element_type); 215 } 216 return bits; 217 } 218 int bits = type->getPrimitiveSizeInBits(); 219 CHECK_GT(bits, 0) << "type is not sized"; 220 return bits; 221 } 222 223 llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module) { 224 llvm::Type* result_type = PrimitiveTypeToIrType(shape.element_type(), module); 225 if (shape.IsTuple()) { 226 // A tuple buffer is an array of pointers. 227 result_type = llvm::ArrayType::get(result_type, shape.tuple_shapes_size()); 228 } else if (shape.IsArray()) { 229 for (int64 dimension : LayoutUtil::MinorToMajor(shape)) { 230 result_type = 231 llvm::ArrayType::get(result_type, shape.dimensions(dimension)); 232 } 233 } 234 return result_type; 235 } 236 237 StatusOr<llvm::Value*> EncodeSelfDescribingShapeConstant(const Shape& shape, 238 int32* shape_size, 239 llvm::IRBuilder<>* b) { 240 string encoded_shape = shape.SerializeAsString(); 241 if (encoded_shape.size() > std::numeric_limits<int32>::max()) { 242 return InternalError("Encoded shape size exceeded int32 size limit."); 243 } 244 *shape_size = static_cast<int32>(encoded_shape.size()); 245 return b->CreateGlobalStringPtr(encoded_shape); 246 } 247 248 StatusOr<Shape> DecodeSelfDescribingShapeConstant(const void* shape_ptr, 249 int32 size_bytes) { 250 ShapeProto shape_proto; 251 TF_RET_CHECK(shape_proto.ParseFromArray(shape_ptr, size_bytes)); 252 Shape shape(shape_proto); 253 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape)); 254 return std::move(shape); 255 } 256 257 llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, 258 llvm::Module* module) { 259 const char* data = static_cast<const char*>(literal.untyped_data()); 260 CHECK_EQ(module->getDataLayout().isLittleEndian(), 261 tensorflow::port::kLittleEndian); 262 return llvm::ConstantDataArray::getString( 263 module->getContext(), llvm::StringRef(data, literal.size_bytes()), 264 /*AddNull=*/false); 265 } 266 267 llvm::GlobalVariable* AllocateSharedMemoryTile(llvm::Module* module, 268 llvm::Type* tile_type, 269 absl::string_view name) { 270 const int kNVPTXSharedMemoryAddrSpace = 3; 271 return new llvm::GlobalVariable( 272 *module, tile_type, 273 /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage, 274 llvm::UndefValue::get(tile_type), AsStringRef(name), nullptr, 275 llvm::GlobalValue::NotThreadLocal, kNVPTXSharedMemoryAddrSpace); 276 } 277 278 llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type, 279 absl::string_view name, 280 llvm::IRBuilder<>* b, 281 int alignment) { 282 return EmitAllocaAtFunctionEntryWithCount(type, nullptr, name, b, alignment); 283 } 284 285 llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type, 286 llvm::Value* element_count, 287 absl::string_view name, 288 llvm::IRBuilder<>* b, 289 int alignment) { 290 llvm::IRBuilder<>::InsertPointGuard guard(*b); 291 llvm::Function* function = b->GetInsertBlock()->getParent(); 292 b->SetInsertPoint(&function->getEntryBlock(), 293 function->getEntryBlock().getFirstInsertionPt()); 294 llvm::AllocaInst* alloca = 295 b->CreateAlloca(type, element_count, AsStringRef(name)); 296 if (alignment != 0) { 297 alloca->setAlignment(alignment); 298 } 299 return alloca; 300 } 301 302 llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before, 303 absl::string_view name, 304 llvm::IRBuilder<>* b) { 305 return llvm::BasicBlock::Create( 306 /*Context=*/b->getContext(), 307 /*Name=*/AsStringRef(name), 308 /*Parent=*/b->GetInsertBlock()->getParent(), 309 /*InsertBefore*/ insert_before); 310 } 311 312 LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name, 313 llvm::IRBuilder<>* b, bool emit_else) { 314 llvm_ir::LlvmIfData if_data; 315 if_data.if_block = b->GetInsertBlock(); 316 if_data.true_block = 317 CreateBasicBlock(nullptr, absl::StrCat(name, "-true"), b); 318 if_data.false_block = 319 emit_else ? CreateBasicBlock(nullptr, absl::StrCat(name, "-false"), b) 320 : nullptr; 321 322 // Add a terminator to the if block, if necessary. 323 if (if_data.if_block->getTerminator() == nullptr) { 324 b->SetInsertPoint(if_data.if_block); 325 if_data.after_block = 326 CreateBasicBlock(nullptr, absl::StrCat(name, "-after"), b); 327 b->CreateBr(if_data.after_block); 328 } else { 329 if_data.after_block = if_data.if_block->splitBasicBlock( 330 b->GetInsertPoint(), absl::StrCat(name, "-after")); 331 } 332 333 // Our basic block should now end with an unconditional branch. Remove it; 334 // we're going to replace it with a conditional branch. 335 if_data.if_block->getTerminator()->eraseFromParent(); 336 337 b->SetInsertPoint(if_data.if_block); 338 b->CreateCondBr(condition, if_data.true_block, 339 emit_else ? if_data.false_block : if_data.after_block); 340 341 b->SetInsertPoint(if_data.true_block); 342 b->CreateBr(if_data.after_block); 343 344 if (emit_else) { 345 b->SetInsertPoint(if_data.false_block); 346 b->CreateBr(if_data.after_block); 347 } 348 349 b->SetInsertPoint(if_data.after_block, 350 if_data.after_block->getFirstInsertionPt()); 351 352 return if_data; 353 } 354 355 llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate, 356 llvm::Value* lhs_value, llvm::Value* rhs_value, 357 llvm::IRBuilder<>* b) { 358 llvm::Value* comparison_result; 359 if (lhs_value->getType()->isIntegerTy()) { 360 comparison_result = b->CreateICmp(predicate, lhs_value, rhs_value); 361 } else { 362 comparison_result = b->CreateFCmp(predicate, lhs_value, rhs_value); 363 } 364 // comparison_result is i1, but the NVPTX codegen incorrectly lowers i1 365 // arrays. So we extend it to i8 so that it's addressable. 366 return b->CreateZExt(comparison_result, llvm_ir::PrimitiveTypeToIrType( 367 PRED, ModuleFromIRBuilder(b))); 368 } 369 370 // Internal helper that is called from emitted code to log an int64 value with a 371 // tag. 372 static void LogS64(const char* tag, int64 value) { 373 LOG(INFO) << tag << " (int64): " << value; 374 } 375 376 void EmitLogging(const char* tag, llvm::Value* value, llvm::IRBuilder<>* b) { 377 llvm::FunctionType* log_function_type = llvm::FunctionType::get( 378 b->getVoidTy(), {b->getInt64Ty(), b->getInt64Ty()}, /*isVarArg=*/false); 379 b->CreateCall(log_function_type, 380 b->CreateIntToPtr(b->getInt64(absl::bit_cast<int64>(&LogS64)), 381 log_function_type->getPointerTo()), 382 {b->getInt64(absl::bit_cast<int64>(tag)), value}); 383 } 384 385 void SetAlignmentMetadataForLoad(llvm::LoadInst* load, uint64_t alignment) { 386 llvm::LLVMContext& context = load->getContext(); 387 llvm::Type* int64_ty = llvm::Type::getInt64Ty(context); 388 llvm::Constant* alignment_constant = 389 llvm::ConstantInt::get(int64_ty, alignment); 390 llvm::MDBuilder metadata_builder(context); 391 auto* alignment_metadata = 392 metadata_builder.createConstant(alignment_constant); 393 load->setMetadata(llvm::LLVMContext::MD_align, 394 llvm::MDNode::get(context, alignment_metadata)); 395 } 396 397 void SetDereferenceableMetadataForLoad(llvm::LoadInst* load, 398 uint64_t dereferenceable_bytes) { 399 llvm::LLVMContext& context = load->getContext(); 400 llvm::Type* int64_ty = llvm::Type::getInt64Ty(context); 401 llvm::Constant* dereferenceable_bytes_constant = 402 llvm::ConstantInt::get(int64_ty, dereferenceable_bytes); 403 llvm::MDBuilder metadata_builder(context); 404 auto* dereferenceable_bytes_metadata = 405 metadata_builder.createConstant(dereferenceable_bytes_constant); 406 load->setMetadata(llvm::LLVMContext::MD_dereferenceable, 407 llvm::MDNode::get(context, dereferenceable_bytes_metadata)); 408 } 409 410 llvm::Instruction* AddRangeMetadata(int64 lower, int64 upper, 411 llvm::Instruction* inst) { 412 llvm::LLVMContext& context = inst->getParent()->getContext(); 413 llvm::IntegerType* i32 = llvm::Type::getInt32Ty(context); 414 inst->setMetadata( 415 llvm::LLVMContext::MD_range, 416 llvm::MDNode::get( 417 context, 418 {llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(i32, lower)), 419 llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(i32, upper))})); 420 return inst; 421 } 422 423 string IrName(string a) { 424 a.erase(std::remove(a.begin(), a.end(), '%'), a.end()); 425 return a; 426 } 427 428 string IrName(absl::string_view a, absl::string_view b) { 429 if (!a.empty() && !b.empty()) { 430 return IrName(absl::StrCat(a, ".", b)); 431 } 432 return IrName(absl::StrCat(a, b)); 433 } 434 435 string IrName(const HloInstruction* a, absl::string_view b) { 436 return IrName(a->name(), b); 437 } 438 439 string SanitizeFunctionName(string function_name) { 440 // The backend with the strictest requirements on function names is NVPTX, so 441 // we sanitize to its requirements. 442 // 443 // A slightly stricter version of the NVPTX requirements is that names match 444 // /[a-zA-Z_$][a-zA-Z0-9_$]*/, with the exception that the names "_" and "$" 445 // are illegal. 446 447 // Sanitize chars in function_name. 448 std::transform(function_name.begin(), function_name.end(), 449 function_name.begin(), [](char c) { 450 if (('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || 451 ('0' <= c && c <= '9') || c == '_' || c == '$') { 452 return c; 453 } 454 return '_'; 455 }); 456 457 // Ensure the name isn't empty. 458 if (function_name.empty()) { 459 function_name = "__unnamed"; 460 } 461 462 // Ensure the name doesn't start with a number. 463 if (!function_name.empty() && function_name[0] >= '0' && 464 function_name[0] <= '9') { 465 function_name.insert(function_name.begin(), '_'); 466 } 467 468 // Ensure the name isn't "_" or "$". 469 if (function_name == "_" || function_name == "$") { 470 function_name += '_'; 471 } 472 473 return function_name; 474 } 475 476 void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) { 477 builder->SetInsertPoint(blk, blk->getFirstInsertionPt()); 478 } 479 480 void SetToLastInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) { 481 if (llvm::Instruction* terminator = blk->getTerminator()) { 482 builder->SetInsertPoint(terminator); 483 } else { 484 builder->SetInsertPoint(blk); 485 } 486 } 487 488 llvm::Value* CreateRor(llvm::Value* rotand, llvm::Value* rotor, 489 llvm::IRBuilder<>* builder) { 490 auto size = rotand->getType()->getPrimitiveSizeInBits(); 491 auto size_value = builder->getIntN(size, size); 492 auto mod = [=](llvm::Value* x) { return builder->CreateURem(x, size_value); }; 493 return builder->CreateOr( 494 builder->CreateShl(rotand, mod(builder->CreateSub(size_value, rotor))), 495 builder->CreateLShr(rotand, mod(rotor))); 496 } 497 498 int64 ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout) { 499 unsigned pointer_size = data_layout.getPointerSize(); 500 return ShapeUtil::ByteSizeOf(shape, pointer_size); 501 } 502 503 llvm::FastMathFlags GetCpuFastMathFlags(const HloModuleConfig& module_config) { 504 llvm::FastMathFlags flags; 505 if (!module_config.debug_options().xla_cpu_enable_fast_math()) { 506 return flags; 507 } 508 509 // Fast implies AllowReassoc, NoInfs, NoNaNs, NoSignedZeros, AllowReciprocal, 510 // AllowContract, and ApproxFunc. 511 flags.setFast(); 512 513 if (module_config.debug_options().xla_cpu_fast_math_honor_nans()) { 514 flags.setNoNaNs(false); 515 } 516 517 if (module_config.debug_options().xla_cpu_fast_math_honor_infs()) { 518 flags.setNoInfs(false); 519 } 520 521 return flags; 522 } 523 524 std::map<int, llvm::MDNode*> MergeMetadata( 525 llvm::LLVMContext* context, const std::map<int, llvm::MDNode*>& a, 526 const std::map<int, llvm::MDNode*>& b) { 527 // We should extend this as needed to deal with other kinds of metadata like 528 // !dereferenceable and !range. 529 530 std::map<int, llvm::MDNode*> result; 531 for (auto kind_md_pair : a) { 532 if (kind_md_pair.first == llvm::LLVMContext::MD_alias_scope) { 533 llvm::SmallVector<llvm::Metadata*, 8> union_of_scopes; 534 llvm::SmallPtrSet<llvm::Metadata*, 8> scope_set; 535 for (const auto& scope_a : kind_md_pair.second->operands()) { 536 scope_set.insert(llvm::cast<llvm::MDNode>(scope_a.get())); 537 union_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_a.get())); 538 } 539 auto it = b.find(kind_md_pair.first); 540 if (it != b.end()) { 541 for (const auto& scope_b : it->second->operands()) { 542 if (!scope_set.count(llvm::cast<llvm::MDNode>(scope_b.get()))) { 543 union_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_b.get())); 544 } 545 } 546 } 547 result[llvm::LLVMContext::MD_alias_scope] = 548 llvm::MDNode::get(*context, union_of_scopes); 549 } else if (kind_md_pair.first == llvm::LLVMContext::MD_noalias) { 550 llvm::SmallVector<llvm::Metadata*, 8> intersection_of_scopes; 551 llvm::SmallPtrSet<llvm::Metadata*, 8> scope_set; 552 for (const auto& scope_a : kind_md_pair.second->operands()) { 553 scope_set.insert(llvm::cast<llvm::MDNode>(scope_a.get())); 554 } 555 auto it = b.find(kind_md_pair.first); 556 if (it != b.end()) { 557 for (const auto& scope_b : it->second->operands()) { 558 if (scope_set.count(llvm::cast<llvm::MDNode>(scope_b))) { 559 intersection_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_b)); 560 } 561 } 562 } 563 if (!intersection_of_scopes.empty()) { 564 result[llvm::LLVMContext::MD_noalias] = 565 llvm::MDNode::get(*context, intersection_of_scopes); 566 } 567 } 568 } 569 return result; 570 } 571 572 static Status CreateAndWriteStringToFile(const string& directory_name, 573 const string& file_name, 574 const string& text) { 575 std::unique_ptr<tensorflow::WritableFile> f; 576 TF_RETURN_IF_ERROR( 577 tensorflow::Env::Default()->RecursivelyCreateDir(directory_name)); 578 TF_RETURN_IF_ERROR( 579 tensorflow::Env::Default()->NewWritableFile(file_name, &f)); 580 TF_RETURN_IF_ERROR(f->Append(text)); 581 TF_RETURN_IF_ERROR(f->Close()); 582 return Status::OK(); 583 } 584 585 void DumpIrIfEnabled(const HloModule& hlo_module, 586 const llvm::Module& llvm_module, bool optimized) { 587 const auto& debug_opts = hlo_module.config().debug_options(); 588 if (!DumpingEnabledForHloModule(hlo_module)) { 589 return; 590 } 591 // We can end up compiling different modules with the same name when using 592 // XlaJitCompiledCpuFunction::Compile. Avoid overwriting IR files previously 593 // dumped from the same process in such cases. 594 string suffix = absl::StrCat("ir-", optimized ? "with" : "no", "-opt"); 595 DumpToFileInDirOrStdout(hlo_module, absl::StrCat(suffix, ".ll"), 596 DumpModuleToString(llvm_module)); 597 598 // For some models the embedded constants can be huge, so also dump the module 599 // with the constants stripped to get IR that is easier to manipulate. Skip 600 // this if we're dumping to stdout; there's no point in duplicating everything 601 // when writing to the terminal. 602 if (!DumpingToStdout(debug_opts)) { 603 DumpToFileInDir(hlo_module, absl::StrCat(suffix, "-noconst.ll"), 604 DumpModuleToString(*DropConstantInitializers(llvm_module))); 605 } 606 } 607 608 llvm::Function* CreateCpuFunction(llvm::FunctionType* function_type, 609 llvm::GlobalValue::LinkageTypes linkage, 610 const HloModuleConfig& module_config, 611 absl::string_view name, 612 llvm::Module* module) { 613 llvm::Function* function = 614 llvm::Function::Create(function_type, linkage, AsStringRef(name), module); 615 function->setCallingConv(llvm::CallingConv::C); 616 function->addFnAttr("no-frame-pointer-elim", "false"); 617 618 // Generate unwind information so that GDB can crawl through the stack frames 619 // created by the JIT compiled code. 620 function->setHasUWTable(); 621 622 if (module_config.debug_options().xla_cpu_enable_fast_math()) { 623 function->addFnAttr("unsafe-fp-math", "true"); 624 function->addFnAttr("no-signed-zeros-fp-math", "true"); 625 626 if (!module_config.debug_options().xla_cpu_fast_math_honor_nans()) { 627 function->addFnAttr("no-nans-fp-math", "true"); 628 } 629 630 if (!module_config.debug_options().xla_cpu_fast_math_honor_infs()) { 631 function->addFnAttr("no-infs-fp-math", "true"); 632 } 633 } 634 635 // Add the optize attribute to the function if optimizing for size. This 636 // controls internal behavior of some optimization passes (e.g. loop 637 // unrolling). 638 if (cpu::options::OptimizeForSizeRequested(module_config)) { 639 function->addFnAttr(llvm::Attribute::OptimizeForSize); 640 } 641 642 return function; 643 } 644 645 void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) { 646 auto options = config.debug_options().xla_backend_extra_options(); 647 if (!options.empty()) { 648 std::vector<string> fake_argv_storage; 649 fake_argv_storage.push_back(""); 650 for (const auto& it : options) { 651 // Skip options the XLA backend itself consumes. 652 if (!absl::StartsWith(it.first, "xla_")) { 653 if (it.second.empty()) { 654 fake_argv_storage.push_back(it.first); 655 } else { 656 fake_argv_storage.push_back(it.first + "=" + it.second); 657 } 658 } 659 } 660 661 VLOG(2) << "Passing argv to LLVM:"; 662 std::vector<const char*> fake_argv; 663 for (const auto& s : fake_argv_storage) { 664 fake_argv.push_back(s.c_str()); 665 VLOG(2) << s; 666 } 667 llvm::cl::ParseCommandLineOptions(fake_argv.size(), &fake_argv[0]); 668 } 669 } 670 671 std::pair<llvm::Value*, llvm::Value*> UMulLowHigh32(llvm::IRBuilder<>* b, 672 llvm::Value* src0, 673 llvm::Value* src1) { 674 CHECK_EQ(src0->getType()->getPrimitiveSizeInBits(), 32); 675 CHECK_EQ(src1->getType()->getPrimitiveSizeInBits(), 32); 676 llvm::Type* int64_ty = b->getInt64Ty(); 677 src0 = b->CreateZExt(src0, int64_ty); 678 src1 = b->CreateZExt(src1, int64_ty); 679 return SplitInt64ToInt32s(b, b->CreateMul(src0, src1)); 680 } 681 682 std::pair<llvm::Value*, llvm::Value*> SplitInt64ToInt32s( 683 llvm::IRBuilder<>* b, llvm::Value* value_64bits) { 684 CHECK_EQ(value_64bits->getType()->getPrimitiveSizeInBits(), 64); 685 llvm::Type* int32_ty = b->getInt32Ty(); 686 llvm::Value* low_32bits = b->CreateTrunc(value_64bits, int32_ty); 687 llvm::Value* high_32bits = 688 b->CreateTrunc(b->CreateLShr(value_64bits, 32), int32_ty); 689 return std::make_pair(low_32bits, high_32bits); 690 } 691 692 llvm::GlobalVariable* GetOrCreateVariableForPhiloxRngState( 693 llvm::Module* module, llvm::IRBuilder<>* b) { 694 static const char* kPhiloxRngStateVariableName = "philox_rng_state"; 695 llvm::GlobalVariable* state_ptr = 696 module->getNamedGlobal(kPhiloxRngStateVariableName); 697 if (!state_ptr) { 698 state_ptr = new llvm::GlobalVariable( 699 /*M=*/*module, 700 /*Ty=*/b->getInt64Ty(), 701 /*isConstant=*/false, 702 /*Linkage=*/llvm::GlobalValue::PrivateLinkage, 703 /*Initializer=*/b->getInt64(0), 704 /*Name=*/kPhiloxRngStateVariableName); 705 } 706 return state_ptr; 707 } 708 709 void IncrementVariableForPhiloxRngState(int64 value, llvm::Module* module, 710 llvm::IRBuilder<>* builder) { 711 llvm::GlobalVariable* state_ptr = 712 GetOrCreateVariableForPhiloxRngState(module, builder); 713 llvm::Value* state_value_old = builder->CreateLoad(state_ptr, "load_state"); 714 // If the 64-bit value overflows, we use the wraparound value. This should 715 // be fine in practice as we only add one to the value each time when a RNG is 716 // executed. 717 llvm::Value* state_value_new = builder->CreateAdd( 718 state_value_old, builder->getInt64(value), "inc_state"); 719 builder->CreateStore(state_value_new, state_ptr); 720 } 721 722 } // namespace llvm_ir 723 } // namespace xla 724