Home | History | Annotate | Download | only in gpu
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_
     18 
     19 #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h"
     20 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
     21 
     22 namespace xla {
     23 namespace gpu {
     24 
     25 // Emits LLVM IR for an "unnested computation".
     26 //
     27 // An unnested computation is an HloComputation which you run by executing one
     28 // or more kernels for each HloInstruction it contains.  Examples of unnested
     29 // computations:
     30 //
     31 //  - An HloModule's root computation,
     32 //  - The body of an HLO while loop,
     33 //  - The true/false computation of an HLO conditional.
     34 //
     35 // Note the opportunity for confusion -- the while loop's computation is nested
     36 // within the root computation, but it's emitted using IrEmitterUnnested!  Don't
     37 // think about it too hard.
     38 //
     39 // Examples of things that are not unnested computations:
     40 //
     41 //  - The reducer of a kReduce HLO.  This is emited using IrEmitterNested.
     42 //  - The body of a fusion node.  IrEmitterUnenested emits the relevant code
     43 //    within a kernel function using FusedIrEmitter.  (FusedIrEmitter is not
     44 //    really an IrEmitter, but is more an "IR generator generator".)
     45 //
     46 class IrEmitterUnnested : public IrEmitter {
     47  public:
     48   IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
     49                     const HloComputation* hlo_computation,
     50                     IrEmitterContext* ir_emitter_context);
     51   IrEmitterUnnested(const IrEmitterUnnested&) = delete;
     52   IrEmitterUnnested& operator=(const IrEmitterUnnested&) = delete;
     53 
     54   // Transfers the ownship of thunk_sequence_ out.
     55   std::unique_ptr<ThunkSequence> ConsumeThunkSequence() {
     56     return std::move(thunk_sequence_);
     57   }
     58 
     59   Status DefaultAction(HloInstruction* hlo) override;
     60 
     61   // IrEmitterUnnested handles the following instructions differently from
     62   // IrEmitter.
     63   Status HandleCopy(HloInstruction* copy) override;
     64   Status HandleConditional(HloInstruction* conditional) override;
     65   Status HandleConvolution(HloInstruction* convolution) override;
     66   Status HandleCustomCall(HloInstruction* custom_call) override;
     67   Status HandleDot(HloInstruction* dot) override;
     68   Status HandleFft(HloInstruction* fft) override;
     69   Status HandleFusion(HloInstruction* fusion) override;
     70   Status HandleGather(HloInstruction* gather) override;
     71   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
     72   Status HandleReduce(HloInstruction* reduce) override;
     73   Status HandleSelectAndScatter(HloInstruction* instruction) override;
     74   Status HandleTuple(HloInstruction* tuple) override;
     75   Status HandleWhile(HloInstruction* xla_while) override;
     76   Status HandleInfeed(HloInstruction* xla_infeed) override;
     77   Status HandleRng(HloInstruction* random) override;
     78   Status HandleSelect(HloInstruction* select) override;
     79 
     80   Status EmitTargetElementLoop(
     81       const HloInstruction& hlo,
     82       const llvm_ir::ElementGenerator& body_emitter) override;
     83 
     84   // Same as `EmitTargetElementLoop`, but in given `thunk` rather than
     85   // `LastThunk()`.
     86   Status EmitTargetElementLoopInThunk(
     87       const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter,
     88       KernelThunk* thunk);
     89 
     90  private:
     91   // Builds the appropriate thunk for the instruction hlo and returns the owning
     92   // pointer to it. The caller needs to make sure `inst` outlives the lifetime
     93   // of the returned Thunk object.
     94   std::unique_ptr<Thunk> BuildThunk(const HloInstruction* hlo);
     95 
     96   // Builds the prototype of the IR kernel for `inst` and adds it to the module.
     97   // This kernel takes as arguments pointers to the given buffer allocations.
     98   llvm::Function* BuildKernelPrototype(
     99       const HloInstruction& inst,
    100       tensorflow::gtl::ArraySlice<const BufferAllocation*> args);
    101 
    102   // EmitColumnReduction and EmitRowReduction emit code for column and row
    103   // reduction of a matrix and/or 3D tensor. Row and column reduction have
    104   // different memory access pattern, so for performance their implementations
    105   // are significantly different.
    106   //
    107   // Emits code that reduces a matrix of shape [height x width] to a vector of
    108   // [width]. Other parameters have the same meaning as those of
    109   // `EmitReductionToVector`. Note that input shape might not be
    110   // [height x width], but can be bitcast to [height x weight] with "height"
    111   // being the major dimension.
    112   Status EmitColumnReduction(int64 height, int64 width, HloInstruction* reduce,
    113                              const Shape& input_shape,
    114                              const llvm_ir::ElementGenerator& input_gen,
    115                              const llvm_ir::ElementGenerator& init_value_gen,
    116                              HloComputation* reducer);
    117 
    118   // Emits code that reduces a 3D tensor of shape [depth x height x width] to a
    119   // vector of shape [height]. Other parameters have the same meaning as those
    120   // of `EmitReductionToVector`. Note that input shape might not be
    121   // [depth x height x width], but can be bitcast to [depth x height x weight]
    122   // with "depth" being the most major dimension.
    123   Status EmitRowReduction(int64 depth, int64 height, int64 width,
    124                           HloInstruction* reduce, const Shape& input_shape,
    125                           const llvm_ir::ElementGenerator& input_gen,
    126                           const llvm_ir::ElementGenerator& init_value_gen,
    127                           HloComputation* reducer);
    128 
    129   // Emits code that reduces a tensor of arbitrary rank to a scalar.
    130   Status EmitReductionToScalar(HloInstruction* reduce, const Shape& input_shape,
    131                                const llvm_ir::ElementGenerator& input_gen,
    132                                const llvm_ir::ElementGenerator& init_value_gen,
    133                                HloComputation* reducer);
    134 
    135   // Figures out whether `reduce` is a row or column reduction, and which
    136   // dimensions to reduce, and calls either `EmitRowReduction` or
    137   // `EmitColumnReduction` as appropriate. `input_shape` is the shape of the
    138   // input array, which is the operand of the Reduce instruction if unfused or
    139   // of the Fusion instruction if fused. `input_gen` and `init_value_gen`
    140   // generate elements of the input and the initial value. Other parameters mean
    141   // the same as for `HandleReduce`.
    142   //
    143   // Prerequisite: `IsReductionToVector(*reduce)`
    144   Status EmitReductionToVector(
    145       HloInstruction* reduce, const Shape& input_shape,
    146       const llvm_ir::ElementGenerator& input_gen,
    147       const llvm_ir::ElementGenerator& init_value_gen,
    148       tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
    149       HloComputation* reducer);
    150 
    151   // Emits code to initialize buffer of `inst` in given `thunk`.
    152   Status EmitInitializer(const HloInstruction* inst, KernelThunk* thunk);
    153 
    154   // Returns a KernelThunk that invokes the kernel emitted for `inst`. The
    155   // caller needs to make sure `inst` outlives the lifetime of the returned
    156   // Thunk object.
    157   std::unique_ptr<Thunk> BuildKernelThunk(const HloInstruction* inst);
    158 
    159   // Returns a FftThunk that calls cuFFT to implement `inst`.
    160   std::unique_ptr<Thunk> BuildFftThunk(const HloInstruction* inst);
    161 
    162   // Returns a GemmThunk that calls gemm to implement `inst`. The caller needs
    163   // to make sure `inst` outlives the lifetime of the returned Thunk object.
    164   std::unique_ptr<Thunk> BuildGemmThunk(const HloInstruction* inst);
    165 
    166   // Returns a thunk that calls host-to-device cuMemcpy to implement `inst`.
    167   std::unique_ptr<Thunk> BuildHostToDeviceCopyThunk(const HloInstruction* inst);
    168 
    169   // Returns a thunk that calls device-to-device cuMemcpy to implement `inst`.
    170   std::unique_ptr<Thunk> BuildDeviceToDeviceCopyThunk(
    171       const HloInstruction* inst);
    172 
    173   // Returns an InfeedThunk that performs device-to-device memcpy to implement
    174   // `inst`.
    175   std::unique_ptr<Thunk> BuildInfeedThunk(const HloInstruction* inst);
    176 
    177   // Returns a WhileThunk that invokes thunk sequences for 'condition' and
    178   // 'body' sub-computations of while instruction 'hlo'.
    179   std::unique_ptr<Thunk> BuildWhileThunk(const HloInstruction* hlo);
    180 
    181   // Returns a ForThunk which executes 'loop_limit' invocations of a thunk
    182   // sequence from the 'body' sub-computation of the while instruction 'hlo'.
    183   std::unique_ptr<Thunk> BuildForThunk(const HloInstruction* hlo,
    184                                        const int64 loop_limit);
    185 
    186   // Returns a ConditionalThunk that executes the thunk sequence for
    187   // 'true_computation' or 'false_computation' depending on the value of the
    188   // predicate in the given conditional instruction.
    189   std::unique_ptr<Thunk> BuildConditionalThunk(const HloInstruction* hlo);
    190 
    191   Status Postprocess(HloInstruction* hlo) override;
    192 
    193   // Returns the last generated thunk.
    194   Thunk* LastThunk() const { return thunk_sequence_->back().get(); }
    195 
    196   // The thunk sequence this IrEmitter generates for the input computation.
    197   std::unique_ptr<ThunkSequence> thunk_sequence_;
    198 
    199   // The HloComputation that this IrEmitter emits code for.
    200   const HloComputation* hlo_computation_;
    201 };
    202 
    203 }  // namespace gpu
    204 }  // namespace xla
    205 
    206 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_
    207