Home | History | Annotate | Download | only in llvm_ir
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
     17 
     18 #include <algorithm>
     19 #include <memory>
     20 #include <vector>
     21 
     22 #include "llvm/IR/DerivedTypes.h"
     23 #include "llvm/IR/GlobalValue.h"
     24 #include "llvm/IR/MDBuilder.h"
     25 #include "llvm/IR/Operator.h"
     26 #include "llvm/Target/TargetOptions.h"
     27 #include "llvm/Transforms/Utils/Cloning.h"
     28 #include "tensorflow/compiler/xla/layout_util.h"
     29 #include "tensorflow/compiler/xla/literal_util.h"
     30 #include "tensorflow/compiler/xla/service/name_uniquer.h"
     31 #include "tensorflow/compiler/xla/shape_util.h"
     32 #include "tensorflow/compiler/xla/types.h"
     33 #include "tensorflow/compiler/xla/util.h"
     34 #include "tensorflow/core/lib/core/casts.h"
     35 #include "tensorflow/core/lib/core/errors.h"
     36 #include "tensorflow/core/lib/io/path.h"
     37 #include "tensorflow/core/lib/strings/strcat.h"
     38 #include "tensorflow/core/platform/env.h"
     39 #include "tensorflow/core/platform/logging.h"
     40 #include "tensorflow/core/platform/types.h"
     41 
     42 namespace xla {
     43 namespace llvm_ir {
     44 
     45 namespace {
     46 
     47 // Note, this function is only useful in an insertion context; in a global
     48 // (e.g. constants) context it will CHECK fail.
     49 llvm::Module* ModuleFromIRBuilder(llvm::IRBuilder<>* ir_builder) {
     50   auto block = CHECK_NOTNULL(ir_builder->GetInsertBlock());
     51   auto fn = CHECK_NOTNULL(block->getParent());
     52   auto module = CHECK_NOTNULL(fn->getParent());
     53   return module;
     54 }
     55 
     56 }  // namespace
     57 
     58 string AsString(const std::string& str) {
     59   return string(str.data(), str.length());
     60 }
     61 
     62 llvm::StringRef AsStringRef(tensorflow::StringPiece str) {
     63   return llvm::StringRef(str.data(), str.size());
     64 }
     65 
     66 std::unique_ptr<llvm::Module> DropConstantInitializers(
     67     const llvm::Module& module) {
     68   std::unique_ptr<llvm::Module> cloned_module = CloneModule(module);
     69   for (llvm::GlobalVariable& global_var : cloned_module->globals()) {
     70     global_var.setInitializer(nullptr);
     71     global_var.setLinkage(llvm::GlobalValue::LinkageTypes::ExternalLinkage);
     72   }
     73   return cloned_module;
     74 }
     75 
     76 string DumpModuleToString(const llvm::Module& module) {
     77   std::string buffer_string;
     78   llvm::raw_string_ostream ostream(buffer_string);
     79   module.print(ostream, nullptr);
     80   ostream.flush();
     81   return AsString(buffer_string);
     82 }
     83 
     84 llvm::Value* EmitCallToIntrinsic(
     85     llvm::Intrinsic::ID intrinsic_id,
     86     tensorflow::gtl::ArraySlice<llvm::Value*> operands,
     87     tensorflow::gtl::ArraySlice<llvm::Type*> overloaded_types,
     88     llvm::IRBuilder<>* ir_builder) {
     89   std::vector<llvm::Type*> types;
     90   for (auto type : overloaded_types) {
     91     types.push_back(type);
     92   }
     93   llvm::Module* module = ModuleFromIRBuilder(ir_builder);
     94   llvm::Function* intrinsic =
     95       llvm::Intrinsic::getDeclaration(module, intrinsic_id, types);
     96   std::vector<llvm::Value*> operands_vec;
     97   for (auto operand : operands) {
     98     operands_vec.push_back(operand);
     99   }
    100   return ir_builder->CreateCall(intrinsic, operands_vec);
    101 }
    102 
    103 llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
    104                           llvm::IRBuilder<>* ir_builder) {
    105   if (ir_builder->getFastMathFlags().noNaNs()) {
    106     auto cmp = ir_builder->CreateFCmpUGE(lhs_value, rhs_value);
    107     return ir_builder->CreateSelect(cmp, lhs_value, rhs_value);
    108   } else {
    109     return EmitCallToIntrinsic(llvm::Intrinsic::maxnum, {lhs_value, rhs_value},
    110                                {lhs_value->getType()}, ir_builder);
    111   }
    112 }
    113 
    114 llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
    115                           llvm::IRBuilder<>* ir_builder) {
    116   if (ir_builder->getFastMathFlags().noNaNs()) {
    117     auto cmp = ir_builder->CreateFCmpULE(lhs_value, rhs_value);
    118     return ir_builder->CreateSelect(cmp, lhs_value, rhs_value);
    119   } else {
    120     return EmitCallToIntrinsic(llvm::Intrinsic::minnum, {lhs_value, rhs_value},
    121                                {lhs_value->getType()}, ir_builder);
    122   }
    123 }
    124 
    125 llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Value* index,
    126                                    llvm::IRBuilder<>* ir_builder) {
    127   llvm::Type* array_type = array->getType();
    128   CHECK(array_type->isPointerTy());
    129   llvm::PointerType* array_type_as_pointer =
    130       llvm::cast<llvm::PointerType>(array_type);
    131   VLOG(2) << "EmitBufferIndexingGEP with type="
    132           << llvm_ir::DumpToString(*array_type)
    133           << " array=" << llvm_ir::DumpToString(*array)
    134           << " index=" << llvm_ir::DumpToString(*index);
    135 
    136   return ir_builder->CreateInBoundsGEP(
    137       array_type_as_pointer->getElementType(), array,
    138       llvm::isa<llvm::GlobalVariable>(array)
    139           ? llvm::ArrayRef<llvm::Value*>({ir_builder->getInt64(0), index})
    140           : index);
    141 }
    142 
    143 llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64 index,
    144                                    llvm::IRBuilder<>* ir_builder) {
    145   return EmitBufferIndexingGEP(array, ir_builder->getInt64(index), ir_builder);
    146 }
    147 
    148 llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
    149                                   llvm::Module* module) {
    150   switch (element_type) {
    151     case PRED:
    152     case S8:
    153     case U8:
    154       return llvm::Type::getInt8Ty(module->getContext());
    155     case S16:
    156     case U16:
    157     case BF16:
    158       // For BF16 we just need some type that is 16 bits wide so that it will
    159       // take up the right amount of space in memory. LLVM does not have a BF16
    160       // type (the LLVM half type is IEEE 16 bit floating point, not bfloat), so
    161       // we can't map it directly to an LLVM type. We will not map a BF16
    162       // addition to an addition on this type (int16) - this is just the type
    163       // used for storage.
    164       return llvm::Type::getInt16Ty(module->getContext());
    165     case F16:
    166       return llvm::Type::getHalfTy(module->getContext());
    167     case S32:
    168     case U32:
    169       return llvm::Type::getInt32Ty(module->getContext());
    170     case S64:
    171     case U64:
    172       return llvm::Type::getInt64Ty(module->getContext());
    173     case F32:
    174       return llvm::Type::getFloatTy(module->getContext());
    175     case F64:
    176       return llvm::Type::getDoubleTy(module->getContext());
    177     case C64: {
    178       auto cplx_t = module->getTypeByName("complex64");
    179       if (cplx_t == nullptr) {
    180         // C++ standard dictates the memory layout of std::complex is contiguous
    181         // real followed by imaginary. C++11 section 26.4 [complex.numbers]:
    182         // If z is an lvalue expression of type cv std::complex<T> then the
    183         // expression reinterpret_cast<cv T(&)[2]>(z) shall be well-formed,
    184         // reinterpret_cast<cv T(&)[2]>(z)[0] shall designate the real part of
    185         // z, and reinterpret_cast<cv T(&)[2]>(z)[1] shall designate the
    186         // imaginary part of z.
    187         return llvm::StructType::create(
    188             {llvm::Type::getFloatTy(module->getContext()),
    189              llvm::Type::getFloatTy(module->getContext())},
    190             "complex64", /*isPacked=*/true);
    191       }
    192       return cplx_t;
    193     }
    194     // A Tuple contains an array of pointers. Use i8*.
    195     case TUPLE:
    196     // An Opaque is like a void*, use i8*.
    197     case OPAQUE:
    198       return llvm::Type::getInt8PtrTy(module->getContext());
    199     default:
    200       LOG(FATAL) << "unsupported type " << element_type;
    201   }
    202 }
    203 
    204 int GetSizeInBits(llvm::Type* type) {
    205   const llvm::StructType* struct_ty = llvm::dyn_cast<llvm::StructType>(type);
    206   if (struct_ty) {
    207     CHECK(struct_ty->isPacked());
    208     int bits = 0;
    209     for (auto element_type : struct_ty->elements()) {
    210       bits += GetSizeInBits(element_type);
    211     }
    212     return bits;
    213   }
    214   int bits = type->getPrimitiveSizeInBits();
    215   CHECK_GT(bits, 0) << "type is not sized";
    216   return bits;
    217 }
    218 
    219 llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module) {
    220   llvm::Type* result_type = PrimitiveTypeToIrType(shape.element_type(), module);
    221   if (ShapeUtil::IsTuple(shape)) {
    222     // A tuple buffer is an array of pointers.
    223     result_type = llvm::ArrayType::get(result_type, shape.tuple_shapes_size());
    224   } else if (ShapeUtil::IsArray(shape)) {
    225     for (int64 dimension : LayoutUtil::MinorToMajor(shape)) {
    226       result_type =
    227           llvm::ArrayType::get(result_type, shape.dimensions(dimension));
    228     }
    229   }
    230   return result_type;
    231 }
    232 
    233 StatusOr<llvm::Value*> EncodeSelfDescribingShapeConstant(
    234     const Shape& shape, int32* shape_size, llvm::IRBuilder<>* ir_builder) {
    235   string encoded_shape = shape.SerializeAsString();
    236   if (encoded_shape.size() > std::numeric_limits<int32>::max()) {
    237     return InternalError("Encoded shape size exceeded int32 size limit.");
    238   }
    239   *shape_size = static_cast<int32>(encoded_shape.size());
    240   return ir_builder->CreateGlobalStringPtr(llvm_ir::AsStringRef(encoded_shape));
    241 }
    242 
    243 StatusOr<Shape> DecodeSelfDescribingShapeConstant(const void* shape_ptr,
    244                                                   int32 size_bytes) {
    245   Shape shape;
    246   TF_RET_CHECK(shape.ParseFromArray(shape_ptr, size_bytes));
    247   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape));
    248   return shape;
    249 }
    250 
    251 namespace {
    252 
    253 // Recursively construct a multidimensional LLVM constant which represents the
    254 // given literal. The minor-to-major dimension ordering in the constant matches
    255 // that of the literal. For example, given a [2 x 3 x 4] Literal (dimension 0
    256 // has size 4, dimension 1 has size 3, etc) of primitive type F32 with a
    257 // minor_to_major value of [2, 1, 0] (column major), a LLVM constant of type
    258 // [4 x [3 x [2 x float]] will be returned.
    259 //
    260 // multi_index is a multidimensional index into the array. dimension_index is an
    261 // index into the minor_to_major field in the literal shape. This determines
    262 // which dimension is iterated over in this level of the recursion. Dimensions
    263 // are iterated from most major down to most minor (highest dimension_index
    264 // value down to zero).
    265 llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index,
    266                                   std::vector<int64>* multi_index,
    267                                   llvm::Module* module) {
    268   const Shape& shape = literal.shape();
    269   llvm::Type* ir_element_type =
    270       llvm_ir::PrimitiveTypeToIrType(shape.element_type(), module);
    271   if (dimension_index == -1) {
    272     // Base case of the recursion. Index into the data field of the protobuf
    273     // with the multi index.
    274     llvm::Constant* value;
    275     switch (shape.element_type()) {
    276       case PRED:
    277         value = llvm::ConstantInt::get(ir_element_type,
    278                                        literal.Get<bool>(*multi_index));
    279         break;
    280       case U8:
    281         value = llvm::ConstantInt::get(ir_element_type,
    282                                        literal.Get<uint8>(*multi_index));
    283         break;
    284       case S32:
    285         value = llvm::ConstantInt::get(ir_element_type,
    286                                        literal.Get<int32>(*multi_index));
    287         break;
    288       case U32:
    289         value = llvm::ConstantInt::get(ir_element_type,
    290                                        literal.Get<uint32>(*multi_index));
    291         break;
    292       case S64:
    293         value = llvm::ConstantInt::get(ir_element_type,
    294                                        literal.Get<int64>(*multi_index));
    295         break;
    296       case U64:
    297         value = llvm::ConstantInt::get(ir_element_type,
    298                                        literal.Get<uint64>(*multi_index));
    299         break;
    300       case F32:
    301         value = llvm::ConstantFP::get(ir_element_type,
    302                                       literal.Get<float>(*multi_index));
    303         break;
    304       case BF16:
    305         value = llvm::ConstantInt::get(
    306             ir_element_type,
    307             tensorflow::bit_cast<uint16>(literal.Get<bfloat16>(*multi_index)));
    308         break;
    309       case F16:
    310         value = llvm::ConstantFP::get(
    311             ir_element_type,
    312             static_cast<float>(literal.Get<half>(*multi_index)));
    313         break;
    314       case F64:
    315         value = llvm::ConstantFP::get(ir_element_type,
    316                                       literal.Get<double>(*multi_index));
    317         break;
    318       case C64: {
    319         complex64 x = literal.Get<complex64>(*multi_index);
    320         value = llvm::ConstantStruct::get(
    321             static_cast<llvm::StructType*>(ir_element_type),
    322             llvm::ConstantFP::get(llvm_ir::PrimitiveTypeToIrType(F32, module),
    323                                   x.real()),
    324             llvm::ConstantFP::get(llvm_ir::PrimitiveTypeToIrType(F32, module),
    325                                   x.imag()));
    326         break;
    327       }
    328       default:
    329         LOG(FATAL) << "unsupported type " << shape.element_type();
    330     }
    331     return value;
    332   }
    333 
    334   // The dimension index starts at the one less than the rank of the array and
    335   // decrements with each recursive call. We want to iterate through the
    336   // dimensions in major-to-minor order as we recurse so just index into
    337   // minor_to_major to get the dimension number for this level of the recursion.
    338   int64 dimension = LayoutUtil::Minor(shape.layout(), dimension_index);
    339 
    340   // Recursively call LiteralToConstant to construct subarrays for the
    341   // more-minor dimensions. Gather the subarrays into a vector for bundling into
    342   // a new (higher-dimensional) ConstantArray.
    343   std::vector<llvm::Constant*> elements;
    344   for (int64 i = 0; i < shape.dimensions(dimension); ++i) {
    345     (*multi_index)[dimension] = i;
    346     elements.push_back(
    347         LiteralToConstant(literal, dimension_index - 1, multi_index, module));
    348   }
    349 
    350   llvm::Type* element_type;
    351   if (elements.empty()) {
    352     element_type = ir_element_type;
    353     for (int i = 0; i < dimension_index; ++i) {
    354       int64 index = LayoutUtil::Minor(shape.layout(), i);
    355       element_type =
    356           llvm::ArrayType::get(element_type, shape.dimensions(index));
    357     }
    358   } else {
    359     element_type = elements[0]->getType();
    360   }
    361   llvm::ArrayType* aggregate_type =
    362       llvm::ArrayType::get(element_type, shape.dimensions(dimension));
    363   return llvm::ConstantArray::get(aggregate_type, elements);
    364 }
    365 
    366 }  // namespace
    367 
    368 llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
    369                                            llvm::Module* module) {
    370   std::vector<int64> multi_index(ShapeUtil::Rank(literal.shape()), 0);
    371   llvm::Constant* value = LiteralToConstant(
    372       literal, /*dimension_index=*/ShapeUtil::Rank(literal.shape()) - 1,
    373       &multi_index, module);
    374   return value;
    375 }
    376 
    377 llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type,
    378                                             tensorflow::StringPiece name,
    379                                             llvm::IRBuilder<>* ir_builder,
    380                                             int alignment) {
    381   return EmitAllocaAtFunctionEntryWithCount(type, nullptr, name, ir_builder,
    382                                             alignment);
    383 }
    384 
    385 llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(
    386     llvm::Type* type, llvm::Value* element_count, tensorflow::StringPiece name,
    387     llvm::IRBuilder<>* ir_builder, int alignment) {
    388   llvm::IRBuilder<>::InsertPoint insert_point = ir_builder->saveIP();
    389   llvm::Function* function = ir_builder->GetInsertBlock()->getParent();
    390   ir_builder->SetInsertPoint(&function->getEntryBlock(),
    391                              function->getEntryBlock().getFirstInsertionPt());
    392   llvm::AllocaInst* alloca =
    393       ir_builder->CreateAlloca(type, element_count, AsStringRef(name));
    394   if (alignment != 0) {
    395     alloca->setAlignment(alignment);
    396   }
    397   ir_builder->restoreIP(insert_point);
    398   return alloca;
    399 }
    400 
    401 llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before,
    402                                    tensorflow::StringPiece name,
    403                                    llvm::IRBuilder<>* ir_builder) {
    404   return llvm::BasicBlock::Create(
    405       /*Context=*/ir_builder->getContext(),
    406       /*Name=*/AsStringRef(name),
    407       /*Parent=*/ir_builder->GetInsertBlock()->getParent(),
    408       /*InsertBefore*/ insert_before);
    409 }
    410 
    411 LlvmIfData EmitIfThenElse(llvm::Value* condition, tensorflow::StringPiece name,
    412                           llvm::IRBuilder<>* ir_builder, bool emit_else) {
    413   llvm_ir::LlvmIfData if_data;
    414   if_data.if_block = ir_builder->GetInsertBlock();
    415   if_data.true_block = CreateBasicBlock(
    416       nullptr, tensorflow::strings::StrCat(name, "-true"), ir_builder);
    417   if_data.false_block =
    418       emit_else ? CreateBasicBlock(nullptr,
    419                                    tensorflow::strings::StrCat(name, "-false"),
    420                                    ir_builder)
    421                 : nullptr;
    422 
    423   // Add a terminator to the if block, if necessary.
    424   if (if_data.if_block->getTerminator() == nullptr) {
    425     ir_builder->SetInsertPoint(if_data.if_block);
    426     if_data.after_block = CreateBasicBlock(
    427         nullptr, tensorflow::strings::StrCat(name, "-after"), ir_builder);
    428     ir_builder->CreateBr(if_data.after_block);
    429   } else {
    430     if_data.after_block = if_data.if_block->splitBasicBlock(
    431         ir_builder->GetInsertPoint(),
    432         AsStringRef(tensorflow::strings::StrCat(name, "-after")));
    433   }
    434 
    435   // Our basic block should now end with an unconditional branch.  Remove it;
    436   // we're going to replace it with a conditional branch.
    437   if_data.if_block->getTerminator()->eraseFromParent();
    438 
    439   ir_builder->SetInsertPoint(if_data.if_block);
    440   ir_builder->CreateCondBr(
    441       condition, if_data.true_block,
    442       emit_else ? if_data.false_block : if_data.after_block);
    443 
    444   ir_builder->SetInsertPoint(if_data.true_block);
    445   ir_builder->CreateBr(if_data.after_block);
    446 
    447   if (emit_else) {
    448     ir_builder->SetInsertPoint(if_data.false_block);
    449     ir_builder->CreateBr(if_data.after_block);
    450   }
    451 
    452   ir_builder->SetInsertPoint(if_data.after_block,
    453                              if_data.after_block->getFirstInsertionPt());
    454 
    455   return if_data;
    456 }
    457 
    458 llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate,
    459                             llvm::Value* lhs_value, llvm::Value* rhs_value,
    460                             llvm::IRBuilder<>* ir_builder) {
    461   llvm::Value* comparison_result;
    462   if (lhs_value->getType()->isIntegerTy()) {
    463     comparison_result = ir_builder->CreateICmp(predicate, lhs_value, rhs_value);
    464   } else {
    465     comparison_result = ir_builder->CreateFCmp(predicate, lhs_value, rhs_value);
    466   }
    467   // comparison_result is i1, but the NVPTX codegen incorrectly lowers i1
    468   // arrays. So we extend it to i8 so that it's addressable.
    469   return ir_builder->CreateZExt(
    470       comparison_result,
    471       llvm_ir::PrimitiveTypeToIrType(PRED, ModuleFromIRBuilder(ir_builder)));
    472 }
    473 
    474 // Internal helper that is called from emitted code to log an int64 value with a
    475 // tag.
    476 static void LogS64(const char* tag, int64 value) {
    477   LOG(INFO) << tag << " (int64): " << value;
    478 }
    479 
    480 void EmitLogging(const char* tag, llvm::Value* value,
    481                  llvm::IRBuilder<>* ir_builder) {
    482   llvm::FunctionType* log_function_type = llvm::FunctionType::get(
    483       ir_builder->getVoidTy(),
    484       {ir_builder->getInt64Ty(), ir_builder->getInt64Ty()}, /*isVarArg=*/false);
    485   ir_builder->CreateCall(
    486       log_function_type,
    487       ir_builder->CreateIntToPtr(
    488           ir_builder->getInt64(tensorflow::bit_cast<int64>(&LogS64)),
    489           log_function_type->getPointerTo()),
    490       {ir_builder->getInt64(tensorflow::bit_cast<int64>(tag)), value});
    491 }
    492 
    493 void SetAlignmentMetadataForLoad(llvm::LoadInst* load, uint64_t alignment) {
    494   llvm::LLVMContext& context = load->getContext();
    495   llvm::Type* int64_ty = llvm::Type::getInt64Ty(context);
    496   llvm::Constant* alignment_constant =
    497       llvm::ConstantInt::get(int64_ty, alignment);
    498   llvm::MDBuilder metadata_builder(context);
    499   auto* alignment_metadata =
    500       metadata_builder.createConstant(alignment_constant);
    501   load->setMetadata(llvm::LLVMContext::MD_align,
    502                     llvm::MDNode::get(context, alignment_metadata));
    503 }
    504 
    505 void SetDereferenceableMetadataForLoad(llvm::LoadInst* load,
    506                                        uint64_t dereferenceable_bytes) {
    507   llvm::LLVMContext& context = load->getContext();
    508   llvm::Type* int64_ty = llvm::Type::getInt64Ty(context);
    509   llvm::Constant* dereferenceable_bytes_constant =
    510       llvm::ConstantInt::get(int64_ty, dereferenceable_bytes);
    511   llvm::MDBuilder metadata_builder(context);
    512   auto* dereferenceable_bytes_metadata =
    513       metadata_builder.createConstant(dereferenceable_bytes_constant);
    514   load->setMetadata(llvm::LLVMContext::MD_dereferenceable,
    515                     llvm::MDNode::get(context, dereferenceable_bytes_metadata));
    516 }
    517 
    518 llvm::Instruction* AddRangeMetadata(int64 lower, int64 upper,
    519                                     llvm::Instruction* inst) {
    520   llvm::LLVMContext& context = inst->getParent()->getContext();
    521   llvm::IntegerType* i32 = llvm::Type::getInt32Ty(context);
    522   inst->setMetadata(
    523       llvm::LLVMContext::MD_range,
    524       llvm::MDNode::get(
    525           context,
    526           {llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(i32, lower)),
    527            llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(i32, upper))}));
    528   return inst;
    529 }
    530 
    531 string IrName(string a) {
    532   a.erase(std::remove(a.begin(), a.end(), '%'), a.end());
    533   return a;
    534 }
    535 
    536 string IrName(tensorflow::StringPiece a, tensorflow::StringPiece b) {
    537   if (!a.empty() && !b.empty()) {
    538     return IrName(tensorflow::strings::StrCat(a, ".", b));
    539   }
    540   return IrName(tensorflow::strings::StrCat(a, b));
    541 }
    542 
    543 string IrName(const HloInstruction* a, tensorflow::StringPiece b) {
    544   return IrName(a->name(), b);
    545 }
    546 
    547 string SanitizeFunctionName(string function_name) {
    548   // The backend with the strictest requirements on function names is NVPTX, so
    549   // we sanitize to its requirements.
    550   //
    551   // A slightly stricter version of the NVPTX requirements is that names match
    552   // /[a-zA-Z_$][a-zA-Z0-9_$]*/, with the exception that the names "_" and "$"
    553   // are illegal.
    554 
    555   // Sanitize chars in function_name.
    556   std::transform(function_name.begin(), function_name.end(),
    557                  function_name.begin(), [](char c) {
    558                    if (('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') ||
    559                        ('0' <= c && c <= '9') || c == '_' || c == '$') {
    560                      return c;
    561                    }
    562                    return '_';
    563                  });
    564 
    565   // Ensure the name isn't empty.
    566   if (function_name.empty()) {
    567     function_name = "__unnamed";
    568   }
    569 
    570   // Ensure the name doesn't start with a number.
    571   if (!function_name.empty() && function_name[0] >= '0' &&
    572       function_name[0] <= '9') {
    573     function_name.insert(function_name.begin(), '_');
    574   }
    575 
    576   // Ensure the name isn't "_" or "$".
    577   if (function_name == "_" || function_name == "$") {
    578     function_name += '_';
    579   }
    580 
    581   return function_name;
    582 }
    583 
    584 void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) {
    585   builder->SetInsertPoint(blk, blk->getFirstInsertionPt());
    586 }
    587 
    588 void SetToLastInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) {
    589   if (llvm::Instruction* terminator = blk->getTerminator()) {
    590     builder->SetInsertPoint(terminator);
    591   } else {
    592     builder->SetInsertPoint(blk);
    593   }
    594 }
    595 
    596 llvm::Value* CreateRor(llvm::Value* rotand, llvm::Value* rotor,
    597                        llvm::IRBuilder<>* builder) {
    598   auto size = rotand->getType()->getPrimitiveSizeInBits();
    599   auto size_value = builder->getIntN(size, size);
    600   auto mod = [=](llvm::Value* x) { return builder->CreateURem(x, size_value); };
    601   return builder->CreateOr(
    602       builder->CreateShl(rotand, mod(builder->CreateSub(size_value, rotor))),
    603       builder->CreateLShr(rotand, mod(rotor)));
    604 }
    605 
    606 int64 ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout) {
    607   unsigned pointer_size = data_layout.getPointerSize();
    608   return ShapeUtil::ByteSizeOf(shape, pointer_size);
    609 }
    610 
    611 llvm::FastMathFlags GetFastMathFlags(bool fast_math_enabled) {
    612   llvm::FastMathFlags flags;
    613   if (fast_math_enabled) {
    614     // Fast implies AllowReassoc, NoInfs, NoNaNs, NoSignedZeros,
    615     // AllowReciprocal, AllowContract, and ApproxFunc.
    616     flags.setFast();
    617   }
    618   return flags;
    619 }
    620 
    621 void SetTargetOptions(bool fast_math_enabled,
    622                       llvm::TargetOptions* target_options) {
    623   // In LLVM backend flags, UnsafeFPMath does not explicitly imply
    624   // NoInfs, etc.
    625   target_options->UnsafeFPMath = fast_math_enabled;
    626   target_options->NoInfsFPMath = fast_math_enabled;
    627   target_options->NoNaNsFPMath = fast_math_enabled;
    628   target_options->NoSignedZerosFPMath = fast_math_enabled;
    629 }
    630 
    631 std::map<int, llvm::MDNode*> MergeMetadata(
    632     llvm::LLVMContext* context, const std::map<int, llvm::MDNode*>& a,
    633     const std::map<int, llvm::MDNode*>& b) {
    634   // We should extend this as needed to deal with other kinds of metadata like
    635   // !dereferenceable and !range.
    636 
    637   std::map<int, llvm::MDNode*> result;
    638   for (auto kind_md_pair : a) {
    639     if (kind_md_pair.first == llvm::LLVMContext::MD_alias_scope) {
    640       llvm::SmallVector<llvm::Metadata*, 8> union_of_scopes;
    641       llvm::SmallPtrSet<llvm::Metadata*, 8> scope_set;
    642       for (const auto& scope_a : kind_md_pair.second->operands()) {
    643         scope_set.insert(llvm::cast<llvm::MDNode>(scope_a.get()));
    644         union_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_a.get()));
    645       }
    646       auto it = b.find(kind_md_pair.first);
    647       if (it != b.end()) {
    648         for (const auto& scope_b : it->second->operands()) {
    649           if (!scope_set.count(llvm::cast<llvm::MDNode>(scope_b.get()))) {
    650             union_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_b.get()));
    651           }
    652         }
    653       }
    654       result[llvm::LLVMContext::MD_alias_scope] =
    655           llvm::MDNode::get(*context, union_of_scopes);
    656     } else if (kind_md_pair.first == llvm::LLVMContext::MD_noalias) {
    657       llvm::SmallVector<llvm::Metadata*, 8> intersection_of_scopes;
    658       llvm::SmallPtrSet<llvm::Metadata*, 8> scope_set;
    659       for (const auto& scope_a : kind_md_pair.second->operands()) {
    660         scope_set.insert(llvm::cast<llvm::MDNode>(scope_a.get()));
    661       }
    662       auto it = b.find(kind_md_pair.first);
    663       if (it != b.end()) {
    664         for (const auto& scope_b : it->second->operands()) {
    665           if (scope_set.count(llvm::cast<llvm::MDNode>(scope_b))) {
    666             intersection_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_b));
    667           }
    668         }
    669       }
    670       if (!intersection_of_scopes.empty()) {
    671         result[llvm::LLVMContext::MD_noalias] =
    672             llvm::MDNode::get(*context, intersection_of_scopes);
    673       }
    674     }
    675   }
    676   return result;
    677 }
    678 
    679 static string GetProcessUniqueIrFileName(tensorflow::StringPiece prefix) {
    680   static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
    681   static NameUniquer* uniquer = new NameUniquer(/*separator=*/"-");
    682 
    683   tensorflow::mutex_lock lock(mu);
    684   return uniquer->GetUniqueName(prefix);
    685 }
    686 
    687 static Status CreateAndWriteStringToFile(const string& directory_name,
    688                                          const string& file_name,
    689                                          const string& text) {
    690   std::unique_ptr<tensorflow::WritableFile> f;
    691   TF_RETURN_IF_ERROR(
    692       tensorflow::Env::Default()->RecursivelyCreateDir(directory_name));
    693   TF_RETURN_IF_ERROR(
    694       tensorflow::Env::Default()->NewWritableFile(file_name, &f));
    695   TF_RETURN_IF_ERROR(f->Append(text));
    696   TF_RETURN_IF_ERROR(f->Close());
    697   return Status::OK();
    698 }
    699 
    700 Status DumpIRToDirectory(const string& directory_name,
    701                          const string& hlo_module_name,
    702                          const llvm::Module& llvm_module, bool optimized) {
    703   // We can end up compiling different modules with the same name when using
    704   // XlaJitCompiledCpuFunction::Compile.  Avoid overwriting IR files previously
    705   // dumped from the same process in such cases.
    706   string unique_and_safe_file_name = GetProcessUniqueIrFileName(
    707       tensorflow::strings::StrCat("ir-", SanitizeFileName(hlo_module_name), "-",
    708                                   optimized ? "with" : "no", "-opt"));
    709 
    710   string ir_file_name = tensorflow::io::JoinPath(
    711       directory_name,
    712       tensorflow::strings::StrCat(unique_and_safe_file_name, ".ll"));
    713 
    714   // For some models the embedded constants can be huge, so also dump the module
    715   // with the constants stripped to get IR that is easier to manipulate.
    716   string ir_no_constant_initializers_file_name = tensorflow::io::JoinPath(
    717       directory_name,
    718       tensorflow::strings::StrCat(unique_and_safe_file_name, "-noconst.ll"));
    719 
    720   TF_RETURN_IF_ERROR(CreateAndWriteStringToFile(
    721       directory_name, ir_file_name, DumpModuleToString(llvm_module)));
    722   return CreateAndWriteStringToFile(
    723       directory_name, ir_no_constant_initializers_file_name,
    724       DumpModuleToString(*DropConstantInitializers(llvm_module)));
    725 }
    726 
    727 llvm::Function* CreateFunction(llvm::FunctionType* function_type,
    728                                llvm::GlobalValue::LinkageTypes linkage,
    729                                bool enable_fast_math, bool optimize_for_size,
    730                                tensorflow::StringPiece name,
    731                                llvm::Module* module) {
    732   llvm::Function* function =
    733       llvm::Function::Create(function_type, linkage, AsStringRef(name), module);
    734   function->setCallingConv(llvm::CallingConv::C);
    735   function->addFnAttr("no-frame-pointer-elim", "false");
    736 
    737   if (enable_fast_math) {
    738     function->addFnAttr("unsafe-fp-math", "true");
    739     function->addFnAttr("no-infs-fp-math", "true");
    740     function->addFnAttr("no-nans-fp-math", "true");
    741     function->addFnAttr("no-signed-zeros-fp-math", "true");
    742   }
    743 
    744   // Add the optize attribute to the function if optimizing for size. This
    745   // controls internal behavior of some optimization passes (e.g. loop
    746   // unrolling).
    747   if (optimize_for_size) {
    748     function->addFnAttr(llvm::Attribute::OptimizeForSize);
    749   }
    750 
    751   return function;
    752 }
    753 
    754 void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) {
    755   auto options = config.debug_options().xla_backend_extra_options();
    756   if (!options.empty()) {
    757     std::vector<string> fake_argv_storage;
    758     fake_argv_storage.push_back("");
    759     for (const auto& it : options) {
    760       // Skip options the XLA backend itself consumes.
    761       if (!tensorflow::StringPiece(it.first).starts_with("xla_")) {
    762         if (it.second.empty()) {
    763           fake_argv_storage.push_back(it.first);
    764         } else {
    765           fake_argv_storage.push_back(it.first + "=" + it.second);
    766         }
    767       }
    768     }
    769 
    770     VLOG(2) << "Passing argv to LLVM:";
    771     std::vector<const char*> fake_argv;
    772     for (const auto& s : fake_argv_storage) {
    773       fake_argv.push_back(s.c_str());
    774       VLOG(2) << s;
    775     }
    776     llvm::cl::ParseCommandLineOptions(fake_argv.size(), &fake_argv[0]);
    777   }
    778 }
    779 
    780 }  // namespace llvm_ir
    781 }  // namespace xla
    782