Home | History | Annotate | Download | only in cpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_
     18 
     19 #include <stddef.h>
     20 #include <map>
     21 #include <memory>
     22 #include <string>
     23 #include <unordered_map>
     24 #include <vector>
     25 
     26 #include "absl/container/flat_hash_map.h"
     27 #include "absl/strings/string_view.h"
     28 #include "absl/types/span.h"
     29 #include "llvm/ADT/Triple.h"
     30 #include "llvm/IR/Function.h"
     31 #include "llvm/IR/IRBuilder.h"
     32 #include "llvm/IR/Module.h"
     33 #include "llvm/IR/Value.h"
     34 #include "llvm/Target/TargetMachine.h"
     35 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
     36 #include "tensorflow/compiler/xla/service/cpu/ir_function.h"
     37 #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
     38 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
     39 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     40 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     41 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
     42 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
     43 #include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h"
     44 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
     45 #include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
     46 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
     47 #include "tensorflow/compiler/xla/service/name_uniquer.h"
     48 #include "tensorflow/compiler/xla/statusor.h"
     49 #include "tensorflow/compiler/xla/types.h"
     50 #include "tensorflow/compiler/xla/xla_data.pb.h"
     51 #include "tensorflow/core/platform/macros.h"
     52 #include "tensorflow/core/platform/types.h"
     53 
     54 namespace xla {
     55 namespace cpu {
     56 // This class is the top-level API for the XLA HLO --> LLVM IR compiler.  It
     57 // implements the DfsHloVisitor interface and emits HLO computations as LLVM IR
     58 // functions.
     59 class IrEmitter : public DfsHloVisitorWithDefault,
     60                   public IrBuilderMixin<IrEmitter> {
     61  public:
     62   using GeneratorForOperandIrArrays =
     63       std::function<std::vector<llvm_ir::IrArray>()>;
     64 
     65   // Create a new LLVM IR emitter.
     66   //
     67   // hlo_module: the HLO module we are emitting IR for.
     68   // assignment: a BufferAssignment from which we know which buffers are used by
     69   //             the HLO nodes.
     70   // llvm_module: the LLVM module to emit IR into.
     71   // instruction_to_profile_idx: the mapping from HLO instructions to their
     72   //              index in the profiling array.
     73   // computation_to_profile_idx: the mapping from HLO computations to their
     74   //              index in the profiling array.
     75   // emit_code_for_msan: whether emitted code should be compatible with msan.
     76   IrEmitter(const HloModule& hlo_module, const BufferAssignment& assignment,
     77             llvm::Module* llvm_module,
     78             std::unordered_map<const HloInstruction*, int64>
     79                 instruction_to_profile_idx,
     80             std::unordered_map<const HloComputation*, int64>
     81                 computation_to_profile_idx,
     82             const TargetMachineFeatures* target_machine,
     83             bool emit_code_for_msan);
     84   ~IrEmitter() override;
     85 
     86   // Emit and return the given HLO computation as an LLVM IR
     87   // function.
     88   //
     89   // function_name_prefix is the desired name of the function. If the name is
     90   // not unique among already emitted functions then a suffix is appended to
     91   // make the name unique.
     92   //
     93   // 'is_top_level_computation' has the following meanings for each CPU backend:
     94   // *) sequential: indicates that this is the entry computation of the HLO
     95   //    module.
     96   // *) parallel: indices that this is the callee of a kCall HLO in the entry
     97   //    computation of the HLO module.
     98   //
     99   // If 'instruction_order' is not NULL, then the HLO instructions are emitted
    100   // in the given order.  In this case, 'instruction_order' must be a
    101   // topological sort of the set of nodes accessible from the root of the
    102   // computation.
    103   StatusOr<llvm::Function*> EmitComputation(
    104       HloComputation* computation, const string& function_name_prefix,
    105       bool is_top_level_computation,
    106       absl::Span<HloInstruction* const> instruction_order);
    107 
    108   llvm::IRBuilder<>* b() { return &b_; }
    109 
    110   // builder() is for IrBuilderMixin.
    111   llvm::IRBuilder<>* builder() { return &b_; }
    112 
    113   // Emit an LLVM global variable for every constant buffer allocation.
    114   Status EmitConstantGlobals();
    115 
    116   // Emit code to map one element according to `map_instr`.
    117   llvm::Value* EmitElementalMap(
    118       const HloMapInstruction& map_instr,
    119       absl::Span<llvm::Value* const> elemental_operands,
    120       absl::string_view name);
    121   // Emit code to emit the element at `index` for a reduce window instruction.
    122   StatusOr<llvm::Value*> EmitElementalReduceWindow(
    123       const HloReduceWindowInstruction* reduce_window,
    124       const llvm_ir::ElementGenerator& input_generator,
    125       const llvm_ir::IrArray::Index& index);
    126   // Emit code to emit the element at `index` for a convolution instruction.
    127   StatusOr<llvm::Value*> EmitElementalConvolution(
    128       const HloConvolutionInstruction* convolution,
    129       const llvm_ir::ElementGenerator& input_generator,
    130       const llvm_ir::ElementGenerator& kernel_generator,
    131       const llvm_ir::IrArray::Index& index);
    132   // Emit code to emit the element at `index` for a reduce instruction.
    133   StatusOr<llvm::Value*> EmitElementalReduce(
    134       const HloReduceInstruction* reduce,
    135       const llvm_ir::ElementGenerator& input_generator,
    136       const llvm_ir::ElementGenerator& initial_value_generator,
    137       const llvm_ir::IrArray::Index& index);
    138 
    139  protected:
    140   //
    141   // The following methods implement the DfsHloVisitor interface.
    142   //
    143   // Default action which emits code for most operations. Operations which are
    144   // special in some way are handled explicitly in HandleFoo methods.
    145   Status DefaultAction(HloInstruction* hlo) override;
    146 
    147   Status HandleAllToAll(HloInstruction* instruction) override;
    148   Status HandleBitcast(HloInstruction* bitcast) override;
    149   Status HandleConstant(HloInstruction* constant) override;
    150   Status HandleCopy(HloInstruction* copy) override;
    151   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
    152   Status HandleSelect(HloInstruction* select) override;
    153   Status HandleTupleSelect(HloInstruction* tuple_select) override;
    154   Status HandleDot(HloInstruction* dot) override;
    155   Status HandleConvolution(HloInstruction* convolution) override;
    156   Status HandleFft(HloInstruction* fft) override;
    157   Status HandleAllReduce(HloInstruction* crs) override;
    158   Status HandleInfeed(HloInstruction* infeed) override;
    159   Status HandleOutfeed(HloInstruction* outfeed) override;
    160   Status HandleSort(HloInstruction* sort) override;
    161   Status HandleParameter(HloInstruction* parameter) override;
    162   Status HandleReduce(HloInstruction* reduce) override;
    163   Status HandleReduceWindow(HloInstruction* reduce_window) override;
    164   Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override;
    165   Status HandleSend(HloInstruction* send) override;
    166   Status HandleSendDone(HloInstruction* send_done) override;
    167   Status HandleSlice(HloInstruction* slice) override;
    168   Status HandleDynamicSlice(HloInstruction* dynamic_slice) override;
    169   Status HandleDynamicUpdateSlice(
    170       HloInstruction* dynamic_update_slice) override;
    171   Status HandleRecv(HloInstruction* recv) override;
    172   Status HandleRecvDone(HloInstruction* recv_done) override;
    173   Status HandlePad(HloInstruction* pad) override;
    174   Status HandleTuple(HloInstruction* tuple) override;
    175   Status HandleFusion(HloInstruction* fusion) override;
    176   Status HandleCall(HloInstruction* call) override;
    177   Status HandleCustomCall(HloInstruction* custom_call) override;
    178   Status HandleWhile(HloInstruction* xla_while) override;
    179   Status HandleConcatenate(HloInstruction* concatenate) override;
    180   Status HandleConditional(HloInstruction* conditional) override;
    181   Status HandleScatter(HloInstruction* scatter) override;
    182   Status HandleAfterAll(HloInstruction* after_all) override;
    183   Status HandleAddDependency(HloInstruction* add_dependency) override;
    184   Status HandleRng(HloInstruction* rng) override;
    185   Status FinishVisit(HloInstruction* root) override;
    186 
    187   Status Preprocess(HloInstruction* hlo) override;
    188   Status Postprocess(HloInstruction* hlo) override;
    189 
    190   // A convenient helper for calling BufferAssignment::GetUniqueSlice.
    191   BufferAllocation::Slice GetAllocationSlice(
    192       const HloInstruction& hlo, const ShapeIndex& index = {}) const {
    193     return assignment_.GetUniqueSlice(&hlo, index).ConsumeValueOrDie();
    194   }
    195 
    196  private:
    197   // Private helper to initialize an IR function for the computation.
    198   void InitializeIrFunction(const string& function_name);
    199 
    200   template <typename T>
    201   llvm::Value* GetProfileCounterCommon(
    202       const T& hlo,
    203       const std::unordered_map<const T*, int64>& profile_index_map);
    204 
    205   // Convenience functions to generate a GEP into the profile counter parameter
    206   // which would correspond to the index for a given HLO instruction or
    207   // computation.
    208   llvm::Value* GetProfileCounterFor(const HloInstruction& instruction) {
    209     return GetProfileCounterCommon<HloInstruction>(instruction,
    210                                                    instruction_to_profile_idx_);
    211   }
    212 
    213   llvm::Value* GetProfileCounterFor(const HloComputation& computation) {
    214     return GetProfileCounterCommon<HloComputation>(computation,
    215                                                    computation_to_profile_idx_);
    216   }
    217 
    218   // Gets the IR Value emitted previously for the given hlo.
    219   //
    220   // Prefer calling GetIrArrayFor if the value you're reading is a buffer,
    221   // because GetIrArrayFor annotates buffer's loads/stores with noalias
    222   // metadata.
    223   //
    224   // Make sure to call this only when you're certain a value *was* emitted - if
    225   // not found, this will log a fatal error.
    226   llvm::Value* GetEmittedValueFor(const HloInstruction* hlo);
    227 
    228   // Gets an IrArray representing the given hlo.
    229   llvm_ir::IrArray GetIrArrayFor(const HloInstruction* hlo);
    230 
    231   // Gets a list of IrArrays, one for each of hlo's operands.
    232   std::vector<llvm_ir::IrArray> GetIrArraysForOperandsOf(
    233       const HloInstruction* hlo);
    234 
    235   GeneratorForOperandIrArrays GetGeneratorForOperandIrArrays(
    236       HloInstruction* unnested_hlo) {
    237     return [=]() { return GetIrArraysForOperandsOf(unnested_hlo); };
    238   }
    239 
    240   // Augments IrArray with aliasing information.
    241   void AddAliasingInformationToIrArray(const HloInstruction& hlo,
    242                                        llvm_ir::IrArray* array) {
    243     alias_analysis_.AddAliasingInformationToIrArray(hlo, array);
    244   }
    245 
    246   // Convenience function to get the IR type matching the given shape.
    247   llvm::Type* IrShapeType(const Shape& shape);
    248 
    249   // Get the llvm::Value* that represents the "prof_counters" argument of the
    250   // computation function being emitted by this emitter.
    251   llvm::Value* GetProfileCountersArgument();
    252 
    253   // Get the xla::ExecutableRunOptions that represents the "run_options"
    254   // argument of the computation function being emitted by this emitter.
    255   llvm::Value* GetExecutableRunOptionsArgument();
    256 
    257   // Get the llvm::Value* that represents the "buffer_table" argument of the
    258   // computation function being emitted by this emitter.
    259   llvm::Value* GetBufferTableArgument();
    260 
    261   // Helper for EmitBufferPointer.
    262   llvm::Value* EmitGlobalBufferPointer(const BufferAllocation::Slice& slice,
    263                                        const Shape& target_shape);
    264 
    265   // Helper for EmitBufferPointer.
    266   llvm::Value* EmitThreadLocalBufferPointer(
    267       const BufferAllocation::Slice& slice, const Shape& target_shape);
    268 
    269   // Emits code that computes the address of the given buffer allocation slice.
    270   llvm::Value* EmitBufferPointer(const BufferAllocation::Slice& slice,
    271                                  const Shape& target_shape);
    272 
    273   // Emits a call to a thread local function (e.g. to the computation nested
    274   // within a reduce or a map).  Thread local callees (by definition) only write
    275   // to and read from thread local allocations.
    276   //
    277   // `parameters` holds the *scalar values* that need to be passed to the
    278   // callee.  The return value is the scalar returned by the callee.
    279   llvm::Value* EmitThreadLocalCall(const HloComputation& callee,
    280                                    absl::Span<llvm::Value* const> parameters,
    281                                    absl::string_view name);
    282 
    283   // Emits a call to a "global" function (e.g. to the computation nested within
    284   // a kWhile or a kCall).  Buffer assignment unabiguously assignes buffers to
    285   // the parameters and return values for these computations so there is no need
    286   // to explicitly pass parameters or return results.
    287   void EmitGlobalCall(const HloComputation& callee, absl::string_view name);
    288 
    289   // Returns the buffer to which a global call to `callee` would have written
    290   // its result.
    291   llvm::Value* GetBufferForGlobalCallReturnValue(const HloComputation& callee);
    292 
    293   // Verifies that the element types of all of the given operand instructions
    294   // match and are of one of the given supported types.
    295   Status ElementTypesSameAndSupported(
    296       const HloInstruction& instruction,
    297       absl::Span<const HloInstruction* const> operands,
    298       absl::Span<const PrimitiveType> supported_types);
    299 
    300   // Emit IR to perform a computation for every element in the given target op.
    301   // This produces a series of nested loops (one for each dimension of the op's
    302   // shape). The body of the inner-most loop is provided by the body_emitter
    303   // function.
    304   //
    305   // desc is an optional human-readable string that's added to the loop name in
    306   // IR.  Regardless of whether desc is provided, target_op->name() is included
    307   // in the loop name.
    308   //
    309   // TODO(jingyue): target_op should be a `const HloInstruction*`.
    310   Status EmitTargetElementLoop(
    311       HloInstruction* target_op,
    312       const llvm_ir::ElementGenerator& element_generator);
    313   Status EmitTargetElementLoop(
    314       HloInstruction* target_op, absl::string_view desc,
    315       const llvm_ir::ElementGenerator& element_generator);
    316 
    317   // Emits a memcpy from the source instruction's result value to the
    318   // destination's.  Both source and destination must have an entry in the
    319   // emitted_value_ table.
    320   Status EmitMemcpy(const HloInstruction& source,
    321                     const HloInstruction& destination);
    322 
    323   // Emits IR to compute the target address of the buffer for the given op.
    324   // After calling this function, you can get a pointer to this buffer by
    325   // calling GetIrArrayForOp or GetEmittedValueFor.
    326   Status EmitTargetAddressForOp(const HloInstruction* op);
    327 
    328   // Structurizes "array_elements" into an MD array that represents "shape".
    329   // This is a recursive function, and "dimension_index" indicates the index of
    330   // the current dimension that the function is considering (0 means the
    331   // most-minor dimension).
    332   llvm::Constant* CreateInitializerForConstantArray(
    333       const std::vector<llvm::Constant*>& array_elements, const Shape& shape,
    334       int64 dimension_index);
    335 
    336   // Tries to codegen a reduction operation using vectorized instructions.
    337   // Returns true if successful, and false on failure.  On failure, sets
    338   // "failure_reason" to a string describing why it could not vectorize the
    339   // reduction.
    340   //
    341   // TODO(sanjoy): Some of the things we do here can be abstracted out into
    342   // concepts that generalize over other vectorizable operations.  We should
    343   // consider pulling out these abstractions into a VectorizingIrEmitter or
    344   // something similar.
    345   StatusOr<bool> EmitVectorizedReduce(HloInstruction* reduce,
    346                                       HloInstruction* arg,
    347                                       HloInstruction* init_value,
    348                                       absl::Span<const int64> dimensions,
    349                                       HloComputation* function,
    350                                       string* failure_reason);
    351 
    352   // We'd like to keep one or two one cache-line's worth of data in registers
    353   // without generating IR with illegal (e.g. excessively large or
    354   // non-power-of-two) vector types.  We do this by introducing a layer of
    355   // abstraction: we introduce a high level vector-like concept called a
    356   // "sharded vector" that models data paralleism, and is mapped to a sequence
    357   // scalar and vector llvm::Value s.
    358   //
    359   // For example, we can represent 29 f32 elements by a sharded vector mapped to
    360   // a sequence of LLVM values of types [<16 x f32>, <8 x f32>, <4 x f32>, f32].
    361   // Note that the last element is scalar.
    362   //
    363   // There is no requirement on the ordering or the uniqueness of the elements
    364   // mapped to sharded vectors -- we allow repeated elements, and we allow
    365   // elements to appear in any order.
    366   using ShardedVector = std::vector<llvm::Value*>;
    367 
    368   // A sharded vector type is the element-wise llvm::Type's of some
    369   // ShardedVector.
    370   using ShardedVectorType = std::vector<llvm::Type*>;
    371 
    372   // Create a sharded vector type corresponding to a "element_count" long
    373   // sequence of "element_type" values.
    374   ShardedVectorType CreateShardedVectorType(PrimitiveType element_type,
    375                                             unsigned element_count);
    376 
    377   // Emit LLVM IR to store the sharded vector "value_to_store" to
    378   // "store_address".
    379   void EmitShardedVectorStore(llvm::Value* store_address,
    380                               const ShardedVector& value_to_store,
    381                               const int alignment,
    382                               const llvm_ir::IrArray& containing_array);
    383 
    384   using ReductionGenerator = std ::function<llvm::Value*(
    385       llvm::IRBuilder<>*, llvm::Value*, llvm::Value*)>;
    386 
    387   // Tries to match the reduction function "function" to a known reduction
    388   // pattern.  Returns a non-null ReductionGenerator on a successful match,
    389   // which can be used to generate the LLVM IR corresponding to said reduction.
    390   // On failure, this stores a reason string into "failure_reason".
    391   ReductionGenerator MatchReductionGenerator(HloComputation* function,
    392                                              string* failure_reason) const;
    393 
    394   // Emits the inner loop nest that runs the reduction.  Helper function for
    395   // EmitVectorizedReduce.
    396   StatusOr<ShardedVector> EmitInnerLoopForVectorizedReduction(
    397       const ReductionGenerator& reduction_generator,
    398       const llvm_ir::IrArray::Index& output_index,
    399       const ShardedVectorType& accumulator_type, HloInstruction* init_value,
    400       HloInstruction* arg, absl::Span<const int64> dimensions,
    401       unsigned element_alignment);
    402 
    403   // Tries to emit a fast concatenate operation using memcpy.  Returns true if
    404   // successful, and false on failure.  On failure, sets "failure_reason" to a
    405   // string describing why it could not emit a fast concatenate.
    406   StatusOr<bool> EmitFastConcatenate(HloInstruction* concatenate,
    407                                      absl::Span<HloInstruction* const> operands,
    408                                      string* failure_reason);
    409 
    410   // Emits LLVM IR to transfer "element_count" elements of type "primitive_type"
    411   // from the address "source" to the address "target".
    412   void EmitTransferElements(llvm::Value* target, llvm::Value* source,
    413                             int64 element_count, PrimitiveType primitive_type,
    414                             const llvm_ir::IrArray& target_array,
    415                             const llvm_ir::IrArray& source_array);
    416 
    417   // Assignment of the buffers needed by the computation and their shape
    418   // information.
    419   const BufferAssignment& assignment_;
    420 
    421   // The LLVM module into which IR will be emitted.
    422   llvm::Module* module_;
    423 
    424   // The target architecture.
    425   llvm::Triple::ArchType arch_type_;
    426 
    427   // Used to produce unique names for generated functions.
    428   NameUniquer name_uniquer_;
    429 
    430   // Map containing all previously emitted computations.
    431   std::map<const HloComputation*, llvm::Function*> emitted_functions_;
    432 
    433   // Map containing all previously emitted thread-local temporary buffers.
    434   std::map<std::pair<llvm::Function*, BufferAllocation::Slice>, llvm::Value*>
    435       thread_local_buffers_;
    436 
    437   // The following fields track the IR emission state. According to LLVM memory
    438   // management rules, their memory is owned by the module (Note that IrFunction
    439   // creates the encapsulated llvm::Function s.t. it is added to the llvm
    440   // module's function list).
    441   std::unique_ptr<IrFunction> compute_function_;
    442   llvm::IRBuilder<> b_;
    443 
    444   // The buffer allocation slice for the root of the computation being compiled.
    445   // Only relevant for thread local computations.
    446   BufferAllocation::Slice computation_root_allocation_;
    447 
    448   // Maps the buffer allocation slices for the parameters to the computation
    449   // being compiled to their parameter numbers.  Only relevant for thread local
    450   // computations.
    451   absl::flat_hash_map<BufferAllocation::Index, int64>
    452       computation_parameter_allocations_;
    453 
    454   // Maps HLO instructions to their index into the profile counter array.
    455   const std::unordered_map<const HloInstruction*, int64>
    456       instruction_to_profile_idx_;
    457 
    458   // Maps HLO computations to their index into the profile counter array.
    459   const std::unordered_map<const HloComputation*, int64>
    460       computation_to_profile_idx_;
    461 
    462   // Maps HLOs to Values emitted for them.
    463   absl::flat_hash_map<const HloInstruction*, llvm::Value*> emitted_value_;
    464 
    465   llvm_ir::AliasAnalysis alias_analysis_;
    466 
    467   // The number of root instruction outer dimensions used in parallel loop
    468   // emission (ParallelLoopEmitter).
    469   int64 num_dynamic_loop_bounds_ = 0;
    470 
    471   // Returns whether the given instruction should be emitted as a parallel loop.
    472   bool ShouldEmitParallelLoopFor(const HloInstruction& op) const {
    473     // Emit parallel loop for root instruction if dynamic outer-dimension loop
    474     // bounds were specified.
    475     return num_dynamic_loop_bounds_ > 0 &&
    476            op.parent()->root_instruction() == &op;
    477   }
    478 
    479   // This struct contains all the state needed to emit instructions for
    480   // profiling a computation.
    481   class ProfilingState {
    482    public:
    483     ProfilingState() : use_rdtscp_(false) {}
    484     explicit ProfilingState(bool use_rdtscp) : use_rdtscp_(use_rdtscp) {}
    485 
    486     // Record the cycle counter before an HLO executes.
    487     void RecordCycleStart(llvm::IRBuilder<>* b, HloInstruction* hlo);
    488     // Record the number of cycles it took for an HLO to execute.
    489     void RecordCycleDelta(llvm::IRBuilder<>* b, HloInstruction* hlo,
    490                           llvm::Value* prof_counter);
    491     // Record the number of cycles it took for the entire computation to
    492     // execute.
    493     void RecordCompleteComputation(llvm::IRBuilder<>* b,
    494                                    llvm::Value* prof_counter);
    495 
    496     // Convenience function to generate a call to an intrinsic which reads the
    497     // CPU cycle counter.
    498     llvm::Value* ReadCycleCounter(llvm::IRBuilder<>* b);
    499 
    500     // Store the cycle counter delta to the per-HLO profile counter.
    501     void UpdateProfileCounter(llvm::IRBuilder<>* b, llvm::Value* prof_counter,
    502                               llvm::Value* cycle_end, llvm::Value* cycle_start);
    503 
    504    private:
    505     // Should we use the x86-specific rdtscp or the generic readcyclecounter
    506     // intrinsic?
    507     bool use_rdtscp_;
    508 
    509     // The first read cycle counter in the program.
    510     llvm::Value* first_read_cycle_start_ = nullptr;
    511 
    512     // The last read cycle counter in the program.
    513     llvm::Value* last_read_cycle_end_ = nullptr;
    514 
    515     // An alloca used to hold the output of the aux value returned by the rdtscp
    516     // intrinsic.
    517     llvm::Value* aux_i8ptr_ = nullptr;
    518 
    519     // Maps HLOs to the value the cycle counter contained right before the HLO
    520     // began to execute.
    521     std::unordered_map<const HloInstruction*, llvm::Value*> cycle_starts_;
    522   };
    523 
    524   ProfilingState profiling_state_;
    525 
    526   // Given a load instruction and a shape or buffer size, annotate the load's
    527   // result with the alignment required by the shape or size.
    528   void AttachAlignmentMetadataForLoad(llvm::LoadInst* load, const Shape& shape);
    529   void AttachAlignmentMetadataForLoad(llvm::LoadInst* load, int64 buffer_size);
    530 
    531   // Given a load instruction and a shape or buffer size, annotate the load's
    532   // result with the dereferenceable bytes required by the shape / buffer size.
    533   void AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
    534                                             const Shape& shape);
    535   void AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
    536                                             int64 buffer_size);
    537 
    538   // Calculate the alignment of a buffer allocated for a given shape.
    539   int MinimumAlignmentForShape(const Shape& shape);
    540 
    541   // Calculate the alignment of a buffer allocated for a given primitive type.
    542   int MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type);
    543 
    544   // Returns the number of bytes within the shape.
    545   int64 ByteSizeOf(const Shape& shape) const;
    546 
    547   enum class XfeedKind {
    548     kInfeed,
    549     kOutfeed,
    550   };
    551 
    552   // Emit IR to transfer between a {infeed,outfeed} buffer and an in-program
    553   // address.
    554   Status EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
    555                            llvm::Value* program_buffer_address);
    556 
    557   // Returns a ConstExpr bitcast.
    558   llvm::Constant* EmitGlobalForLiteral(const Literal& literal);
    559 
    560   const HloModuleConfig& hlo_module_config_;
    561 
    562   bool is_top_level_computation_;
    563 
    564   const TargetMachineFeatures& target_machine_features_;
    565 
    566   struct LiteralPtrHashFunctor {
    567     size_t operator()(const Literal* literal) const { return literal->Hash(); }
    568   };
    569 
    570   struct LiteralPtrEqualityFunctor {
    571     bool operator()(const Literal* lhs, const Literal* rhs) const {
    572       return *lhs == *rhs;
    573     }
    574   };
    575 
    576   absl::flat_hash_map<const Literal*, llvm::Constant*, LiteralPtrHashFunctor,
    577                       LiteralPtrEqualityFunctor>
    578       emitted_literals_;
    579 
    580   absl::flat_hash_map<BufferAllocation::Index, llvm::Constant*>
    581       constant_buffer_to_global_;
    582 
    583   std::vector<const HloComputation*> thread_local_computations_;
    584   std::vector<const HloComputation*> global_computations_;
    585 
    586   bool emit_code_for_msan_;
    587 
    588   TF_DISALLOW_COPY_AND_ASSIGN(IrEmitter);
    589 };
    590 
    591 }  // namespace cpu
    592 }  // namespace xla
    593 
    594 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_
    595