Home | History | Annotate | Download | only in gpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_
     18 
     19 #include <functional>
     20 #include <map>
     21 #include <memory>
     22 #include <utility>
     23 #include <vector>
     24 
     25 #include "llvm/IR/Function.h"
     26 #include "llvm/IR/IRBuilder.h"
     27 #include "llvm/IR/Value.h"
     28 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
     29 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
     30 #include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h"
     31 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
     32 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
     33 #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h"
     34 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
     35 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     36 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     37 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
     38 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
     39 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
     40 #include "tensorflow/compiler/xla/statusor.h"
     41 #include "tensorflow/compiler/xla/types.h"
     42 #include "tensorflow/compiler/xla/xla_data.pb.h"
     43 #include "tensorflow/core/lib/core/stringpiece.h"
     44 #include "tensorflow/core/lib/gtl/array_slice.h"
     45 #include "tensorflow/core/platform/types.h"
     46 
     47 namespace xla {
     48 namespace gpu {
     49 
     50 // Abstract base class for translating HLO graphs to LLVM IR for a GPU.
     51 //
     52 // There are two concrete subclasses of IrEmitter: IrEmitterNested and
     53 // IrEmitterUnnested.  In the unnested variety, each HLO gets its own kernel
     54 // function, whereas in the nested version the whole computation is emitted as
     55 // one *non-kernel* function.
     56 //
     57 // In XLA, kernel functions never call other kernel functions.  This means that
     58 // if we have a kernel -- e.g. implementing a kReduce HLO -- that wants to use
     59 // an HLO computation as a "subroutine" -- e.g. the HLO computation that
     60 // specifies how to reduce two elements -- then the subroutine computation must
     61 // be emitted using IrEmitterNested.
     62 //
     63 // Fusion nodes are a special case.  A fusion node is emitted using
     64 // IrEmitterUnnested, but the code is generated using FusedIrEmitter, which is
     65 // not a subclass of gpu::IrEmitter, and in fact is better understood as an IR
     66 // generator generator.  See comments on that class.
     67 class IrEmitter : public DfsHloVisitorWithDefault {
     68  public:
     69   IrEmitter(const IrEmitter&) = delete;
     70   IrEmitter& operator=(const IrEmitter&) = delete;
     71 
     72   Status DefaultAction(HloInstruction* hlo) override;
     73   Status HandleConstant(HloInstruction* constant) override;
     74   Status HandleBitcast(HloInstruction* bitcast) override;
     75   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
     76   Status HandleDot(HloInstruction* dot) override;
     77   Status HandleConvolution(HloInstruction* convolution) override;
     78   Status HandleFft(HloInstruction* fft) override;
     79   Status HandleCrossReplicaSum(HloInstruction* crs) override;
     80   Status HandleInfeed(HloInstruction* infeed) override;
     81   Status HandleOutfeed(HloInstruction* outfeed) override;
     82   Status HandleSort(HloInstruction* sort) override;
     83   Status HandleSend(HloInstruction* send) override;
     84   Status HandleSendDone(HloInstruction* send_done) override;
     85   Status HandleRecv(HloInstruction* recv) override;
     86   Status HandleRecvDone(HloInstruction* recv_done) override;
     87   Status HandleParameter(HloInstruction* parameter) override;
     88   Status HandleReduce(HloInstruction* reduce) override;
     89   Status HandleTuple(HloInstruction* tuple) override;
     90   Status HandleSelect(HloInstruction* select) override;
     91   Status HandleFusion(HloInstruction* fusion) override;
     92   Status HandleCall(HloInstruction* call) override;
     93   Status HandleCustomCall(HloInstruction* custom_call) override;
     94   Status HandleRng(HloInstruction* random) override;
     95   Status HandleBatchNormInference(HloInstruction* batch_norm) override;
     96   Status HandleBatchNormTraining(HloInstruction* batch_norm) override;
     97   Status HandleBatchNormGrad(HloInstruction* batch_norm) override;
     98 
     99   Status FinishVisit(HloInstruction* root) override { return Status::OK(); }
    100 
    101  protected:
    102   // Constructs an IrEmitter with the given IrEmitter context.
    103   // ir_emitter_context is owned by the caller and should outlive the IrEmitter
    104   // object.
    105   explicit IrEmitter(const HloModuleConfig& hlo_module_config,
    106                      IrEmitterContext* ir_emitter_context, bool is_nested);
    107 
    108   // Helper for calling HloToIrBindings::GetIrArray.
    109   //
    110   // Gets the IrArray which contains inst.  This array has metadata that makes
    111   // it valid only within the IR that implements consumer.  If you are
    112   // implementing an HLO and want to get its own output buffer, call
    113   // GetIrArray(hlo, hlo).
    114   llvm_ir::IrArray GetIrArray(const HloInstruction& inst,
    115                               const HloInstruction& consumer,
    116                               const ShapeIndex& shape_index = {}) {
    117     return bindings_.GetIrArray(inst, consumer, shape_index);
    118   }
    119   // A convenient helper for calling HloToIrBindings::GetBasePointer.
    120   llvm::Value* GetBasePointer(const HloInstruction& inst) const {
    121     return bindings_.GetBasePointer(inst);
    122   }
    123   // A convenient helper for calling BufferAssignment::GetUniqueTopLevelSlice.
    124   BufferAllocation::Slice GetAllocationSlice(const HloInstruction& hlo) const {
    125     return ir_emitter_context_->buffer_assignment()
    126         .GetUniqueTopLevelSlice(&hlo)
    127         .ConsumeValueOrDie();
    128   }
    129 
    130   // Emit a singlethreaded or multithreaded loop that computes every element in
    131   // the result of the given HLO instruction. This produces a series of nested
    132   // loops (e.g. one for each dimension of the `hlo`'s shape). The body of the
    133   // inner-most loop is provided by the body_emitter function.
    134   virtual Status EmitTargetElementLoop(
    135       const HloInstruction& hlo,
    136       const llvm_ir::ElementGenerator& body_emitter) = 0;
    137 
    138   // Emits a call in IR to the given nested computation with the given operands
    139   // and output. If no IR function has been previously emitted for the
    140   // computation, also emits such a function.
    141   Status EmitCallToNestedComputation(
    142       const HloComputation& nested_computation,
    143       tensorflow::gtl::ArraySlice<llvm::Value*> operands, llvm::Value* output);
    144 
    145   // Emits an atomic operation that implements `nested_computation` in the
    146   // sequentially consistent memory model. `output_address` and `source_address`
    147   // are the arguments of the nested computation. For example,
    148   // atomicAdd(output_address, *source_address).
    149   Status EmitAtomicOperationForNestedComputation(
    150       const HloComputation& nested_computation, llvm::Value* output_address,
    151       llvm::Value* source_address);
    152 
    153   GpuElementalIrEmitter::NestedComputer GetNestedComputer() {
    154     return std::bind(&IrEmitter::ComputeNestedElement, this,
    155                      std::placeholders::_1, std::placeholders::_2);
    156   }
    157 
    158   IrEmitterContext* ir_emitter_context_;
    159   llvm::Module* module_;
    160 
    161   // The following fields track the IR emission state. According to LLVM memory
    162   // management rules, their memory is owned by the module.
    163   llvm::IRBuilder<> ir_builder_;
    164 
    165   // Mapping from HLO to its underlying LLVM value.
    166   HloToIrBindings bindings_;
    167 
    168   // Hlo configuration data used during code generation.
    169   const HloModuleConfig& hlo_module_config_;
    170 
    171  private:
    172   // Emits a series of nested loops for iterating over an operand array in the
    173   // dot operation. Loops are constructed in major to minor dimension layout
    174   // order. No loop is emitted for the given reduction_dimension. The function
    175   // returns an IrArray index for the given operand_array containing the indvars
    176   // of the loops. All dimensions of the index are filled except for the
    177   // reduction dimension. name_suffix is the string to append to the names of
    178   // LLVM constructs (eg, basic blocks) constructed by this method.
    179   llvm_ir::IrArray::Index EmitOperandArrayLoopNest(
    180       const llvm_ir::IrArray& operand_array, int64 reduction_dimension,
    181       tensorflow::StringPiece name_suffix, llvm_ir::ForLoopNest* loop_nest);
    182 
    183   // A helper method for EmitAtomicOperationForNestedComputation. Certain
    184   // computations, such as floating-point addition and integer maximization, can
    185   // be simply implemented using an LLVM atomic instruction. If "computation" is
    186   // one of this kind, emits code to do that and returns true; otherwise,
    187   // returns false.
    188   bool MaybeEmitDirectAtomicOperation(const HloComputation& computation,
    189                                       llvm::Value* output_address,
    190                                       llvm::Value* source_address);
    191 
    192   // A helper method for EmitAtomicOperationForNestedComputation. It implements
    193   // binary atomic operations using atomicCAS with special handling to support
    194   // small data types.
    195   Status EmitAtomicOperationUsingCAS(const HloComputation& computation,
    196                                      llvm::Value* output_address,
    197                                      llvm::Value* source_address);
    198 
    199   StatusOr<llvm::Value*> ComputeNestedElement(
    200       const HloComputation& computation,
    201       tensorflow::gtl::ArraySlice<llvm::Value*> parameter_elements);
    202 
    203   // Emits an atomic operation that implements `nested_computation` in the
    204   // sequentially consistent memory model. `output_address` and `source_address`
    205   // are the arguments of the nested computation. For example,
    206   // atomicAdd(output_address, *source_address).
    207   StatusOr<llvm::Function*> EmitAtomicFunctionForNestedComputation(
    208       const HloComputation& nested_computation, llvm::Type* element_ir_type);
    209 
    210   // Map nested computations to emitted IR functions. This serves as a cache so
    211   // that IrEmitter does not emit multiple functions for the same
    212   // HloComputation.
    213   std::map<const HloComputation*, llvm::Function*> computation_to_ir_function_;
    214 };
    215 
    216 }  // namespace gpu
    217 }  // namespace xla
    218 
    219 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_
    220