1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ 18 19 #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" 20 #include "tensorflow/compiler/xla/service/gpu/thunk.h" 21 22 namespace xla { 23 namespace gpu { 24 25 // Emits LLVM IR for an "unnested computation". 26 // 27 // An unnested computation is an HloComputation which you run by executing one 28 // or more kernels for each HloInstruction it contains. Examples of unnested 29 // computations: 30 // 31 // - An HloModule's root computation, 32 // - The body of an HLO while loop, 33 // - The true/false computation of an HLO conditional. 34 // 35 // Note the opportunity for confusion -- the while loop's computation is nested 36 // within the root computation, but it's emitted using IrEmitterUnnested! Don't 37 // think about it too hard. 38 // 39 // Examples of things that are not unnested computations: 40 // 41 // - The reducer of a kReduce HLO. This is emited using IrEmitterNested. 42 // - The body of a fusion node. IrEmitterUnenested emits the relevant code 43 // within a kernel function using FusedIrEmitter. (FusedIrEmitter is not 44 // really an IrEmitter, but is more an "IR generator generator".) 45 // 46 class IrEmitterUnnested : public IrEmitter { 47 public: 48 IrEmitterUnnested(const HloModuleConfig& hlo_module_config, 49 const HloComputation* hlo_computation, 50 IrEmitterContext* ir_emitter_context); 51 IrEmitterUnnested(const IrEmitterUnnested&) = delete; 52 IrEmitterUnnested& operator=(const IrEmitterUnnested&) = delete; 53 54 // Transfers the ownship of thunk_sequence_ out. 55 std::unique_ptr<ThunkSequence> ConsumeThunkSequence() { 56 return std::move(thunk_sequence_); 57 } 58 59 Status DefaultAction(HloInstruction* hlo) override; 60 61 // IrEmitterUnnested handles the following instructions differently from 62 // IrEmitter. 63 Status HandleCopy(HloInstruction* copy) override; 64 Status HandleConditional(HloInstruction* conditional) override; 65 Status HandleConvolution(HloInstruction* convolution) override; 66 Status HandleCustomCall(HloInstruction* custom_call) override; 67 Status HandleDot(HloInstruction* dot) override; 68 Status HandleFft(HloInstruction* fft) override; 69 Status HandleFusion(HloInstruction* fusion) override; 70 Status HandleGather(HloInstruction* gather) override; 71 Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; 72 Status HandleReduce(HloInstruction* reduce) override; 73 Status HandleSelectAndScatter(HloInstruction* instruction) override; 74 Status HandleTuple(HloInstruction* tuple) override; 75 Status HandleWhile(HloInstruction* xla_while) override; 76 Status HandleInfeed(HloInstruction* xla_infeed) override; 77 Status HandleRng(HloInstruction* random) override; 78 Status HandleSelect(HloInstruction* select) override; 79 80 Status EmitTargetElementLoop( 81 const HloInstruction& hlo, 82 const llvm_ir::ElementGenerator& body_emitter) override; 83 84 // Same as `EmitTargetElementLoop`, but in given `thunk` rather than 85 // `LastThunk()`. 86 Status EmitTargetElementLoopInThunk( 87 const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter, 88 KernelThunk* thunk); 89 90 private: 91 // Builds the appropriate thunk for the instruction hlo and returns the owning 92 // pointer to it. The caller needs to make sure `inst` outlives the lifetime 93 // of the returned Thunk object. 94 std::unique_ptr<Thunk> BuildThunk(const HloInstruction* hlo); 95 96 // Builds the prototype of the IR kernel for `inst` and adds it to the module. 97 // This kernel takes as arguments pointers to the given buffer allocations. 98 llvm::Function* BuildKernelPrototype( 99 const HloInstruction& inst, 100 tensorflow::gtl::ArraySlice<const BufferAllocation*> args); 101 102 // EmitColumnReduction and EmitRowReduction emit code for column and row 103 // reduction of a matrix and/or 3D tensor. Row and column reduction have 104 // different memory access pattern, so for performance their implementations 105 // are significantly different. 106 // 107 // Emits code that reduces a matrix of shape [height x width] to a vector of 108 // [width]. Other parameters have the same meaning as those of 109 // `EmitReductionToVector`. Note that input shape might not be 110 // [height x width], but can be bitcast to [height x weight] with "height" 111 // being the major dimension. 112 Status EmitColumnReduction(int64 height, int64 width, HloInstruction* reduce, 113 const Shape& input_shape, 114 const llvm_ir::ElementGenerator& input_gen, 115 const llvm_ir::ElementGenerator& init_value_gen, 116 HloComputation* reducer); 117 118 // Emits code that reduces a 3D tensor of shape [depth x height x width] to a 119 // vector of shape [height]. Other parameters have the same meaning as those 120 // of `EmitReductionToVector`. Note that input shape might not be 121 // [depth x height x width], but can be bitcast to [depth x height x weight] 122 // with "depth" being the most major dimension. 123 Status EmitRowReduction(int64 depth, int64 height, int64 width, 124 HloInstruction* reduce, const Shape& input_shape, 125 const llvm_ir::ElementGenerator& input_gen, 126 const llvm_ir::ElementGenerator& init_value_gen, 127 HloComputation* reducer); 128 129 // Emits code that reduces a tensor of arbitrary rank to a scalar. 130 Status EmitReductionToScalar(HloInstruction* reduce, const Shape& input_shape, 131 const llvm_ir::ElementGenerator& input_gen, 132 const llvm_ir::ElementGenerator& init_value_gen, 133 HloComputation* reducer); 134 135 // Figures out whether `reduce` is a row or column reduction, and which 136 // dimensions to reduce, and calls either `EmitRowReduction` or 137 // `EmitColumnReduction` as appropriate. `input_shape` is the shape of the 138 // input array, which is the operand of the Reduce instruction if unfused or 139 // of the Fusion instruction if fused. `input_gen` and `init_value_gen` 140 // generate elements of the input and the initial value. Other parameters mean 141 // the same as for `HandleReduce`. 142 // 143 // Prerequisite: `IsReductionToVector(*reduce)` 144 Status EmitReductionToVector( 145 HloInstruction* reduce, const Shape& input_shape, 146 const llvm_ir::ElementGenerator& input_gen, 147 const llvm_ir::ElementGenerator& init_value_gen, 148 tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce, 149 HloComputation* reducer); 150 151 // Emits code to initialize buffer of `inst` in given `thunk`. 152 Status EmitInitializer(const HloInstruction* inst, KernelThunk* thunk); 153 154 // Returns a KernelThunk that invokes the kernel emitted for `inst`. The 155 // caller needs to make sure `inst` outlives the lifetime of the returned 156 // Thunk object. 157 std::unique_ptr<Thunk> BuildKernelThunk(const HloInstruction* inst); 158 159 // Returns a FftThunk that calls cuFFT to implement `inst`. 160 std::unique_ptr<Thunk> BuildFftThunk(const HloInstruction* inst); 161 162 // Returns a GemmThunk that calls gemm to implement `inst`. The caller needs 163 // to make sure `inst` outlives the lifetime of the returned Thunk object. 164 std::unique_ptr<Thunk> BuildGemmThunk(const HloInstruction* inst); 165 166 // Returns a thunk that calls host-to-device cuMemcpy to implement `inst`. 167 std::unique_ptr<Thunk> BuildHostToDeviceCopyThunk(const HloInstruction* inst); 168 169 // Returns a thunk that calls device-to-device cuMemcpy to implement `inst`. 170 std::unique_ptr<Thunk> BuildDeviceToDeviceCopyThunk( 171 const HloInstruction* inst); 172 173 // Returns an InfeedThunk that performs device-to-device memcpy to implement 174 // `inst`. 175 std::unique_ptr<Thunk> BuildInfeedThunk(const HloInstruction* inst); 176 177 // Returns a WhileThunk that invokes thunk sequences for 'condition' and 178 // 'body' sub-computations of while instruction 'hlo'. 179 std::unique_ptr<Thunk> BuildWhileThunk(const HloInstruction* hlo); 180 181 // Returns a ForThunk which executes 'loop_limit' invocations of a thunk 182 // sequence from the 'body' sub-computation of the while instruction 'hlo'. 183 std::unique_ptr<Thunk> BuildForThunk(const HloInstruction* hlo, 184 const int64 loop_limit); 185 186 // Returns a ConditionalThunk that executes the thunk sequence for 187 // 'true_computation' or 'false_computation' depending on the value of the 188 // predicate in the given conditional instruction. 189 std::unique_ptr<Thunk> BuildConditionalThunk(const HloInstruction* hlo); 190 191 Status Postprocess(HloInstruction* hlo) override; 192 193 // Returns the last generated thunk. 194 Thunk* LastThunk() const { return thunk_sequence_->back().get(); } 195 196 // The thunk sequence this IrEmitter generates for the input computation. 197 std::unique_ptr<ThunkSequence> thunk_sequence_; 198 199 // The HloComputation that this IrEmitter emits code for. 200 const HloComputation* hlo_computation_; 201 }; 202 203 } // namespace gpu 204 } // namespace xla 205 206 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ 207