1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" 17 18 #include <algorithm> 19 #include <memory> 20 #include <vector> 21 22 #include "llvm/IR/DerivedTypes.h" 23 #include "llvm/IR/GlobalValue.h" 24 #include "llvm/IR/MDBuilder.h" 25 #include "llvm/IR/Operator.h" 26 #include "llvm/Target/TargetOptions.h" 27 #include "llvm/Transforms/Utils/Cloning.h" 28 #include "tensorflow/compiler/xla/layout_util.h" 29 #include "tensorflow/compiler/xla/literal_util.h" 30 #include "tensorflow/compiler/xla/service/name_uniquer.h" 31 #include "tensorflow/compiler/xla/shape_util.h" 32 #include "tensorflow/compiler/xla/types.h" 33 #include "tensorflow/compiler/xla/util.h" 34 #include "tensorflow/core/lib/core/casts.h" 35 #include "tensorflow/core/lib/core/errors.h" 36 #include "tensorflow/core/lib/io/path.h" 37 #include "tensorflow/core/lib/strings/strcat.h" 38 #include "tensorflow/core/platform/env.h" 39 #include "tensorflow/core/platform/logging.h" 40 #include "tensorflow/core/platform/types.h" 41 42 namespace xla { 43 namespace llvm_ir { 44 45 namespace { 46 47 // Note, this function is only useful in an insertion context; in a global 48 // (e.g. constants) context it will CHECK fail. 49 llvm::Module* ModuleFromIRBuilder(llvm::IRBuilder<>* ir_builder) { 50 auto block = CHECK_NOTNULL(ir_builder->GetInsertBlock()); 51 auto fn = CHECK_NOTNULL(block->getParent()); 52 auto module = CHECK_NOTNULL(fn->getParent()); 53 return module; 54 } 55 56 } // namespace 57 58 string AsString(const std::string& str) { 59 return string(str.data(), str.length()); 60 } 61 62 llvm::StringRef AsStringRef(tensorflow::StringPiece str) { 63 return llvm::StringRef(str.data(), str.size()); 64 } 65 66 std::unique_ptr<llvm::Module> DropConstantInitializers( 67 const llvm::Module& module) { 68 std::unique_ptr<llvm::Module> cloned_module = CloneModule(module); 69 for (llvm::GlobalVariable& global_var : cloned_module->globals()) { 70 global_var.setInitializer(nullptr); 71 global_var.setLinkage(llvm::GlobalValue::LinkageTypes::ExternalLinkage); 72 } 73 return cloned_module; 74 } 75 76 string DumpModuleToString(const llvm::Module& module) { 77 std::string buffer_string; 78 llvm::raw_string_ostream ostream(buffer_string); 79 module.print(ostream, nullptr); 80 ostream.flush(); 81 return AsString(buffer_string); 82 } 83 84 llvm::Value* EmitCallToIntrinsic( 85 llvm::Intrinsic::ID intrinsic_id, 86 tensorflow::gtl::ArraySlice<llvm::Value*> operands, 87 tensorflow::gtl::ArraySlice<llvm::Type*> overloaded_types, 88 llvm::IRBuilder<>* ir_builder) { 89 std::vector<llvm::Type*> types; 90 for (auto type : overloaded_types) { 91 types.push_back(type); 92 } 93 llvm::Module* module = ModuleFromIRBuilder(ir_builder); 94 llvm::Function* intrinsic = 95 llvm::Intrinsic::getDeclaration(module, intrinsic_id, types); 96 std::vector<llvm::Value*> operands_vec; 97 for (auto operand : operands) { 98 operands_vec.push_back(operand); 99 } 100 return ir_builder->CreateCall(intrinsic, operands_vec); 101 } 102 103 llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value, 104 llvm::IRBuilder<>* ir_builder) { 105 if (ir_builder->getFastMathFlags().noNaNs()) { 106 auto cmp = ir_builder->CreateFCmpUGE(lhs_value, rhs_value); 107 return ir_builder->CreateSelect(cmp, lhs_value, rhs_value); 108 } else { 109 return EmitCallToIntrinsic(llvm::Intrinsic::maxnum, {lhs_value, rhs_value}, 110 {lhs_value->getType()}, ir_builder); 111 } 112 } 113 114 llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value, 115 llvm::IRBuilder<>* ir_builder) { 116 if (ir_builder->getFastMathFlags().noNaNs()) { 117 auto cmp = ir_builder->CreateFCmpULE(lhs_value, rhs_value); 118 return ir_builder->CreateSelect(cmp, lhs_value, rhs_value); 119 } else { 120 return EmitCallToIntrinsic(llvm::Intrinsic::minnum, {lhs_value, rhs_value}, 121 {lhs_value->getType()}, ir_builder); 122 } 123 } 124 125 llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Value* index, 126 llvm::IRBuilder<>* ir_builder) { 127 llvm::Type* array_type = array->getType(); 128 CHECK(array_type->isPointerTy()); 129 llvm::PointerType* array_type_as_pointer = 130 llvm::cast<llvm::PointerType>(array_type); 131 VLOG(2) << "EmitBufferIndexingGEP with type=" 132 << llvm_ir::DumpToString(*array_type) 133 << " array=" << llvm_ir::DumpToString(*array) 134 << " index=" << llvm_ir::DumpToString(*index); 135 136 return ir_builder->CreateInBoundsGEP( 137 array_type_as_pointer->getElementType(), array, 138 llvm::isa<llvm::GlobalVariable>(array) 139 ? llvm::ArrayRef<llvm::Value*>({ir_builder->getInt64(0), index}) 140 : index); 141 } 142 143 llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64 index, 144 llvm::IRBuilder<>* ir_builder) { 145 return EmitBufferIndexingGEP(array, ir_builder->getInt64(index), ir_builder); 146 } 147 148 llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, 149 llvm::Module* module) { 150 switch (element_type) { 151 case PRED: 152 case S8: 153 case U8: 154 return llvm::Type::getInt8Ty(module->getContext()); 155 case S16: 156 case U16: 157 case BF16: 158 // For BF16 we just need some type that is 16 bits wide so that it will 159 // take up the right amount of space in memory. LLVM does not have a BF16 160 // type (the LLVM half type is IEEE 16 bit floating point, not bfloat), so 161 // we can't map it directly to an LLVM type. We will not map a BF16 162 // addition to an addition on this type (int16) - this is just the type 163 // used for storage. 164 return llvm::Type::getInt16Ty(module->getContext()); 165 case F16: 166 return llvm::Type::getHalfTy(module->getContext()); 167 case S32: 168 case U32: 169 return llvm::Type::getInt32Ty(module->getContext()); 170 case S64: 171 case U64: 172 return llvm::Type::getInt64Ty(module->getContext()); 173 case F32: 174 return llvm::Type::getFloatTy(module->getContext()); 175 case F64: 176 return llvm::Type::getDoubleTy(module->getContext()); 177 case C64: { 178 auto cplx_t = module->getTypeByName("complex64"); 179 if (cplx_t == nullptr) { 180 // C++ standard dictates the memory layout of std::complex is contiguous 181 // real followed by imaginary. C++11 section 26.4 [complex.numbers]: 182 // If z is an lvalue expression of type cv std::complex<T> then the 183 // expression reinterpret_cast<cv T(&)[2]>(z) shall be well-formed, 184 // reinterpret_cast<cv T(&)[2]>(z)[0] shall designate the real part of 185 // z, and reinterpret_cast<cv T(&)[2]>(z)[1] shall designate the 186 // imaginary part of z. 187 return llvm::StructType::create( 188 {llvm::Type::getFloatTy(module->getContext()), 189 llvm::Type::getFloatTy(module->getContext())}, 190 "complex64", /*isPacked=*/true); 191 } 192 return cplx_t; 193 } 194 // A Tuple contains an array of pointers. Use i8*. 195 case TUPLE: 196 // An Opaque is like a void*, use i8*. 197 case OPAQUE: 198 return llvm::Type::getInt8PtrTy(module->getContext()); 199 default: 200 LOG(FATAL) << "unsupported type " << element_type; 201 } 202 } 203 204 int GetSizeInBits(llvm::Type* type) { 205 const llvm::StructType* struct_ty = llvm::dyn_cast<llvm::StructType>(type); 206 if (struct_ty) { 207 CHECK(struct_ty->isPacked()); 208 int bits = 0; 209 for (auto element_type : struct_ty->elements()) { 210 bits += GetSizeInBits(element_type); 211 } 212 return bits; 213 } 214 int bits = type->getPrimitiveSizeInBits(); 215 CHECK_GT(bits, 0) << "type is not sized"; 216 return bits; 217 } 218 219 llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module) { 220 llvm::Type* result_type = PrimitiveTypeToIrType(shape.element_type(), module); 221 if (ShapeUtil::IsTuple(shape)) { 222 // A tuple buffer is an array of pointers. 223 result_type = llvm::ArrayType::get(result_type, shape.tuple_shapes_size()); 224 } else if (ShapeUtil::IsArray(shape)) { 225 for (int64 dimension : LayoutUtil::MinorToMajor(shape)) { 226 result_type = 227 llvm::ArrayType::get(result_type, shape.dimensions(dimension)); 228 } 229 } 230 return result_type; 231 } 232 233 StatusOr<llvm::Value*> EncodeSelfDescribingShapeConstant( 234 const Shape& shape, int32* shape_size, llvm::IRBuilder<>* ir_builder) { 235 string encoded_shape = shape.SerializeAsString(); 236 if (encoded_shape.size() > std::numeric_limits<int32>::max()) { 237 return InternalError("Encoded shape size exceeded int32 size limit."); 238 } 239 *shape_size = static_cast<int32>(encoded_shape.size()); 240 return ir_builder->CreateGlobalStringPtr(llvm_ir::AsStringRef(encoded_shape)); 241 } 242 243 StatusOr<Shape> DecodeSelfDescribingShapeConstant(const void* shape_ptr, 244 int32 size_bytes) { 245 Shape shape; 246 TF_RET_CHECK(shape.ParseFromArray(shape_ptr, size_bytes)); 247 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape)); 248 return shape; 249 } 250 251 namespace { 252 253 // Recursively construct a multidimensional LLVM constant which represents the 254 // given literal. The minor-to-major dimension ordering in the constant matches 255 // that of the literal. For example, given a [2 x 3 x 4] Literal (dimension 0 256 // has size 4, dimension 1 has size 3, etc) of primitive type F32 with a 257 // minor_to_major value of [2, 1, 0] (column major), a LLVM constant of type 258 // [4 x [3 x [2 x float]] will be returned. 259 // 260 // multi_index is a multidimensional index into the array. dimension_index is an 261 // index into the minor_to_major field in the literal shape. This determines 262 // which dimension is iterated over in this level of the recursion. Dimensions 263 // are iterated from most major down to most minor (highest dimension_index 264 // value down to zero). 265 llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, 266 std::vector<int64>* multi_index, 267 llvm::Module* module) { 268 const Shape& shape = literal.shape(); 269 llvm::Type* ir_element_type = 270 llvm_ir::PrimitiveTypeToIrType(shape.element_type(), module); 271 if (dimension_index == -1) { 272 // Base case of the recursion. Index into the data field of the protobuf 273 // with the multi index. 274 llvm::Constant* value; 275 switch (shape.element_type()) { 276 case PRED: 277 value = llvm::ConstantInt::get(ir_element_type, 278 literal.Get<bool>(*multi_index)); 279 break; 280 case U8: 281 value = llvm::ConstantInt::get(ir_element_type, 282 literal.Get<uint8>(*multi_index)); 283 break; 284 case S32: 285 value = llvm::ConstantInt::get(ir_element_type, 286 literal.Get<int32>(*multi_index)); 287 break; 288 case U32: 289 value = llvm::ConstantInt::get(ir_element_type, 290 literal.Get<uint32>(*multi_index)); 291 break; 292 case S64: 293 value = llvm::ConstantInt::get(ir_element_type, 294 literal.Get<int64>(*multi_index)); 295 break; 296 case U64: 297 value = llvm::ConstantInt::get(ir_element_type, 298 literal.Get<uint64>(*multi_index)); 299 break; 300 case F32: 301 value = llvm::ConstantFP::get(ir_element_type, 302 literal.Get<float>(*multi_index)); 303 break; 304 case BF16: 305 value = llvm::ConstantInt::get( 306 ir_element_type, 307 tensorflow::bit_cast<uint16>(literal.Get<bfloat16>(*multi_index))); 308 break; 309 case F16: 310 value = llvm::ConstantFP::get( 311 ir_element_type, 312 static_cast<float>(literal.Get<half>(*multi_index))); 313 break; 314 case F64: 315 value = llvm::ConstantFP::get(ir_element_type, 316 literal.Get<double>(*multi_index)); 317 break; 318 case C64: { 319 complex64 x = literal.Get<complex64>(*multi_index); 320 value = llvm::ConstantStruct::get( 321 static_cast<llvm::StructType*>(ir_element_type), 322 llvm::ConstantFP::get(llvm_ir::PrimitiveTypeToIrType(F32, module), 323 x.real()), 324 llvm::ConstantFP::get(llvm_ir::PrimitiveTypeToIrType(F32, module), 325 x.imag())); 326 break; 327 } 328 default: 329 LOG(FATAL) << "unsupported type " << shape.element_type(); 330 } 331 return value; 332 } 333 334 // The dimension index starts at the one less than the rank of the array and 335 // decrements with each recursive call. We want to iterate through the 336 // dimensions in major-to-minor order as we recurse so just index into 337 // minor_to_major to get the dimension number for this level of the recursion. 338 int64 dimension = LayoutUtil::Minor(shape.layout(), dimension_index); 339 340 // Recursively call LiteralToConstant to construct subarrays for the 341 // more-minor dimensions. Gather the subarrays into a vector for bundling into 342 // a new (higher-dimensional) ConstantArray. 343 std::vector<llvm::Constant*> elements; 344 for (int64 i = 0; i < shape.dimensions(dimension); ++i) { 345 (*multi_index)[dimension] = i; 346 elements.push_back( 347 LiteralToConstant(literal, dimension_index - 1, multi_index, module)); 348 } 349 350 llvm::Type* element_type; 351 if (elements.empty()) { 352 element_type = ir_element_type; 353 for (int i = 0; i < dimension_index; ++i) { 354 int64 index = LayoutUtil::Minor(shape.layout(), i); 355 element_type = 356 llvm::ArrayType::get(element_type, shape.dimensions(index)); 357 } 358 } else { 359 element_type = elements[0]->getType(); 360 } 361 llvm::ArrayType* aggregate_type = 362 llvm::ArrayType::get(element_type, shape.dimensions(dimension)); 363 return llvm::ConstantArray::get(aggregate_type, elements); 364 } 365 366 } // namespace 367 368 llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, 369 llvm::Module* module) { 370 std::vector<int64> multi_index(ShapeUtil::Rank(literal.shape()), 0); 371 llvm::Constant* value = LiteralToConstant( 372 literal, /*dimension_index=*/ShapeUtil::Rank(literal.shape()) - 1, 373 &multi_index, module); 374 return value; 375 } 376 377 llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type, 378 tensorflow::StringPiece name, 379 llvm::IRBuilder<>* ir_builder, 380 int alignment) { 381 return EmitAllocaAtFunctionEntryWithCount(type, nullptr, name, ir_builder, 382 alignment); 383 } 384 385 llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount( 386 llvm::Type* type, llvm::Value* element_count, tensorflow::StringPiece name, 387 llvm::IRBuilder<>* ir_builder, int alignment) { 388 llvm::IRBuilder<>::InsertPoint insert_point = ir_builder->saveIP(); 389 llvm::Function* function = ir_builder->GetInsertBlock()->getParent(); 390 ir_builder->SetInsertPoint(&function->getEntryBlock(), 391 function->getEntryBlock().getFirstInsertionPt()); 392 llvm::AllocaInst* alloca = 393 ir_builder->CreateAlloca(type, element_count, AsStringRef(name)); 394 if (alignment != 0) { 395 alloca->setAlignment(alignment); 396 } 397 ir_builder->restoreIP(insert_point); 398 return alloca; 399 } 400 401 llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before, 402 tensorflow::StringPiece name, 403 llvm::IRBuilder<>* ir_builder) { 404 return llvm::BasicBlock::Create( 405 /*Context=*/ir_builder->getContext(), 406 /*Name=*/AsStringRef(name), 407 /*Parent=*/ir_builder->GetInsertBlock()->getParent(), 408 /*InsertBefore*/ insert_before); 409 } 410 411 LlvmIfData EmitIfThenElse(llvm::Value* condition, tensorflow::StringPiece name, 412 llvm::IRBuilder<>* ir_builder, bool emit_else) { 413 llvm_ir::LlvmIfData if_data; 414 if_data.if_block = ir_builder->GetInsertBlock(); 415 if_data.true_block = CreateBasicBlock( 416 nullptr, tensorflow::strings::StrCat(name, "-true"), ir_builder); 417 if_data.false_block = 418 emit_else ? CreateBasicBlock(nullptr, 419 tensorflow::strings::StrCat(name, "-false"), 420 ir_builder) 421 : nullptr; 422 423 // Add a terminator to the if block, if necessary. 424 if (if_data.if_block->getTerminator() == nullptr) { 425 ir_builder->SetInsertPoint(if_data.if_block); 426 if_data.after_block = CreateBasicBlock( 427 nullptr, tensorflow::strings::StrCat(name, "-after"), ir_builder); 428 ir_builder->CreateBr(if_data.after_block); 429 } else { 430 if_data.after_block = if_data.if_block->splitBasicBlock( 431 ir_builder->GetInsertPoint(), 432 AsStringRef(tensorflow::strings::StrCat(name, "-after"))); 433 } 434 435 // Our basic block should now end with an unconditional branch. Remove it; 436 // we're going to replace it with a conditional branch. 437 if_data.if_block->getTerminator()->eraseFromParent(); 438 439 ir_builder->SetInsertPoint(if_data.if_block); 440 ir_builder->CreateCondBr( 441 condition, if_data.true_block, 442 emit_else ? if_data.false_block : if_data.after_block); 443 444 ir_builder->SetInsertPoint(if_data.true_block); 445 ir_builder->CreateBr(if_data.after_block); 446 447 if (emit_else) { 448 ir_builder->SetInsertPoint(if_data.false_block); 449 ir_builder->CreateBr(if_data.after_block); 450 } 451 452 ir_builder->SetInsertPoint(if_data.after_block, 453 if_data.after_block->getFirstInsertionPt()); 454 455 return if_data; 456 } 457 458 llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate, 459 llvm::Value* lhs_value, llvm::Value* rhs_value, 460 llvm::IRBuilder<>* ir_builder) { 461 llvm::Value* comparison_result; 462 if (lhs_value->getType()->isIntegerTy()) { 463 comparison_result = ir_builder->CreateICmp(predicate, lhs_value, rhs_value); 464 } else { 465 comparison_result = ir_builder->CreateFCmp(predicate, lhs_value, rhs_value); 466 } 467 // comparison_result is i1, but the NVPTX codegen incorrectly lowers i1 468 // arrays. So we extend it to i8 so that it's addressable. 469 return ir_builder->CreateZExt( 470 comparison_result, 471 llvm_ir::PrimitiveTypeToIrType(PRED, ModuleFromIRBuilder(ir_builder))); 472 } 473 474 // Internal helper that is called from emitted code to log an int64 value with a 475 // tag. 476 static void LogS64(const char* tag, int64 value) { 477 LOG(INFO) << tag << " (int64): " << value; 478 } 479 480 void EmitLogging(const char* tag, llvm::Value* value, 481 llvm::IRBuilder<>* ir_builder) { 482 llvm::FunctionType* log_function_type = llvm::FunctionType::get( 483 ir_builder->getVoidTy(), 484 {ir_builder->getInt64Ty(), ir_builder->getInt64Ty()}, /*isVarArg=*/false); 485 ir_builder->CreateCall( 486 log_function_type, 487 ir_builder->CreateIntToPtr( 488 ir_builder->getInt64(tensorflow::bit_cast<int64>(&LogS64)), 489 log_function_type->getPointerTo()), 490 {ir_builder->getInt64(tensorflow::bit_cast<int64>(tag)), value}); 491 } 492 493 void SetAlignmentMetadataForLoad(llvm::LoadInst* load, uint64_t alignment) { 494 llvm::LLVMContext& context = load->getContext(); 495 llvm::Type* int64_ty = llvm::Type::getInt64Ty(context); 496 llvm::Constant* alignment_constant = 497 llvm::ConstantInt::get(int64_ty, alignment); 498 llvm::MDBuilder metadata_builder(context); 499 auto* alignment_metadata = 500 metadata_builder.createConstant(alignment_constant); 501 load->setMetadata(llvm::LLVMContext::MD_align, 502 llvm::MDNode::get(context, alignment_metadata)); 503 } 504 505 void SetDereferenceableMetadataForLoad(llvm::LoadInst* load, 506 uint64_t dereferenceable_bytes) { 507 llvm::LLVMContext& context = load->getContext(); 508 llvm::Type* int64_ty = llvm::Type::getInt64Ty(context); 509 llvm::Constant* dereferenceable_bytes_constant = 510 llvm::ConstantInt::get(int64_ty, dereferenceable_bytes); 511 llvm::MDBuilder metadata_builder(context); 512 auto* dereferenceable_bytes_metadata = 513 metadata_builder.createConstant(dereferenceable_bytes_constant); 514 load->setMetadata(llvm::LLVMContext::MD_dereferenceable, 515 llvm::MDNode::get(context, dereferenceable_bytes_metadata)); 516 } 517 518 llvm::Instruction* AddRangeMetadata(int64 lower, int64 upper, 519 llvm::Instruction* inst) { 520 llvm::LLVMContext& context = inst->getParent()->getContext(); 521 llvm::IntegerType* i32 = llvm::Type::getInt32Ty(context); 522 inst->setMetadata( 523 llvm::LLVMContext::MD_range, 524 llvm::MDNode::get( 525 context, 526 {llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(i32, lower)), 527 llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(i32, upper))})); 528 return inst; 529 } 530 531 string IrName(string a) { 532 a.erase(std::remove(a.begin(), a.end(), '%'), a.end()); 533 return a; 534 } 535 536 string IrName(tensorflow::StringPiece a, tensorflow::StringPiece b) { 537 if (!a.empty() && !b.empty()) { 538 return IrName(tensorflow::strings::StrCat(a, ".", b)); 539 } 540 return IrName(tensorflow::strings::StrCat(a, b)); 541 } 542 543 string IrName(const HloInstruction* a, tensorflow::StringPiece b) { 544 return IrName(a->name(), b); 545 } 546 547 string SanitizeFunctionName(string function_name) { 548 // The backend with the strictest requirements on function names is NVPTX, so 549 // we sanitize to its requirements. 550 // 551 // A slightly stricter version of the NVPTX requirements is that names match 552 // /[a-zA-Z_$][a-zA-Z0-9_$]*/, with the exception that the names "_" and "$" 553 // are illegal. 554 555 // Sanitize chars in function_name. 556 std::transform(function_name.begin(), function_name.end(), 557 function_name.begin(), [](char c) { 558 if (('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || 559 ('0' <= c && c <= '9') || c == '_' || c == '$') { 560 return c; 561 } 562 return '_'; 563 }); 564 565 // Ensure the name isn't empty. 566 if (function_name.empty()) { 567 function_name = "__unnamed"; 568 } 569 570 // Ensure the name doesn't start with a number. 571 if (!function_name.empty() && function_name[0] >= '0' && 572 function_name[0] <= '9') { 573 function_name.insert(function_name.begin(), '_'); 574 } 575 576 // Ensure the name isn't "_" or "$". 577 if (function_name == "_" || function_name == "$") { 578 function_name += '_'; 579 } 580 581 return function_name; 582 } 583 584 void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) { 585 builder->SetInsertPoint(blk, blk->getFirstInsertionPt()); 586 } 587 588 void SetToLastInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) { 589 if (llvm::Instruction* terminator = blk->getTerminator()) { 590 builder->SetInsertPoint(terminator); 591 } else { 592 builder->SetInsertPoint(blk); 593 } 594 } 595 596 llvm::Value* CreateRor(llvm::Value* rotand, llvm::Value* rotor, 597 llvm::IRBuilder<>* builder) { 598 auto size = rotand->getType()->getPrimitiveSizeInBits(); 599 auto size_value = builder->getIntN(size, size); 600 auto mod = [=](llvm::Value* x) { return builder->CreateURem(x, size_value); }; 601 return builder->CreateOr( 602 builder->CreateShl(rotand, mod(builder->CreateSub(size_value, rotor))), 603 builder->CreateLShr(rotand, mod(rotor))); 604 } 605 606 int64 ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout) { 607 unsigned pointer_size = data_layout.getPointerSize(); 608 return ShapeUtil::ByteSizeOf(shape, pointer_size); 609 } 610 611 llvm::FastMathFlags GetFastMathFlags(bool fast_math_enabled) { 612 llvm::FastMathFlags flags; 613 if (fast_math_enabled) { 614 // Fast implies AllowReassoc, NoInfs, NoNaNs, NoSignedZeros, 615 // AllowReciprocal, AllowContract, and ApproxFunc. 616 flags.setFast(); 617 } 618 return flags; 619 } 620 621 void SetTargetOptions(bool fast_math_enabled, 622 llvm::TargetOptions* target_options) { 623 // In LLVM backend flags, UnsafeFPMath does not explicitly imply 624 // NoInfs, etc. 625 target_options->UnsafeFPMath = fast_math_enabled; 626 target_options->NoInfsFPMath = fast_math_enabled; 627 target_options->NoNaNsFPMath = fast_math_enabled; 628 target_options->NoSignedZerosFPMath = fast_math_enabled; 629 } 630 631 std::map<int, llvm::MDNode*> MergeMetadata( 632 llvm::LLVMContext* context, const std::map<int, llvm::MDNode*>& a, 633 const std::map<int, llvm::MDNode*>& b) { 634 // We should extend this as needed to deal with other kinds of metadata like 635 // !dereferenceable and !range. 636 637 std::map<int, llvm::MDNode*> result; 638 for (auto kind_md_pair : a) { 639 if (kind_md_pair.first == llvm::LLVMContext::MD_alias_scope) { 640 llvm::SmallVector<llvm::Metadata*, 8> union_of_scopes; 641 llvm::SmallPtrSet<llvm::Metadata*, 8> scope_set; 642 for (const auto& scope_a : kind_md_pair.second->operands()) { 643 scope_set.insert(llvm::cast<llvm::MDNode>(scope_a.get())); 644 union_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_a.get())); 645 } 646 auto it = b.find(kind_md_pair.first); 647 if (it != b.end()) { 648 for (const auto& scope_b : it->second->operands()) { 649 if (!scope_set.count(llvm::cast<llvm::MDNode>(scope_b.get()))) { 650 union_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_b.get())); 651 } 652 } 653 } 654 result[llvm::LLVMContext::MD_alias_scope] = 655 llvm::MDNode::get(*context, union_of_scopes); 656 } else if (kind_md_pair.first == llvm::LLVMContext::MD_noalias) { 657 llvm::SmallVector<llvm::Metadata*, 8> intersection_of_scopes; 658 llvm::SmallPtrSet<llvm::Metadata*, 8> scope_set; 659 for (const auto& scope_a : kind_md_pair.second->operands()) { 660 scope_set.insert(llvm::cast<llvm::MDNode>(scope_a.get())); 661 } 662 auto it = b.find(kind_md_pair.first); 663 if (it != b.end()) { 664 for (const auto& scope_b : it->second->operands()) { 665 if (scope_set.count(llvm::cast<llvm::MDNode>(scope_b))) { 666 intersection_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_b)); 667 } 668 } 669 } 670 if (!intersection_of_scopes.empty()) { 671 result[llvm::LLVMContext::MD_noalias] = 672 llvm::MDNode::get(*context, intersection_of_scopes); 673 } 674 } 675 } 676 return result; 677 } 678 679 static string GetProcessUniqueIrFileName(tensorflow::StringPiece prefix) { 680 static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); 681 static NameUniquer* uniquer = new NameUniquer(/*separator=*/"-"); 682 683 tensorflow::mutex_lock lock(mu); 684 return uniquer->GetUniqueName(prefix); 685 } 686 687 static Status CreateAndWriteStringToFile(const string& directory_name, 688 const string& file_name, 689 const string& text) { 690 std::unique_ptr<tensorflow::WritableFile> f; 691 TF_RETURN_IF_ERROR( 692 tensorflow::Env::Default()->RecursivelyCreateDir(directory_name)); 693 TF_RETURN_IF_ERROR( 694 tensorflow::Env::Default()->NewWritableFile(file_name, &f)); 695 TF_RETURN_IF_ERROR(f->Append(text)); 696 TF_RETURN_IF_ERROR(f->Close()); 697 return Status::OK(); 698 } 699 700 Status DumpIRToDirectory(const string& directory_name, 701 const string& hlo_module_name, 702 const llvm::Module& llvm_module, bool optimized) { 703 // We can end up compiling different modules with the same name when using 704 // XlaJitCompiledCpuFunction::Compile. Avoid overwriting IR files previously 705 // dumped from the same process in such cases. 706 string unique_and_safe_file_name = GetProcessUniqueIrFileName( 707 tensorflow::strings::StrCat("ir-", SanitizeFileName(hlo_module_name), "-", 708 optimized ? "with" : "no", "-opt")); 709 710 string ir_file_name = tensorflow::io::JoinPath( 711 directory_name, 712 tensorflow::strings::StrCat(unique_and_safe_file_name, ".ll")); 713 714 // For some models the embedded constants can be huge, so also dump the module 715 // with the constants stripped to get IR that is easier to manipulate. 716 string ir_no_constant_initializers_file_name = tensorflow::io::JoinPath( 717 directory_name, 718 tensorflow::strings::StrCat(unique_and_safe_file_name, "-noconst.ll")); 719 720 TF_RETURN_IF_ERROR(CreateAndWriteStringToFile( 721 directory_name, ir_file_name, DumpModuleToString(llvm_module))); 722 return CreateAndWriteStringToFile( 723 directory_name, ir_no_constant_initializers_file_name, 724 DumpModuleToString(*DropConstantInitializers(llvm_module))); 725 } 726 727 llvm::Function* CreateFunction(llvm::FunctionType* function_type, 728 llvm::GlobalValue::LinkageTypes linkage, 729 bool enable_fast_math, bool optimize_for_size, 730 tensorflow::StringPiece name, 731 llvm::Module* module) { 732 llvm::Function* function = 733 llvm::Function::Create(function_type, linkage, AsStringRef(name), module); 734 function->setCallingConv(llvm::CallingConv::C); 735 function->addFnAttr("no-frame-pointer-elim", "false"); 736 737 if (enable_fast_math) { 738 function->addFnAttr("unsafe-fp-math", "true"); 739 function->addFnAttr("no-infs-fp-math", "true"); 740 function->addFnAttr("no-nans-fp-math", "true"); 741 function->addFnAttr("no-signed-zeros-fp-math", "true"); 742 } 743 744 // Add the optize attribute to the function if optimizing for size. This 745 // controls internal behavior of some optimization passes (e.g. loop 746 // unrolling). 747 if (optimize_for_size) { 748 function->addFnAttr(llvm::Attribute::OptimizeForSize); 749 } 750 751 return function; 752 } 753 754 void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) { 755 auto options = config.debug_options().xla_backend_extra_options(); 756 if (!options.empty()) { 757 std::vector<string> fake_argv_storage; 758 fake_argv_storage.push_back(""); 759 for (const auto& it : options) { 760 // Skip options the XLA backend itself consumes. 761 if (!tensorflow::StringPiece(it.first).starts_with("xla_")) { 762 if (it.second.empty()) { 763 fake_argv_storage.push_back(it.first); 764 } else { 765 fake_argv_storage.push_back(it.first + "=" + it.second); 766 } 767 } 768 } 769 770 VLOG(2) << "Passing argv to LLVM:"; 771 std::vector<const char*> fake_argv; 772 for (const auto& s : fake_argv_storage) { 773 fake_argv.push_back(s.c_str()); 774 VLOG(2) << s; 775 } 776 llvm::cl::ParseCommandLineOptions(fake_argv.size(), &fake_argv[0]); 777 } 778 } 779 780 } // namespace llvm_ir 781 } // namespace xla 782