Home | History | Annotate | Download | only in gpu
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_
     18 
     19 #include "absl/container/inlined_vector.h"
     20 #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h"
     21 #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
     22 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
     23 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
     24 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h"
     25 
     26 namespace xla {
     27 namespace gpu {
     28 
     29 // Emits LLVM IR for an "unnested computation".
     30 //
     31 // An unnested computation is an HloComputation which you run by executing one
     32 // or more kernels for each HloInstruction it contains.  Examples of unnested
     33 // computations:
     34 //
     35 //  - An HloModule's root computation,
     36 //  - The body of an HLO while loop,
     37 //  - The true/false computation of an HLO conditional.
     38 //
     39 // Note the opportunity for confusion -- the while loop's computation is nested
     40 // within the root computation, but it's emitted using IrEmitterUnnested!  Don't
     41 // think about it too hard.
     42 //
     43 // Examples of things that are not unnested computations:
     44 //
     45 //  - The reducer of a kReduce HLO.  This is emitted using IrEmitterNested.
     46 //  - The body of a fusion node.  IrEmitterUnenested emits the relevant code
     47 //    within a kernel function using FusedIrEmitter.  (FusedIrEmitter is not
     48 //    really an IrEmitter, but is more an "IR generator generator".)
     49 //
     50 class IrEmitterUnnested : public IrEmitter {
     51  public:
     52   // Parameter block_contains_multi_tiles indicates whether a tile block
     53   // consists of multiple tiles or not. If the tile block contains only one
     54   // tile, there is no need to use atomic operation to accumulate a local result
     55   // to a global result to implement reduction.
     56   using TileGenerator =
     57       std::function<void(const llvm_ir::IrArray::Index& output_tile_origin,
     58                          absl::Span<llvm::Value* const> output_tile_bounds,
     59                          bool block_contains_multi_tiles)>;
     60   // KernelCodegenInfo records the common information to support the code
     61   // generation for a kernel to process tensor elements by blocks. A block of
     62   // tensor elements may contain one or multiple tiles. The code generators that
     63   // generate code for tile elements or block prologue/epilogue refer to this
     64   // class in their prototypes. If the implementations of such code generators
     65   // require other information that are specific to the HLO instructions, the
     66   // implementations need to define and use derived classes of this class.
     67   class KernelCodegenInfo {
     68    public:
     69     explicit KernelCodegenInfo(llvm_ir::KernelMappingScheme* mapping_scheme)
     70         : mapping_scheme_(mapping_scheme),
     71           tiled_param_info_(nullptr),
     72           lane_id_(nullptr),
     73           index_ty_(nullptr) {}
     74     virtual ~KernelCodegenInfo() {}
     75 
     76     void SetLaneId(llvm::Value* v) { lane_id_ = v; }
     77     void SetIndexType(llvm::Type* t) { index_ty_ = t; }
     78     void SetTiledParamInfo(llvm_ir::TiledParameterInfo* tiled_param_info) {
     79       tiled_param_info_ = tiled_param_info;
     80     }
     81 
     82     llvm::Value* GetLaneId() const { return lane_id_; }
     83     llvm_ir::KernelMappingScheme* GetKernelMappingScheme() const {
     84       return mapping_scheme_;
     85     }
     86     llvm_ir::TiledParameterInfo* GetTiledParameterInfo() const {
     87       return tiled_param_info_;
     88     }
     89     llvm::Type* GetIndexType() const { return index_ty_; }
     90 
     91    protected:
     92     llvm_ir::KernelMappingScheme* mapping_scheme_;
     93     llvm_ir::TiledParameterInfo* tiled_param_info_;
     94     llvm::Value* lane_id_;
     95     llvm::Type* index_ty_;
     96   };
     97 
     98   // A function object to prepare for the code generation for a tile block.
     99   using BlockPrologueGenerator =
    100       std::function<void(HloInstruction* hlo, KernelCodegenInfo* kernel_info)>;
    101   // A function object to finalize the code generation for a tile block.
    102   using BlockEpilogueGenerator =
    103       std::function<void(HloInstruction* hlo, KernelCodegenInfo* kernel_info)>;
    104   // A function object to generate code to process one element in a tile.
    105   //
    106   // hlo: the instruction for which the code is generated for.
    107   // index: the index for the first output element of the current thread.
    108   // y_loc: The y coordinate within a tile.
    109   // x_loc: The x coordinate within a tile.
    110   // kernel_info: Other information to support the kernel code generation.
    111   // x_iter_num: When a thread process N elements in the X dimension, x_iter_num
    112   //             has a value of 0..N-1 to identify the element being process.
    113   using TileElementGenerator = std::function<void(
    114       HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
    115       const KernelCodegenInfo* kernel_info, llvm::Value* y_loc,
    116       llvm::Value* x_loc, int64 x_iter_num)>;
    117 
    118   // KernelCodeGenerator records the code generator objects that generate code
    119   // for tile elements or tile block prologue/epilogue.
    120   class KernelCodeGenerator {
    121    public:
    122     explicit KernelCodeGenerator(
    123         TileElementGenerator tile_element_generator,
    124         BlockPrologueGenerator block_prologue_generator = {},
    125         BlockEpilogueGenerator block_epilogue_generator = {})
    126         : tile_element_generator_(std::move(tile_element_generator)),
    127           block_prologue_generator_(std::move(block_prologue_generator)),
    128           block_epilogue_generator_(std::move(block_epilogue_generator)) {}
    129 
    130     const TileElementGenerator& GetTileElementGenerator() const {
    131       return tile_element_generator_;
    132     }
    133     const BlockPrologueGenerator& GetBlockPrologueGenerator() const {
    134       return block_prologue_generator_;
    135     }
    136     const BlockEpilogueGenerator& GetBlockEpilogueGenerator() const {
    137       return block_epilogue_generator_;
    138     }
    139 
    140    private:
    141     TileElementGenerator tile_element_generator_;
    142     BlockPrologueGenerator block_prologue_generator_;
    143     BlockEpilogueGenerator block_epilogue_generator_;
    144   };
    145 
    146   IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
    147                     const HloComputation* hlo_computation,
    148                     IrEmitterContext* ir_emitter_context);
    149   IrEmitterUnnested(const IrEmitterUnnested&) = delete;
    150   IrEmitterUnnested& operator=(const IrEmitterUnnested&) = delete;
    151 
    152   // Transfers the ownship of thunk_sequence_ out.
    153   std::unique_ptr<ThunkSequence> ConsumeThunkSequence() {
    154     return std::move(thunk_sequence_);
    155   }
    156 
    157   Status DefaultAction(HloInstruction* hlo) override;
    158 
    159   // IrEmitterUnnested handles the following instructions differently from
    160   // IrEmitter.
    161   Status HandleCopy(HloInstruction* copy) override;
    162   Status HandleConditional(HloInstruction* conditional) override;
    163   Status HandleConvolution(HloInstruction* convolution) override;
    164   Status HandleCustomCall(HloInstruction* custom_call) override;
    165   Status HandleDot(HloInstruction* dot) override;
    166   Status HandleFft(HloInstruction* fft) override;
    167   Status HandleFusion(HloInstruction* fusion) override;
    168   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
    169   Status HandleReduce(HloInstruction* reduce) override;
    170   Status HandleSelectAndScatter(HloInstruction* instruction) override;
    171   Status HandleTuple(HloInstruction* tuple) override;
    172   Status HandleWhile(HloInstruction* xla_while) override;
    173   Status HandleInfeed(HloInstruction* xla_infeed) override;
    174   Status HandleOutfeed(HloInstruction* outfeed) override;
    175   Status HandleRng(HloInstruction* random) override;
    176   Status HandleScatter(HloInstruction* scatter) override;
    177   Status HandleSelect(HloInstruction* select) override;
    178   Status HandleSort(HloInstruction* sort) override;
    179   Status HandleTriangularSolve(HloInstruction* hlo) override;
    180   Status HandleTupleSelect(HloInstruction* tuple_select) override;
    181   Status HandleAllReduce(HloInstruction* crs) override;
    182   Status HandleAfterAll(HloInstruction* after_all) override;
    183 
    184   Status EmitTargetElementLoop(
    185       const HloInstruction& hlo,
    186       const llvm_ir::ElementGenerator& body_emitter) override;
    187 
    188   // Same as `EmitTargetElementLoop`, but in given `thunk` rather than
    189   // `LastThunk()`.
    190   Status EmitTargetElementLoopInThunk(
    191       const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter,
    192       KernelThunk* thunk);
    193 
    194   // Emits LLVM global variables corresponding to constant instructions.
    195   Status EmitConstantGlobals();
    196 
    197  private:
    198   // Add a owning Thunk object to the thunk sequence.
    199   void AddThunkToThunkSequence(std::unique_ptr<Thunk> thunk) {
    200     thunk_sequence_->emplace_back(std::move(thunk));
    201   }
    202 
    203   // Builds the prototype of the IR kernel for `inst` and adds it to the module.
    204   // This kernel takes as arguments pointers to the given buffer allocations.
    205   llvm::Function* BuildKernelPrototype(
    206       const HloInstruction& inst,
    207       absl::Span<const BufferAllocation* const> args);
    208 
    209   // Helper for writing extra outputs from inside a reduce kernel.
    210   Status EmitExtraOutputsForReduce(
    211       const HloInstruction* unnested_hlo, const llvm_ir::IrArray::Index& index,
    212       absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
    213           extra_output_gens);
    214 
    215   // Generates code for reduction to contiguous dimensions.
    216   //
    217   // Prerequisite: `IsReductionToVector(*unnested_hlo)`
    218   Status EmitReductionToVector(HloInstruction* unnested_hlo);
    219 
    220   // Computes the KernelMappingScheme for the reduce HLO and indicates whether
    221   // the reduction is a row reduction. For an un-fused reduce op, unnested_hlo
    222   // and first_reduce are the same instruction. For a kInput fusion,
    223   // unnested_hlo is the fusion instruction while first_reduce is the first
    224   // reduce op.
    225   std::tuple<llvm_ir::KernelMappingScheme, bool>
    226   ComputeMappingSchemeAndReductionKind(const HloInstruction* unnested_hlo,
    227                                        const HloInstruction* first_reduce);
    228 
    229   // Emits code for an in-place scatter, modifying `thunk`s launch dimensions in
    230   // the process. `scatter` may be fused, scatter indices are taken from
    231   // `scatter_indices_gen`, updates from`updates_gen`. The output buffer is
    232   // expected to have the operand values in it already.
    233   Status EmitScatter(Thunk* thunk, HloInstruction* scatter,
    234                      const llvm_ir::ElementGenerator& scatter_indices_gen,
    235                      const llvm_ir::ElementGenerator& updates_gen);
    236 
    237   // Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel
    238   // for the hlo instruction.
    239   bool CheckAndEmitHloWithTile021(HloInstruction* hlo);
    240   // Emits a kernel for the hlo instruction using a 0-2-1 tiling algorithm and
    241   // returns the launch dimensions for the kernel. This is a helper to support
    242   // the implementation of CheckAndEmitHloWithTile021.
    243   LaunchDimensions EmitHlo021Tile(HloInstruction* hlo,
    244                                   absl::Span<const int64> reduced_output_dims,
    245                                   absl::Span<const int64> tiled_param_ids);
    246   // Emits a kernel for an unnested HLO instruction.
    247   LaunchDimensions EmitKernel(HloInstruction* unnested_hlo,
    248                               absl::Span<const int64> param_ids,
    249                               const KernelCodeGenerator& kernel_generator,
    250                               KernelCodegenInfo* kernel_info);
    251   void EmitBlock(const TileGenerator& emit_one_tile,
    252                  KernelCodegenInfo* kernel_info, KernelSupportLibrary* ksl,
    253                  llvm::Type* index_ty);
    254   // Emits code to process a tensor element in a tile for the given kCopy HLO
    255   // that performs a 0-2-1 transpose.
    256   void EmitTileElementForCopy(HloInstruction* hlo,
    257                               const llvm_ir::IrArray::Index& index,
    258                               const KernelCodegenInfo* kernel_info,
    259                               llvm::Value* y_loc, llvm::Value* x_loc,
    260                               int64 x_iter_num);
    261   // Emits code to process a tensor element in a tile for the given kLoop fusion
    262   // HLO containing parameters that are 0-2-1 transpose of its outputs.
    263   void EmitTileElementForFusion(HloInstruction* hlo,
    264                                 const llvm_ir::IrArray::Index& index,
    265                                 const KernelCodegenInfo* kernel_info,
    266                                 llvm::Value* y_loc, llvm::Value* x_loc,
    267                                 int64 x_iter_num);
    268   // Emits code to process a tensor element in a tile for the given input hlo
    269   // that is either a unnested kReduce or a kInput fusion.
    270   void EmitTileElementForReduction(HloInstruction* unnested_hlo,
    271                                    const llvm_ir::IrArray::Index& index,
    272                                    const KernelCodegenInfo* kernel_info,
    273                                    llvm::Value* y_loc, llvm::Value* x_loc,
    274                                    int64 x_iter_num);
    275   // Prepares for the code generation for a tile block of a reduction kernel.
    276   void EmitPrologueForReduction(HloInstruction* unnested_hlo,
    277                                 KernelCodegenInfo* kernel_info);
    278   void EmitPrologueForOneReduction(HloInstruction* unnested_hlo,
    279                                    HloInstruction* reduce_inst, int reduce_idx,
    280                                    KernelCodegenInfo* kernel_info,
    281                                    GpuElementalIrEmitter* elemental_emitter,
    282                                    ShapeIndex output_shape_index);
    283   // Wraps up the code generation for a tile block of a reduction kernel.
    284   void EmitEpilogueForReduction(HloInstruction* unnested_hlo,
    285                                 KernelCodegenInfo* kernel_info);
    286   // For each reducer, emits the shuffle-down loop to accumulate the partial
    287   // result to the global result.
    288   void EmitFullWarpShuffleDownLoopForAllReduces(
    289       absl::Span<HloComputation* const> reducers,
    290       absl::Span<llvm::AllocaInst* const> partial_result_addresses);
    291 
    292   // Generates the IrArray for each input of an hlo and returns a vector that
    293   // constains such IrArrays.
    294   std::vector<llvm_ir::IrArray> ConstructIrArrayForInputs(
    295       const HloInstruction& hlo);
    296 
    297   // For each input of the `hlo` instruction, checks its value in
    298   // `param_buffers` to find out whether the input has a reduced shape. If the
    299   // input has a reduced shape, constructs the reduced shape for the input and
    300   // casts the original input IrArray in `param_arrays` to the reduced shape.
    301   // Return the total number of inputs.
    302   int ConstructInputReducedShapeAndCastInputIrArrayToShape(
    303       const HloInstruction& hlo,
    304       const std::vector<llvm_ir::IrArray>& param_arrays,
    305       const std::vector<llvm::Value*>& param_buffers,
    306       absl::Span<const int64> reduced_output_dims,
    307       std::vector<Shape>* param_reduced_shapes,
    308       std::vector<llvm_ir::IrArray>* param_in_reduced_shape_arrays);
    309 
    310   // Returns a KernelThunk that invokes the kernel emitted for `inst`. The
    311   // caller needs to make sure `inst` outlives the lifetime of the returned
    312   // Thunk object. The kernel implementation will be unrolled if unroll_factor
    313   // is greater than one. 'implements_whole_instruction' specifies whether this
    314   // KernelThunk implements the whole 'inst' HloInstruction. In some cases
    315   // 'inst' will be implemented by a sequence of Thunks.
    316   std::unique_ptr<KernelThunk> BuildKernelThunk(
    317       const HloInstruction* inst, bool implements_whole_instruction,
    318       int unroll_factor = 1);
    319 
    320   // Returns a FftThunk that calls cuFFT to implement `inst`.
    321   std::unique_ptr<Thunk> BuildFftThunk(const HloInstruction* inst);
    322 
    323   // Returns a CholeskyThunk that calls cuSolver to implement `inst`.
    324   std::unique_ptr<Thunk> BuildCholeskyThunk(const HloInstruction* inst);
    325 
    326   // Returns a TriangularSolveThunk that calls cuBlas to implement `inst`.
    327   std::unique_ptr<Thunk> BuildTriangularSolveThunk(const HloInstruction* inst);
    328 
    329   // Returns a GemmThunk that calls gemm to implement `inst`. The caller needs
    330   // to make sure `inst` outlives the lifetime of the returned Thunk object.
    331   std::unique_ptr<Thunk> BuildGemmThunk(const HloInstruction* inst);
    332 
    333   // Returns a thunk that, given a reduce or select-and-scatter op, initializes
    334   // its memory to the appropriate initial value.
    335   StatusOr<std::unique_ptr<Thunk>> BuildInitializerThunk(
    336       HloInstruction* hlo, const ShapeIndex& index = {});
    337 
    338   // Returns a thunk that calls host-to-device cuMemcpy to implement `inst`.
    339   std::unique_ptr<Thunk> BuildHostToDeviceCopyThunk(const HloInstruction* inst);
    340 
    341   // Returns a thunk that calls device-to-device cuMemcpy to implement `inst`.
    342   std::unique_ptr<Thunk> BuildDeviceToDeviceCopyThunk(
    343       const HloInstruction* inst);
    344 
    345   // Returns an InfeedThunk that performs a host-to-device memcpy to implement
    346   // `inst`.
    347   std::unique_ptr<Thunk> BuildInfeedThunk(const HloInstruction* inst);
    348 
    349   // Returns an OutfeedThunk that performs a device-to-host memcpy to implement
    350   // `inst`.
    351   std::unique_ptr<Thunk> BuildOutfeedThunk(const HloInstruction* inst);
    352 
    353   // Returns a WhileThunk that invokes thunk sequences for 'condition' and
    354   // 'body' sub-computations of while instruction 'hlo'.
    355   std::unique_ptr<Thunk> BuildWhileThunk(const HloInstruction* hlo);
    356 
    357   // Returns a ForThunk which executes 'loop_limit' invocations of a thunk
    358   // sequence from the 'body' sub-computation of the while instruction 'hlo'.
    359   std::unique_ptr<Thunk> BuildForThunk(const HloInstruction* hlo,
    360                                        const int64 loop_limit);
    361 
    362   // Returns a ConditionalThunk which executes the thunk sequence for the
    363   // 'branch_computation' corresponding to the predicate/branch_index of the
    364   // given conditional instruction.
    365   std::unique_ptr<Thunk> BuildConditionalThunk(const HloInstruction* hlo);
    366 
    367   Status Postprocess(HloInstruction* hlo) override;
    368 
    369   // Returns the last generated thunk.
    370   Thunk* LastThunk() const { return thunk_sequence_->back().get(); }
    371 
    372   // The thunk sequence this IrEmitter generates for the input computation.
    373   std::unique_ptr<ThunkSequence> thunk_sequence_;
    374 
    375   // The HloComputation that this IrEmitter emits code for.
    376   const HloComputation* hlo_computation_;
    377 };
    378 
    379 }  // namespace gpu
    380 }  // namespace xla
    381 
    382 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_
    383