Home | History | Annotate | Download | only in llvm_ir
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_LLVM_UTIL_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_LLVM_UTIL_H_
     18 
     19 #include <stdint.h>
     20 #include <string>
     21 #include <vector>
     22 
     23 #include "absl/strings/string_view.h"
     24 #include "absl/types/span.h"
     25 #include "llvm/ADT/StringRef.h"
     26 #include "llvm/IR/BasicBlock.h"
     27 #include "llvm/IR/GlobalVariable.h"
     28 #include "llvm/IR/IRBuilder.h"
     29 #include "llvm/IR/Instructions.h"
     30 #include "llvm/IR/Module.h"
     31 #include "llvm/IR/Value.h"
     32 #include "llvm/Support/raw_ostream.h"
     33 #include "tensorflow/compiler/xla/literal.h"
     34 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     35 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
     36 #include "tensorflow/compiler/xla/types.h"
     37 #include "tensorflow/compiler/xla/xla_data.pb.h"
     38 #include "tensorflow/core/platform/types.h"
     39 
     40 namespace llvm {
     41 class FastMathFlags;
     42 class TargetOptions;
     43 };
     44 
     45 namespace xla {
     46 namespace llvm_ir {
     47 
     48 // Convert a absl::string_view to a llvm::StringRef. Note: both
     49 // absl::string_view and llvm::StringRef are non-owning pointers into a
     50 // string in memory. This method is used to feed strings to LLVM
     51 // & Clang APIs that expect llvm::StringRef.
     52 inline llvm::StringRef AsStringRef(absl::string_view str) {
     53   return llvm::StringRef(str.data(), str.size());
     54 }
     55 
     56 template <typename T>
     57 llvm::ArrayRef<T> AsArrayRef(const std::vector<T>& vec) {
     58   return llvm::ArrayRef<T>(vec.data(), vec.size());
     59 }
     60 
     61 template <typename T>
     62 llvm::ArrayRef<T> AsArrayRef(const absl::Span<const T>& slice) {
     63   return llvm::ArrayRef<T>(slice.data(), slice.size());
     64 }
     65 
     66 // Dump the given LLVM entity to a string. This works for Types and Values.
     67 template <typename T>
     68 string DumpToString(const T& entity) {
     69   std::string buffer_string;
     70   llvm::raw_string_ostream ostream(buffer_string);
     71   entity.print(ostream);
     72   ostream.flush();
     73   return buffer_string;
     74 }
     75 
     76 // Dump the given LLVM module to a string. This requires a function distinct
     77 // from DumpToString because the signatures of the print() methods for Values
     78 // and Modules are slightly different.
     79 string DumpModuleToString(const llvm::Module& module);
     80 
     81 // Constructs a human-friendly name from the given inputs.  The result is
     82 // suitable for use as an llvm::Value's name.
     83 //
     84 // This is equivalent to
     85 //
     86 //   - changing the HloInstruction* to its name() (if we called that overload),
     87 //   - joining all of the nonempty inputs by '.', and then
     88 //   - removing all '%'s.
     89 //
     90 string IrName(string a);
     91 string IrName(absl::string_view a, absl::string_view b);
     92 string IrName(const HloInstruction* a, absl::string_view b = "");
     93 
     94 // Removes special characters from a function name.
     95 //
     96 // Note that this can cause different inputs to map to the same output, so after
     97 // sanitizing a function name, you must run it through a uniquer.
     98 string SanitizeFunctionName(string function_name);
     99 
    100 // Emits a call to the specified intrinsic with the given operands. Overloaded
    101 // intrinsics (for example, "minnum") must include a type in overloaded_types
    102 // for each overloaded type. Typically, overloaded intrinsics have only a single
    103 // overloaded type.
    104 llvm::CallInst* EmitCallToIntrinsic(
    105     llvm::Intrinsic::ID intrinsic_id, absl::Span<llvm::Value* const> operands,
    106     absl::Span<llvm::Type* const> overloaded_types, llvm::IRBuilder<>* b);
    107 
    108 // Emit float max. Emit maxnum intrinsic is fast math is disabled, or
    109 // fcmp+select otherwise
    110 llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
    111                           llvm::IRBuilder<>* b);
    112 
    113 // Emit float min. Emit minnum intrinsic is fast math is disabled, or
    114 // fcmp+select otherwise
    115 llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
    116                           llvm::IRBuilder<>* b);
    117 
    118 // Convenience methods for emitting a GEP instruction that indexes into a buffer
    119 // (1-dimensional array), equivalent to array[index]. The type is automatically
    120 // determined from the element type of the array.  The int64 index overload
    121 // wraps the index in a i64 llvm::Value.
    122 llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Value* index,
    123                                    llvm::IRBuilder<>* b);
    124 llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64 index,
    125                                    llvm::IRBuilder<>* b);
    126 
    127 // Returns the LLVM type which represents the given XLA primitive type.
    128 llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
    129                                   llvm::Module* module);
    130 
    131 // Returns the type size in bits. If "type" is a struct, it must be packed.
    132 int GetSizeInBits(llvm::Type* type);
    133 
    134 // Returns the LLVM type which represents the given XLA shape. For example,
    135 // if "shape" is [5 x [10 x f32]], the function returns [5 x [10 x float]].
    136 llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module);
    137 
    138 // Returns a value that represents a pointer to a global string constant that
    139 // encodes the shape as a serialized protobuf.
    140 StatusOr<llvm::Value*> EncodeSelfDescribingShapeConstant(const Shape& shape,
    141                                                          int32* shape_size,
    142                                                          llvm::IRBuilder<>* b);
    143 
    144 // Inverses the encoding of a Shape protobuf into an LLVM global variable.
    145 //
    146 // This is intended to be called from the runtime to decode the llvm::Constants
    147 // that are created via ConvertShapeToSelfDescribingConstant and subsequently
    148 // embedded into the program.
    149 StatusOr<Shape> DecodeSelfDescribingShapeConstant(const void* shape_ptr,
    150                                                   int32 size_bytes);
    151 
    152 // Converts a given literal to an IR Constant. Literals have known constant
    153 // values at IR emission time.
    154 llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
    155                                            llvm::Module* module);
    156 
    157 // Allocates a tile of shared memory.
    158 llvm::GlobalVariable* AllocateSharedMemoryTile(llvm::Module* module,
    159                                                llvm::Type* tile_type,
    160                                                absl::string_view name);
    161 
    162 // Inserts an allocate of the requested type at the entry point of the
    163 // function that the builder is currently building. The insert point
    164 // of the builder is set to the same place after calling this function
    165 // as before.
    166 //
    167 // This can be useful to avoid e.g. executing an alloca every time
    168 // through a loop.
    169 llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type,
    170                                             absl::string_view name,
    171                                             llvm::IRBuilder<>* b,
    172                                             int alignment = 0);
    173 
    174 // As EmitAllocaAtFunctionEntry, but allocates element_count entries
    175 // instead of a single element.
    176 llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type,
    177                                                      llvm::Value* element_count,
    178                                                      absl::string_view name,
    179                                                      llvm::IRBuilder<>* b,
    180                                                      int alignment = 0);
    181 
    182 // Creates a basic block with the same context and function as for the
    183 // builder. Inserts at the end of the function if insert_before is
    184 // null.
    185 llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before,
    186                                    absl::string_view name,
    187                                    llvm::IRBuilder<>* b);
    188 
    189 // Struct with data on a conditional branch in a diamond shape created
    190 // via EmitIfThenElse.
    191 struct LlvmIfData {
    192   // The block that has the conditional branch.
    193   llvm::BasicBlock* if_block;
    194 
    195   // The block that is executed if the condition is true.
    196   llvm::BasicBlock* true_block;
    197 
    198   // The block that is executed if the condition is false.
    199   llvm::BasicBlock* false_block;
    200 
    201   // The block that follows after both the true_block and the
    202   // false_block.
    203   llvm::BasicBlock* after_block;
    204 };
    205 
    206 // Inserts a diamond-shaped if-then-else construct at the current
    207 // insertion point of the builder. This involves splitting the current
    208 // block into two blocks, at the insertion point, and introducing a
    209 // true-block and a false-block that connect the two split pieces. The
    210 // true-block is executed if the condition parameter evaluates to true
    211 // and otherwise the false-block is executed. If `emit_else` is false,
    212 // it jumps to the after-block rather than the false-block if the
    213 // condition is false, and the returned `false_block` is null.
    214 //
    215 // Currently the insertion point of the builder must be a well-formed
    216 // block with a terminator. If you need to use this for a
    217 // non-terminated block, just make the function able to do that too.
    218 LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name,
    219                           llvm::IRBuilder<>* b, bool emit_else = true);
    220 
    221 // Emits a compare operation between "lhs" and "rhs" with the given predicate,
    222 // and then converts the result to i8 so that it is addressable.
    223 llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate,
    224                             llvm::Value* lhs, llvm::Value* rhs,
    225                             llvm::IRBuilder<>* b);
    226 
    227 // Emits a call that logs the given value with the given tag as a prefix.
    228 // The provided tag and value are passed to a runtime logging call that is
    229 // embedded in this translation unit when the emitted code is executed.
    230 //
    231 // This can be very useful for debugging generated programs in short order when
    232 // developing new generated routines.
    233 //
    234 // Precondition: value must be an int64.
    235 // Precondition: tag must be a stable pointer for the lifetime of the generated
    236 // program (the constant pointer is burned in to the program).
    237 void EmitLogging(const char* tag, llvm::Value* value, llvm::IRBuilder<>* b);
    238 
    239 // Adds alignment metadata to a load instruction using the given alignment.
    240 // The alignment refers to the result of the load, not the load itself.
    241 void SetAlignmentMetadataForLoad(llvm::LoadInst* load, uint64_t alignment);
    242 
    243 // Adds dereferenceable metadata to a load instruction using the given
    244 // the number of dereferenceable bytes.
    245 // Dereferenceable refers to the result of the load, not the load itself.
    246 void SetDereferenceableMetadataForLoad(llvm::LoadInst* load,
    247                                        uint64_t dereferenceable_bytes);
    248 
    249 // Tells LLVM `inst >= lower && inst < upper`. Returns `inst` for convenience.
    250 llvm::Instruction* AddRangeMetadata(int64 lower, int64 upper,
    251                                     llvm::Instruction* inst);
    252 
    253 void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder);
    254 
    255 void SetToLastInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder);
    256 
    257 // Create a bitwise rotation of `rotand` by `rotor`.
    258 llvm::Value* CreateRor(llvm::Value* rotand, llvm::Value* rotor,
    259                        llvm::IRBuilder<>* builder);
    260 
    261 // Returns the number of bytes within the shape.
    262 int64 ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout);
    263 
    264 // Gets an llvm::FastMathFlags that reflects the settings in the given
    265 // module config.
    266 llvm::FastMathFlags GetCpuFastMathFlags(const HloModuleConfig& module_config);
    267 
    268 // Computes a conservative union of the metadata in "a" and "b".  For
    269 // aliasing-related metadata, this means the result can be applied to
    270 // instructions whose aliasing relationship can be described either by "a" *or*
    271 // by "b".
    272 std::map<int, llvm::MDNode*> MergeMetadata(
    273     llvm::LLVMContext* context, const std::map<int, llvm::MDNode*>& a,
    274     const std::map<int, llvm::MDNode*>& b);
    275 
    276 // Dumps out `llvm_module` to the path specified in DebugOptions, if dumping is
    277 // enabled for the given HLO module.
    278 //
    279 // A sanitized version of `hlo_module_name` is incorporated into the file name.
    280 // If `optimized` is true then a suffix of "-with-opt.ll" is used, else a suffix
    281 // of "-no-opt.ll" is used.
    282 void DumpIrIfEnabled(const HloModule& hlo_module,
    283                      const llvm::Module& llvm_module, bool optimized);
    284 
    285 llvm::Function* CreateCpuFunction(llvm::FunctionType* function_type,
    286                                   llvm::GlobalValue::LinkageTypes linkage,
    287                                   const HloModuleConfig& module_config,
    288                                   absl::string_view name, llvm::Module* module);
    289 
    290 // Extracts the xla_backend_extra_options from `config` and passes those that
    291 // don't start with xla_ to LLVM.
    292 void InitializeLLVMCommandLineOptions(const HloModuleConfig& config);
    293 
    294 // Zero-extends two 32-bit values to 64 bits, multiplies them, and returns the
    295 // result as a pair of (low 32 bits, high 32 bits).
    296 std::pair<llvm::Value*, llvm::Value*> UMulLowHigh32(llvm::IRBuilder<>* b,
    297                                                     llvm::Value* src0,
    298                                                     llvm::Value* src1);
    299 // Splits the 64-bit integer value into its high and low 32 bits.
    300 std::pair<llvm::Value*, llvm::Value*> SplitInt64ToInt32s(
    301     llvm::IRBuilder<>* b, llvm::Value* value_64bits);
    302 
    303 // Checks whether a global variable is already created to represent a
    304 // state passed between RNG calls implemented with Philox algorithm. If not,
    305 // creates such a variable. Returns the global variable.
    306 llvm::GlobalVariable* GetOrCreateVariableForPhiloxRngState(
    307     llvm::Module* module, llvm::IRBuilder<>* b);
    308 
    309 // Adds a value to the global state variable each time when a RNG hlo is
    310 // executed. The value of this global state variable is added to the seed
    311 // of the Philox RNG algorithm so that calling the same RNG Hlo multiple times
    312 // should rarely produce the same result.
    313 void IncrementVariableForPhiloxRngState(int64 value, llvm::Module* module,
    314                                         llvm::IRBuilder<>* b);
    315 }  // namespace llvm_ir
    316 }  // namespace xla
    317 
    318 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_LLVM_UTIL_H_
    319