1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" 17 18 #include <algorithm> 19 #include <vector> 20 21 #include "llvm/IR/Module.h" 22 #include "tensorflow/compiler/xla/layout_util.h" 23 #include "tensorflow/compiler/xla/service/hlo_computation.h" 24 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 25 #include "tensorflow/compiler/xla/service/hlo_module.h" 26 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 27 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" 28 #include "tensorflow/compiler/xla/shape_util.h" 29 #include "tensorflow/compiler/xla/util.h" 30 #include "tensorflow/compiler/xla/window_util.h" 31 #include "tensorflow/compiler/xla/xla_data.pb.h" 32 #include "tensorflow/core/platform/logging.h" 33 #include "tensorflow/core/platform/protobuf.h" 34 35 namespace xla { 36 namespace gpu { 37 38 namespace { 39 40 // Return whether the given shape is a matrix with no padding. 41 bool IsRank2WithNoPadding(const Shape& shape) { 42 return ShapeUtil::Rank(shape) == 2 && !LayoutUtil::IsPadded(shape); 43 } 44 45 // In a gemm operation where output = lhs * rhs, check whether the given shapes 46 // are valid for the operation. 47 bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, 48 const Shape& output_shape) { 49 // The inputs and the output must 50 // 1) be matrices with no padding and a non-zero number of elements, 51 // 2) have an allowed element type. 52 bool type_is_allowed = (output_shape.element_type() == F32 || 53 output_shape.element_type() == F64); 54 return type_is_allowed && IsRank2WithNoPadding(lhs_shape) && 55 IsRank2WithNoPadding(rhs_shape) && 56 IsRank2WithNoPadding(output_shape) && 57 !ShapeUtil::HasZeroElements(lhs_shape) && 58 !ShapeUtil::HasZeroElements(rhs_shape); 59 } 60 } // namespace 61 62 bool ImplementedAsGemm(const HloInstruction& hlo) { 63 // We can only do this if the HLO is unnested. 64 if (hlo.parent() != hlo.GetModule()->entry_computation()) { 65 return false; 66 } 67 68 // For certain types of Dot, we can call pre-canned BLAS gemm. 69 if (hlo.opcode() == HloOpcode::kDot) { 70 const Shape& lhs_shape = hlo.operand(0)->shape(); 71 const Shape& rhs_shape = hlo.operand(1)->shape(); 72 73 // If gemm can accept the operand shapes, use it rather than a custom 74 // kernel. 75 if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) { 76 // The size of the reduction dimension should match. The shape inference 77 // guarantees this invariant, so the check here is for programming 78 // errors. 79 CHECK_EQ(lhs_shape.dimensions(1), rhs_shape.dimensions(0)); 80 return true; 81 } 82 } 83 84 if (hlo.opcode() == HloOpcode::kFusion && 85 hlo.fusion_kind() == HloInstruction::FusionKind::kTransposeDot && 86 hlo.fused_expression_root()->opcode() == HloOpcode::kDot) { 87 return true; 88 } 89 90 return false; 91 } 92 93 const char* const kCudnnBatchNormForwardInferenceCallTarget = 94 "__cudnn$batchNormalizationForwardInference"; 95 const char* const kCudnnBatchNormForwardTrainingCallTarget = 96 "__cudnn$batchNormalizationForwardTraining"; 97 const char* const kCudnnBatchNormBackwardCallTarget = 98 "__cudnn$batchNormalizationBackward"; 99 100 bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo) { 101 if (hlo.opcode() != HloOpcode::kCustomCall) { 102 return false; 103 } 104 const auto& target = hlo.custom_call_target(); 105 return target == kCudnnBatchNormForwardInferenceCallTarget || 106 target == kCudnnBatchNormForwardTrainingCallTarget || 107 target == kCudnnBatchNormBackwardCallTarget; 108 } 109 110 const char* const kCudnnConvForwardCallTarget = "__cudnn$convForward"; 111 const char* const kCudnnConvBackwardInputCallTarget = 112 "__cudnn$convBackwardInput"; 113 const char* const kCudnnConvBackwardFilterCallTarget = 114 "__cudnn$convBackwardFilter"; 115 116 bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) { 117 if (hlo.opcode() != HloOpcode::kCustomCall) { 118 return false; 119 } 120 const auto& target = hlo.custom_call_target(); 121 return target == kCudnnConvForwardCallTarget || 122 target == kCudnnConvBackwardInputCallTarget || 123 target == kCudnnConvBackwardFilterCallTarget; 124 } 125 126 bool ImplementedAsLibraryCall(const HloInstruction& hlo) { 127 return ImplementedAsGemm(hlo) || IsCustomCallToDnnBatchNorm(hlo) || 128 IsCustomCallToDnnConvolution(hlo); 129 } 130 131 static HloInstruction* CreateCudnnConv( 132 const char* call_target, const Shape& shape, HloInstruction* lhs, 133 HloInstruction* rhs, const Window& window, 134 const ConvolutionDimensionNumbers& dnums) { 135 HloComputation* computation = lhs->parent(); 136 137 // This call returns a tuple of (conv_result, scratch_memory), where 138 // conv_result is the actual result of the convolution, and scratch_memory is 139 // temporary memory used by cudnn. 140 // 141 // At the moment, we don't know how much scratch memory this conv is going to 142 // use, so we put u8[0] in this place. Later on another pass will choose 143 // which conv algorithm to use, and at that point we'll modify the shape of 144 // this second tuple element. 145 Shape call_shape = 146 ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})}); 147 148 // Our CustomCall takes three arguments: The conv lhs and rhs, and the cudnn 149 // algorithm to use. It's up to a later pass to choose the algorithm, so to 150 // indicate that we haven't yet made a choice, we speicfy -1 for that arg. 151 HloInstruction* negative_one = computation->AddInstruction( 152 HloInstruction::CreateConstant(Literal::CreateR0<int64>(-1))); 153 HloInstruction* custom_call = 154 computation->AddInstruction(HloInstruction::CreateCustomCall( 155 call_shape, {lhs, rhs, negative_one}, call_target)); 156 custom_call->set_window(window); 157 custom_call->set_convolution_dimension_numbers(dnums); 158 return custom_call; 159 } 160 161 HloInstruction* CreateCudnnConvForward( 162 const Shape& shape, HloInstruction* input, HloInstruction* kernel, 163 const Window& window, const ConvolutionDimensionNumbers& dnums) { 164 return CreateCudnnConv(kCudnnConvForwardCallTarget, shape, input, kernel, 165 window, dnums); 166 } 167 168 HloInstruction* CreateCudnnConvBackwardInput( 169 const Shape& shape, HloInstruction* output, HloInstruction* reverse_filter, 170 const Window& window, const ConvolutionDimensionNumbers& dnums) { 171 return CreateCudnnConv(kCudnnConvBackwardInputCallTarget, shape, output, 172 reverse_filter, window, dnums); 173 } 174 175 HloInstruction* CreateCudnnConvBackwardFilter( 176 const Shape& shape, HloInstruction* input, HloInstruction* output, 177 const Window& window, const ConvolutionDimensionNumbers& dnums) { 178 return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, shape, input, 179 output, window, dnums); 180 } 181 182 bool IsReductionToVector(const HloInstruction& reduce) { 183 if (HloOpcode::kReduce != reduce.opcode()) { 184 return false; 185 } 186 const HloInstruction* input = reduce.operand(0); 187 std::vector<int64> dims_to_keep; 188 for (int64 dim = 0; dim < input->shape().dimensions().size(); ++dim) { 189 if (!std::count(reduce.dimensions().begin(), reduce.dimensions().end(), 190 dim)) { 191 dims_to_keep.push_back(dim); 192 } 193 } 194 return LayoutUtil::AreDimensionsConsecutive(input->shape().layout(), 195 dims_to_keep) && 196 ShapeUtil::Equal(reduce.shape(), ShapeUtil::FilterDimensions( 197 [&dims_to_keep](int64 dim) { 198 return std::count( 199 dims_to_keep.begin(), 200 dims_to_keep.end(), dim); 201 }, 202 input->shape())); 203 } 204 205 // This emits a device-side call to 206 // "i32 vprintf(i8* fmt, arguments_type* arguments)" in the driver; see 207 // http://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/index.html#system-calls 208 llvm::Value* EmitPrintf(tensorflow::StringPiece fmt, 209 tensorflow::gtl::ArraySlice<llvm::Value*> arguments, 210 llvm::IRBuilder<>* builder) { 211 std::vector<llvm::Type*> argument_types; 212 for (auto argument : arguments) { 213 argument_types.push_back(argument->getType()); 214 } 215 auto* arguments_type = llvm::StructType::create(argument_types); 216 llvm::Value* arguments_ptr = builder->CreateAlloca(arguments_type); 217 for (size_t i = 0; i < arguments.size(); ++i) { 218 builder->CreateStore( 219 arguments[i], 220 builder->CreateGEP(arguments_ptr, 221 {builder->getInt64(0), builder->getInt32(i)})); 222 } 223 return builder->CreateCall( 224 builder->GetInsertBlock()->getParent()->getParent()->getOrInsertFunction( 225 "vprintf", 226 llvm::FunctionType::get(builder->getInt32Ty(), 227 {builder->getInt8Ty()->getPointerTo(), 228 arguments_type->getPointerTo()}, 229 /*isVarArg=*/false)), 230 {builder->CreateGlobalStringPtr(llvm_ir::AsStringRef(fmt)), 231 arguments_ptr}); 232 } 233 234 llvm::Value* EmitShuffleDown(llvm::Value* value, llvm::Value* offset, 235 llvm::IRBuilder<>* builder) { 236 int bit_width = value->getType()->getPrimitiveSizeInBits(); 237 238 // Special case for efficiency 239 if (value->getType()->isFloatTy() && bit_width == 32) { 240 return llvm_ir::EmitCallToIntrinsic( 241 llvm::Intrinsic::nvvm_shfl_down_f32, 242 {value, offset, builder->getInt32(kWarpSize - 1)}, {}, builder); 243 } 244 245 // We must split values wider than 32 bits as the "shfl" instruction operates 246 // on 32-bit values. 247 int num_segments = CeilOfRatio(bit_width, 32); 248 llvm::Value* x = builder->CreateBitCast( 249 builder->CreateZExt( 250 builder->CreateBitCast(value, builder->getIntNTy(bit_width)), 251 builder->getIntNTy(32 * num_segments)), 252 llvm::VectorType::get(builder->getInt32Ty(), num_segments)); 253 for (int i = 0; i < num_segments; ++i) { 254 x = builder->CreateInsertElement( 255 x, 256 llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_shfl_down_i32, 257 {builder->CreateExtractElement(x, i), 258 offset, builder->getInt32(kWarpSize - 1)}, 259 {}, builder), 260 i); 261 } 262 return builder->CreateBitCast( 263 builder->CreateTrunc( 264 builder->CreateBitCast(x, builder->getIntNTy(32 * num_segments)), 265 builder->getIntNTy(bit_width)), 266 value->getType()); 267 } 268 269 } // namespace gpu 270 } // namespace xla 271