Home | History | Annotate | Download | only in cpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_
     18 
     19 #include <stddef.h>
     20 #include <map>
     21 #include <memory>
     22 #include <string>
     23 #include <unordered_map>
     24 #include <vector>
     25 
     26 #include "llvm/ADT/Triple.h"
     27 #include "llvm/IR/Function.h"
     28 #include "llvm/IR/IRBuilder.h"
     29 #include "llvm/IR/Module.h"
     30 #include "llvm/IR/Value.h"
     31 #include "llvm/Target/TargetMachine.h"
     32 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
     33 #include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h"
     34 #include "tensorflow/compiler/xla/service/cpu/ir_function.h"
     35 #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
     36 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
     37 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     38 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     39 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
     40 #include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h"
     41 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
     42 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
     43 #include "tensorflow/compiler/xla/service/name_uniquer.h"
     44 #include "tensorflow/compiler/xla/statusor.h"
     45 #include "tensorflow/compiler/xla/types.h"
     46 #include "tensorflow/compiler/xla/xla_data.pb.h"
     47 #include "tensorflow/core/lib/core/stringpiece.h"
     48 #include "tensorflow/core/lib/gtl/array_slice.h"
     49 #include "tensorflow/core/lib/gtl/flatmap.h"
     50 #include "tensorflow/core/platform/macros.h"
     51 #include "tensorflow/core/platform/types.h"
     52 
     53 namespace xla {
     54 namespace cpu {
     55 // This class is the top-level API for the XLA HLO --> LLVM IR compiler.  It
     56 // implements the DfsHloVisitor interface and emits HLO computations as LLVM IR
     57 // functions.
     58 class IrEmitter : public DfsHloVisitorWithDefault {
     59  public:
     60   // Create a new LLVM IR emitter.
     61   //
     62   // hlo_module: the HLO module we are emitting IR for.
     63   // assignment: a BufferAssignment from which we know which temporary buffers
     64   //             are used by the HLO nodes.
     65   // llvm_module: the LLVM module to emit IR into.
     66   // instruction_to_profile_idx: the mapping from HLO instructions to their
     67   //              index in the profiling array.
     68   // computation_to_profile_idx: the mapping from HLO computations to their
     69   //              index in the profiling array.
     70   // external_constant_pool: if non-null, points to an ExternalConstantPool
     71   //                         instance into which the Ir emitter can spill
     72   //                         constants.
     73   IrEmitter(const HloModule& hlo_module, const BufferAssignment& assignment,
     74             llvm::Module* llvm_module,
     75             std::unordered_map<const HloInstruction*, int64>
     76                 instruction_to_profile_idx,
     77             std::unordered_map<const HloComputation*, int64>
     78                 computation_to_profile_idx,
     79             llvm::TargetMachine* target_machine,
     80             ExternalConstantPool* external_constant_pool);
     81   ~IrEmitter() override;
     82 
     83   // Emit and return the given HLO computation as an LLVM IR
     84   // function.
     85   //
     86   // function_name_prefix is the desired name of the function. If the name is
     87   // not unique among already emitted functions then a suffix is appended to
     88   // make the name unique.
     89   //
     90   // 'is_top_level_computation' has the following meanings for each CPU backend:
     91   // *) sequential: indicates that this is the entry computation of the HLO
     92   //    module.
     93   // *) parallel: indices that this is the callee of a kCall HLO in the entry
     94   //    computation of the HLO module.
     95   //
     96   // If 'instruction_order' is not NULL, then the HLO instructions are emitted
     97   // in the given order.  In this case, 'instruction_order' must be a
     98   // topological sort of the set of nodes accessible from the root of the
     99   // computation.
    100   StatusOr<llvm::Function*> EmitComputation(
    101       HloComputation* computation, const string& function_name_prefix,
    102       bool is_top_level_computation,
    103       std::vector<const HloInstruction*>* instruction_order);
    104 
    105   llvm::IRBuilder<>* ir_builder() { return &ir_builder_; }
    106 
    107   // Emits a call to `computation` with scalar arguments `arguments`.
    108   StatusOr<llvm::Value*> EmitScalarCall(
    109       PrimitiveType return_type, HloComputation* computation,
    110       const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name);
    111 
    112  protected:
    113   //
    114   // The following methods implement the DfsHloVisitor interface.
    115   //
    116   // Default action which emits code for most operations. Operations which are
    117   // special in some way are handled explicitly in HandleFoo methods.
    118   Status DefaultAction(HloInstruction* hlo) override;
    119 
    120   Status HandleBitcast(HloInstruction* bitcast) override;
    121   Status HandleConstant(HloInstruction* constant) override;
    122   Status HandleCopy(HloInstruction* copy) override;
    123   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
    124   Status HandleSelect(HloInstruction* select) override;
    125   Status HandleDot(HloInstruction* dot) override;
    126   Status HandleConvolution(HloInstruction* convolution) override;
    127   Status HandleFft(HloInstruction* fft) override;
    128   Status HandleCrossReplicaSum(HloInstruction* crs) override;
    129   Status HandleInfeed(HloInstruction* infeed) override;
    130   Status HandleOutfeed(HloInstruction* outfeed) override;
    131   Status HandleSort(HloInstruction* sort) override;
    132   Status HandleParameter(HloInstruction* parameter) override;
    133   Status HandleReduce(HloInstruction* reduce) override;
    134   Status HandleReduceWindow(HloInstruction* reduce_window) override;
    135   Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override;
    136   Status HandleSend(HloInstruction* send) override;
    137   Status HandleSendDone(HloInstruction* send_done) override;
    138   Status HandleSlice(HloInstruction* slice) override;
    139   Status HandleDynamicSlice(HloInstruction* dynamic_slice) override;
    140   Status HandleDynamicUpdateSlice(
    141       HloInstruction* dynamic_update_slice) override;
    142   Status HandleRecv(HloInstruction* recv) override;
    143   Status HandleRecvDone(HloInstruction* recv_done) override;
    144   Status HandlePad(HloInstruction* pad) override;
    145   Status HandleTuple(HloInstruction* tuple) override;
    146   Status HandleMap(HloInstruction* map) override;
    147   Status HandleFusion(HloInstruction* fusion) override;
    148   Status HandleCall(HloInstruction* call) override;
    149   Status HandleCustomCall(HloInstruction* custom_call) override;
    150   Status HandleWhile(HloInstruction* xla_while) override;
    151   Status HandleConcatenate(HloInstruction* concatenate) override;
    152   Status HandleConditional(HloInstruction* conditional) override;
    153   Status FinishVisit(HloInstruction* root) override;
    154 
    155   Status Preprocess(HloInstruction* hlo) override;
    156   Status Postprocess(HloInstruction* hlo) override;
    157 
    158  private:
    159   // Private helper to initialize an IR function for the computation.
    160   void InitializeIrFunction(const string& function_name);
    161 
    162   template <typename T>
    163   llvm::Value* GetProfileCounterCommon(
    164       const T& hlo,
    165       const std::unordered_map<const T*, int64>& profile_index_map);
    166 
    167   // Convenience functions to generate a GEP into the profile counter parameter
    168   // which would correspond to the index for a given HLO instruction or
    169   // computation.
    170   llvm::Value* GetProfileCounterFor(const HloInstruction& instruction) {
    171     return GetProfileCounterCommon<HloInstruction>(instruction,
    172                                                    instruction_to_profile_idx_);
    173   }
    174 
    175   llvm::Value* GetProfileCounterFor(const HloComputation& computation) {
    176     return GetProfileCounterCommon<HloComputation>(computation,
    177                                                    computation_to_profile_idx_);
    178   }
    179 
    180   // Gets the IR Value emitted previously for the given hlo.
    181   //
    182   // Prefer calling GetIrArrayFor if the value you're reading is a buffer,
    183   // because GetIrArrayFor annotates buffer's loads/stores with noalias
    184   // metadata.
    185   //
    186   // Make sure to call this only when you're certain a value *was* emitted - if
    187   // not found, this will log a fatal error.
    188   llvm::Value* GetEmittedValueFor(const HloInstruction* hlo);
    189 
    190   // Gets an IrArray representing the given hlo.
    191   llvm_ir::IrArray GetIrArrayFor(const HloInstruction* hlo);
    192 
    193   // Gets a list of IrArrays, one for each of hlo's operands.
    194   std::vector<llvm_ir::IrArray> GetIrArraysForOperandsOf(
    195       const HloInstruction* hlo);
    196 
    197   // Augments IrArray with aliasing information.
    198   void AddAliasingInformationToIrArray(const HloInstruction& hlo,
    199                                        llvm_ir::IrArray* array) {
    200     alias_analysis_.AddAliasingInformationToIrArray(hlo, array);
    201   }
    202 
    203   // Convenience function to get the IR type matching the given shape.
    204   llvm::Type* IrShapeType(const Shape& shape);
    205 
    206   // Get the llvm::Value* that represents the "prof_counters" argument of the
    207   // computation function being emitted by this emitter.
    208   llvm::Value* GetProfileCountersArgument();
    209 
    210   // Get the xla::ExecutableRunOptions that represents the "run_options"
    211   // argument of the computation function being emitted by this emitter.
    212   llvm::Value* GetExecutableRunOptionsArgument();
    213 
    214   // Get the llvm::Value* that represents the "temps" argument of the
    215   // computation function being emitted by this emitter.
    216   llvm::Value* GetTempBuffersArgument();
    217 
    218   // Emits code that computes the address of the given temporary buffer to the
    219   // function. target_shape is the shape of this temporary buffer.
    220   // The returned Value's type is a pointer to element_type.
    221   llvm::Value* EmitTempBufferPointer(const BufferAllocation::Slice& slice,
    222                                      const Shape& target_shape);
    223 
    224   // Emits a function into the current module. This can be used for
    225   // computations embedded inside other computations, such as the
    226   // function that a map operation applies.
    227   StatusOr<llvm::Function*> EmitFunction(
    228       HloComputation* function,  // The function to emit.
    229       tensorflow::StringPiece
    230           function_name_suffix);  // Used for LLVM IR register names.
    231 
    232   // Methods that emit a function call.
    233   // Parameters:
    234   //   function - The LLVM function to call.
    235   //   return_shape - The return shape of the HLO computation that was used to
    236   //     make the function.  Not the same as the return type of the function
    237   //     in LLVM, since we use output parameters for the return type.
    238   //   element_count - number of elements to return (array form only).
    239   //   parameter_addresses - pointers to be passed to the function as
    240   //     parameters.
    241   //   name - used for LLVM IR register names.
    242 
    243   // Emits a function call, returning a scalar, often an element of a larger
    244   // array.  Returns a Value for the scalar element returned by the function.
    245   llvm::Value* EmitElementFunctionCall(
    246       llvm::Function* function, const Shape& return_shape,
    247       tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
    248       tensorflow::StringPiece name);
    249 
    250   // Array function call emitter.  Stores the function's result into a supplied
    251   // buffer.
    252   // Parameters:
    253   //   function - The LLVM function to call.
    254   //   parameter_addresses - pointers to be passed to the function as
    255   //     parameters.
    256   //   return_value - pointer to a buffer where the call result is stored.
    257 
    258   void EmitArrayFunctionCallInto(
    259       llvm::Function* function,
    260       tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
    261       llvm::Value* return_value_buffer, tensorflow::StringPiece name);
    262 
    263   // Array function call emitter.  Returns a Value for the function's return
    264   // value buffer address. The return value buffer is alloca'ed by this
    265   // function.
    266   llvm::Value* EmitArrayFunctionCall(
    267       llvm::Function* function, const Shape& return_shape, int64 element_count,
    268       tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
    269       tensorflow::StringPiece name);
    270 
    271   // Verifies that the element types of all of the given operand instructions
    272   // match and are of one of the given supported types.
    273   Status ElementTypesSameAndSupported(
    274       const HloInstruction& instruction,
    275       tensorflow::gtl::ArraySlice<const HloInstruction*> operands,
    276       tensorflow::gtl::ArraySlice<PrimitiveType> supported_types);
    277 
    278   // Emit IR to perform a computation for every element in the given target op.
    279   // This produces a series of nested loops (one for each dimension of the op's
    280   // shape). The body of the inner-most loop is provided by the body_emitter
    281   // function.
    282   //
    283   // desc is an optional human-readable string that's added to the loop name in
    284   // IR.  Regardless of whether desc is provided, target_op->name() is included
    285   // in the loop name.
    286   //
    287   // TODO(jingyue): target_op should be a `const HloInstruction*`.
    288   Status EmitTargetElementLoop(
    289       HloInstruction* target_op,
    290       const llvm_ir::ElementGenerator& element_generator);
    291   Status EmitTargetElementLoop(
    292       HloInstruction* target_op, tensorflow::StringPiece desc,
    293       const llvm_ir::ElementGenerator& element_generator);
    294 
    295   // Emits a memcpy from the source instruction's result value to the
    296   // destination's.  Both source and destination must have an entry in the
    297   // emitted_value_ table.
    298   Status EmitMemcpy(const HloInstruction& source,
    299                     const HloInstruction& destination);
    300 
    301   // Emits IR to compute the target address of the buffer for the given op.
    302   // After calling this function, you can get a pointer to this buffer by
    303   // calling GetIrArrayForOp or GetEmittedValueFor.
    304   Status EmitTargetAddressForOp(const HloInstruction* op);
    305 
    306   // Structurizes "array_elements" into an MD array that represents "shape".
    307   // This is a recursive function, and "dimension_index" indicates the index of
    308   // the current dimension that the function is considering (0 means the
    309   // most-minor dimension).
    310   llvm::Constant* CreateInitializerForConstantArray(
    311       const std::vector<llvm::Constant*>& array_elements, const Shape& shape,
    312       int64 dimension_index);
    313 
    314   // Tries to codegen a reduction operation using vectorized instructions.
    315   // Returns true if successful, and false on failure.  On failure, sets
    316   // "failure_reason" to a string describing why it could not vectorize the
    317   // reduction.
    318   //
    319   // TODO(sanjoy): Some of the things we do here can be abstracted out into
    320   // concepts that generalize over other vectorizable operations.  We should
    321   // consider pulling out these abstractions into a VectorizingIrEmitter or
    322   // something similar.
    323   StatusOr<bool> EmitVectorizedReduce(
    324       HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value,
    325       tensorflow::gtl::ArraySlice<int64> dimensions, HloComputation* function,
    326       string* failure_reason);
    327 
    328   // We'd like to keep one or two one cache-line's worth of data in registers
    329   // without generating IR with illegal (e.g. excessively large or
    330   // non-power-of-two) vector types.  We do this by introducing a layer of
    331   // abstraction: we introduce a high level vector-like concept called a
    332   // "sharded vector" that models data paralleism, and is mapped to a sequence
    333   // scalar and vector llvm::Value s.
    334   //
    335   // For example, we can represent 29 f32 elements by a sharded vector mapped to
    336   // a sequence of LLVM values of types [<16 x f32>, <8 x f32>, <4 x f32>, f32].
    337   // Note that the last element is scalar.
    338   //
    339   // There is no requirement on the ordering or the uniqueness of the elements
    340   // mapped to sharded vectors -- we allow repeated elements, and we allow
    341   // elements to appear in any order.
    342   using ShardedVector = std::vector<llvm::Value*>;
    343 
    344   // A sharded vector type is the element-wise llvm::Type's of some
    345   // ShardedVector.
    346   using ShardedVectorType = std::vector<llvm::Type*>;
    347 
    348   // Create a sharded vector type corresponding to a "element_count" long
    349   // sequence of "element_type" values.
    350   ShardedVectorType CreateShardedVectorType(PrimitiveType element_type,
    351                                             unsigned element_count);
    352 
    353   // Emit LLVM IR to store the sharded vector "value_to_store" to
    354   // "store_address".
    355   void EmitShardedVectorStore(llvm::Value* store_address,
    356                               const ShardedVector& value_to_store,
    357                               const int alignment,
    358                               const llvm_ir::IrArray& containing_array);
    359 
    360   using ReductionGenerator = std ::function<llvm::Value*(
    361       llvm::IRBuilder<>*, llvm::Value*, llvm::Value*)>;
    362 
    363   // Tries to match the reduction function "function" to a known reduction
    364   // pattern.  Returns a non-null ReductionGenerator on a successful match,
    365   // which can be used to generate the LLVM IR corresponding to said reduction.
    366   // On failure, this stores a reason string into "failure_reason".
    367   ReductionGenerator MatchReductionGenerator(HloComputation* function,
    368                                              string* failure_reason) const;
    369 
    370   // Emits the inner loop nest that runs the reduction.  Helper function for
    371   // EmitVectorizedReduce.
    372   StatusOr<ShardedVector> EmitInnerLoopForVectorizedReduction(
    373       const ReductionGenerator& reduction_generator,
    374       const llvm_ir::IrArray::Index& output_index,
    375       const ShardedVectorType& accumulator_type, HloInstruction* init_value,
    376       HloInstruction* arg, tensorflow::gtl::ArraySlice<int64> dimensions,
    377       unsigned element_alignment);
    378 
    379   // Tries to emit a fast concatenate operation using memcpy.  Returns true if
    380   // successful, and false on failure.  On failure, sets "failure_reason" to a
    381   // string describing why it could not emit a fast concatenate.
    382   StatusOr<bool> EmitFastConcatenate(
    383       HloInstruction* concatenate,
    384       tensorflow::gtl::ArraySlice<HloInstruction*> operands,
    385       string* failure_reason);
    386 
    387   // Emits LLVM IR to transfer "element_count" elements of type "primitive_type"
    388   // from the address "source" to the address "target".
    389   void EmitTransferElements(llvm::Value* target, llvm::Value* source,
    390                             int64 element_count, PrimitiveType primitive_type,
    391                             const llvm_ir::IrArray& target_array,
    392                             const llvm_ir::IrArray& source_array);
    393 
    394   // Assignment of the temporary buffers needed by the computation and their
    395   // shape information.
    396   const BufferAssignment& assignment_;
    397 
    398   // The LLVM module into which IR will be emitted.
    399   llvm::Module* module_;
    400 
    401   // The target architecture.
    402   llvm::Triple::ArchType arch_type_;
    403 
    404   // Used to produce unique names for generated functions.
    405   NameUniquer name_uniquer_;
    406 
    407   // Map containing all previously emitted computations.
    408   std::map<HloComputation*, llvm::Function*> emitted_functions_;
    409 
    410   // Map containing all previously emitted thread-local temporary buffers.
    411   std::map<std::pair<llvm::Function*, BufferAllocation::Slice>,
    412            llvm::AllocaInst*>
    413       thread_local_buffers_;
    414 
    415   // The following fields track the IR emission state. According to LLVM memory
    416   // management rules, their memory is owned by the module (Note that IrFunction
    417   // creates the encapsulated llvm::Function s.t. it is added to the llvm
    418   // module's function list).
    419   std::unique_ptr<IrFunction> compute_function_;
    420   llvm::IRBuilder<> ir_builder_;
    421 
    422   // Maps HLO instructions to their index into the profile counter array.
    423   const std::unordered_map<const HloInstruction*, int64>
    424       instruction_to_profile_idx_;
    425 
    426   // Maps HLO computations to their index into the profile counter array.
    427   const std::unordered_map<const HloComputation*, int64>
    428       computation_to_profile_idx_;
    429 
    430   // Maps HLOs to Values emitted for them.
    431   std::unordered_map<const HloInstruction*, llvm::Value*> emitted_value_;
    432 
    433   llvm_ir::AliasAnalysis alias_analysis_;
    434 
    435   // The number of root instruction outer dimensions used in parallel loop
    436   // emission (ParallelLoopEmitter).
    437   int64 num_dynamic_loop_bounds_ = 0;
    438 
    439   // Returns whether the given instruction should be emitted as a parallel loop.
    440   bool ShouldEmitParallelLoopFor(const HloInstruction& op) const {
    441     // Emit parallel loop for root instruction if dynamic outer-dimension loop
    442     // bounds were specified.
    443     return num_dynamic_loop_bounds_ > 0 &&
    444            op.parent()->root_instruction() == &op;
    445   }
    446 
    447   // This struct contains all the state needed to emit instructions for
    448   // profiling a computation.
    449   class ProfilingState {
    450    public:
    451     ProfilingState() : use_rdtscp_(false), prof_counters_(nullptr) {}
    452     ProfilingState(bool use_rdtscp, llvm::Value* prof_counters)
    453         : use_rdtscp_(use_rdtscp), prof_counters_(prof_counters) {}
    454 
    455     // Record the cycle counter before an HLO executes.
    456     void RecordCycleStart(llvm::IRBuilder<>* ir_builder, HloInstruction* hlo);
    457     // Record the number of cycles it took for an HLO to execute.
    458     void RecordCycleDelta(llvm::IRBuilder<>* ir_builder, HloInstruction* hlo,
    459                           llvm::Value* prof_counter);
    460     // Record the number of cycles it took for the entire computation to
    461     // execute.
    462     void RecordCompleteComputation(llvm::IRBuilder<>* ir_builder,
    463                                    llvm::Value* prof_counter);
    464 
    465     // Convenience function to generate a call to an intrinsic which reads the
    466     // CPU cycle counter.
    467     llvm::Value* ReadCycleCounter(llvm::IRBuilder<>* ir_builder);
    468 
    469     // Store the cycle counter delta to the per-HLO profile counter.
    470     void UpdateProfileCounter(llvm::IRBuilder<>* ir_builder,
    471                               llvm::Value* prof_counter, llvm::Value* cycle_end,
    472                               llvm::Value* cycle_start);
    473 
    474    private:
    475     // Should we use the x86-specific rdtscp or the generic readcyclecounter
    476     // intrinsic?
    477     bool use_rdtscp_;
    478 
    479     // The argument which corresponds to the profile counter buffer.
    480     llvm::Value* prof_counters_;
    481 
    482     // The first read cycle counter in the program.
    483     llvm::Value* first_read_cycle_start_ = nullptr;
    484 
    485     // The last read cycle counter in the program.
    486     llvm::Value* last_read_cycle_end_ = nullptr;
    487 
    488     // An alloca used to hold the output of the aux value returned by the rdtscp
    489     // intrinsic.
    490     llvm::Value* aux_i8ptr_ = nullptr;
    491 
    492     // Maps HLOs to the value the cycle counter contained right before the HLO
    493     // began to execute.
    494     std::unordered_map<const HloInstruction*, llvm::Value*> cycle_starts_;
    495   };
    496 
    497   ProfilingState profiling_state_;
    498 
    499   // Given a load instruction and a shape or buffer size, annotate the load's
    500   // result with the alignment required by the shape or size.
    501   void AttachAlignmentMetadataForLoad(llvm::LoadInst* load, const Shape& shape);
    502   void AttachAlignmentMetadataForLoad(llvm::LoadInst* load, int64 buffer_size);
    503 
    504   // Given a load instruction and a shape or buffer size, annotate the load's
    505   // result with the dereferenceable bytes required by the shape / buffer size.
    506   void AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
    507                                             const Shape& shape);
    508   void AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
    509                                             int64 buffer_size);
    510 
    511   // Calculate the alignment of a buffer allocated for a given shape.
    512   int MinimumAlignmentForShape(const Shape& shape);
    513 
    514   // Calculate the alignment of a buffer allocated for a given primitive type.
    515   int MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type);
    516 
    517   // Calculate the alignment of a buffer with a particular size.
    518   int MinimumAlignmentForBufferSize(int64 buffer_size);
    519 
    520   // Returns the number of bytes within the shape.
    521   int64 ByteSizeOf(const Shape& shape) const;
    522 
    523   enum class XfeedKind {
    524     kInfeed,
    525     kOutfeed,
    526   };
    527 
    528   // Emit IR to transfer between a {infeed,outfeed} buffer and an in-program
    529   // address.
    530   Status EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
    531                            llvm::Value* program_buffer_address);
    532 
    533   const HloModuleConfig& hlo_module_config_;
    534 
    535   const bool parallel_cpu_backend_;
    536 
    537   bool is_top_level_computation_;
    538 
    539   TargetMachineFeatures target_machine_features_;
    540 
    541   int64 external_global_constant_counter_ = 0;
    542   ExternalConstantPool* external_constant_pool_;
    543 
    544   TF_DISALLOW_COPY_AND_ASSIGN(IrEmitter);
    545 };
    546 
    547 }  // namespace cpu
    548 }  // namespace xla
    549 
    550 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_
    551