1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" 17 18 #include <algorithm> 19 #include <vector> 20 21 #include "llvm/IR/Module.h" 22 #include "tensorflow/compiler/xla/layout_util.h" 23 #include "tensorflow/compiler/xla/service/hlo_computation.h" 24 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 25 #include "tensorflow/compiler/xla/service/hlo_module.h" 26 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 27 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" 28 #include "tensorflow/compiler/xla/shape_util.h" 29 #include "tensorflow/compiler/xla/util.h" 30 #include "tensorflow/compiler/xla/window_util.h" 31 #include "tensorflow/compiler/xla/xla_data.pb.h" 32 #include "tensorflow/core/platform/logging.h" 33 #include "tensorflow/core/platform/protobuf.h" 34 35 namespace xla { 36 namespace gpu { 37 38 namespace { 39 40 // Return whether the given shape is rank 2 excluding the batch dimensions. 41 bool IsRank2(const Shape& shape, int64 batch_dimensions_size) { 42 return shape.rank() == batch_dimensions_size + 2; 43 } 44 45 // In a gemm operation where output = lhs * rhs, check whether the given shapes 46 // are valid for the operation. 47 bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, 48 const Shape& output_shape, 49 int64 batch_dimensions_size) { 50 // The inputs and the output must 51 // 1) be matrices with no padding and a non-zero number of elements, 52 // 2) have an allowed element type. 53 PrimitiveType output_primitive_type = output_shape.element_type(); 54 bool type_is_allowed = 55 (output_primitive_type == F16 || output_primitive_type == F32 || 56 output_primitive_type == F64 || output_primitive_type == C64 || 57 output_primitive_type == C128); 58 return type_is_allowed && IsRank2(lhs_shape, batch_dimensions_size) && 59 IsRank2(rhs_shape, batch_dimensions_size) && 60 IsRank2(output_shape, batch_dimensions_size) && 61 !ShapeUtil::IsZeroElementArray(lhs_shape) && 62 !ShapeUtil::IsZeroElementArray(rhs_shape); 63 } 64 65 bool DotImplementedAsGemm(const HloInstruction& dot) { 66 CHECK_EQ(dot.opcode(), HloOpcode::kDot); 67 const Shape& lhs_shape = dot.operand(0)->shape(); 68 const Shape& rhs_shape = dot.operand(1)->shape(); 69 const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); 70 71 // If gemm can accept the operand shapes, use it rather than a custom 72 // kernel. 73 if (AreValidGemmShapes(lhs_shape, rhs_shape, dot.shape(), 74 dim_numbers.lhs_batch_dimensions_size())) { 75 // The size of the reduction dimension should match. The shape inference 76 // guarantees this invariant, so the check here is for programming 77 // errors. 78 CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)), 79 rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0))); 80 return true; 81 } 82 return false; 83 } 84 } // namespace 85 86 bool ImplementedAsGemm(const HloInstruction& hlo) { 87 // For certain types of Dot, we can call pre-canned BLAS gemm. 88 if (hlo.opcode() == HloOpcode::kDot) { 89 return DotImplementedAsGemm(hlo); 90 } 91 92 if (hlo.opcode() == HloOpcode::kFusion && 93 hlo.fusion_kind() == HloInstruction::FusionKind::kOutput && 94 (hlo.fused_expression_root()->opcode() == HloOpcode::kMultiply || 95 hlo.fused_expression_root()->opcode() == HloOpcode::kAdd)) { 96 // Try to find the dot inside the output fusion node. 97 const HloInstruction* dot = hlo.fused_expression_root()->operand(0); 98 if (dot->opcode() != HloOpcode::kDot) { 99 dot = hlo.fused_expression_root()->operand(1); 100 } 101 if (dot->opcode() == HloOpcode::kDot) { 102 return DotImplementedAsGemm(*dot); 103 } 104 } 105 106 return false; 107 } 108 109 const char* const kCudnnBatchNormForwardInferenceCallTarget = 110 "__cudnn$batchNormalizationForwardInference"; 111 const char* const kCudnnBatchNormForwardTrainingCallTarget = 112 "__cudnn$batchNormalizationForwardTraining"; 113 const char* const kCudnnBatchNormBackwardCallTarget = 114 "__cudnn$batchNormalizationBackward"; 115 116 bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo) { 117 if (hlo.opcode() != HloOpcode::kCustomCall) { 118 return false; 119 } 120 const auto& target = hlo.custom_call_target(); 121 return target == kCudnnBatchNormForwardInferenceCallTarget || 122 target == kCudnnBatchNormForwardTrainingCallTarget || 123 target == kCudnnBatchNormBackwardCallTarget; 124 } 125 126 const char* const kCudnnConvForwardCallTarget = "__cudnn$convForward"; 127 const char* const kCudnnConvBackwardInputCallTarget = 128 "__cudnn$convBackwardInput"; 129 const char* const kCudnnConvBackwardFilterCallTarget = 130 "__cudnn$convBackwardFilter"; 131 const char* const kCudnnConvBiasActivationForwardCallTarget = 132 "__cudnn$convBiasActivationForward"; 133 134 bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) { 135 if (hlo.opcode() != HloOpcode::kCustomCall) { 136 return false; 137 } 138 const auto& target = hlo.custom_call_target(); 139 return target == kCudnnConvForwardCallTarget || 140 target == kCudnnConvBackwardInputCallTarget || 141 target == kCudnnConvBackwardFilterCallTarget || 142 target == kCudnnConvBiasActivationForwardCallTarget; 143 } 144 145 const char* const kCusolverCholeskyCallTarget = "__cusolver$cholesky"; 146 147 bool IsCustomCallToCusolver(const HloInstruction& hlo) { 148 if (hlo.opcode() != HloOpcode::kCustomCall) { 149 return false; 150 } 151 const auto& target = hlo.custom_call_target(); 152 return target == kCusolverCholeskyCallTarget; 153 } 154 155 bool ImplementedAsLibraryCall(const HloInstruction& hlo) { 156 return ImplementedAsGemm(hlo) || IsCustomCallToDnnBatchNorm(hlo) || 157 IsCustomCallToDnnConvolution(hlo); 158 } 159 160 bool IsReductionToVector(const HloInstruction& reduce) { 161 if (HloOpcode::kReduce != reduce.opcode()) { 162 return false; 163 } 164 const HloInstruction* input = reduce.operand(0); 165 std::vector<int64> dims_to_keep; 166 for (int64 dim = 0; dim < input->shape().dimensions().size(); ++dim) { 167 if (!absl::c_linear_search(reduce.dimensions(), dim)) { 168 dims_to_keep.push_back(dim); 169 } 170 } 171 return LayoutUtil::AreDimensionsConsecutive(input->shape().layout(), 172 dims_to_keep) && 173 ShapeUtil::Equal( 174 reduce.shape(), 175 ShapeUtil::FilterDimensions( 176 [&](int64 dim) { return absl::c_count(dims_to_keep, dim); }, 177 input->shape())); 178 } 179 180 // This emits a device-side call to 181 // "i32 vprintf(i8* fmt, arguments_type* arguments)" in the driver; see 182 // http://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/index.html#system-calls 183 llvm::Value* EmitPrintf(absl::string_view fmt, 184 absl::Span<llvm::Value* const> arguments, 185 llvm::IRBuilder<>* builder) { 186 std::vector<llvm::Type*> argument_types; 187 for (auto argument : arguments) { 188 argument_types.push_back(argument->getType()); 189 } 190 auto* arguments_type = llvm::StructType::create(argument_types); 191 llvm::Value* arguments_ptr = builder->CreateAlloca(arguments_type); 192 for (size_t i = 0; i < arguments.size(); ++i) { 193 builder->CreateStore( 194 arguments[i], 195 builder->CreateGEP(arguments_ptr, 196 {builder->getInt64(0), builder->getInt32(i)})); 197 } 198 return builder->CreateCall( 199 builder->GetInsertBlock()->getParent()->getParent()->getOrInsertFunction( 200 "vprintf", 201 llvm::FunctionType::get(builder->getInt32Ty(), 202 {builder->getInt8Ty()->getPointerTo(), 203 arguments_type->getPointerTo()}, 204 /*isVarArg=*/false)), 205 {builder->CreateGlobalStringPtr(llvm_ir::AsStringRef(fmt)), 206 arguments_ptr}); 207 } 208 209 llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, 210 llvm::IRBuilder<>* builder) { 211 int bit_width = value->getType()->getPrimitiveSizeInBits(); 212 llvm::Value* all_warps_mask = builder->getInt32(-1); 213 214 // Special case for efficiency 215 if (value->getType()->isFloatTy() && bit_width == 32) { 216 return llvm_ir::EmitCallToIntrinsic( 217 llvm::Intrinsic::nvvm_shfl_sync_down_f32, 218 {all_warps_mask, value, offset, builder->getInt32(kWarpSize - 1)}, {}, 219 builder); 220 } 221 222 // We must split values wider than 32 bits as the "shfl" instruction operates 223 // on 32-bit values. 224 int num_segments = CeilOfRatio(bit_width, 32); 225 llvm::Value* x = builder->CreateBitCast( 226 builder->CreateZExt( 227 builder->CreateBitCast(value, builder->getIntNTy(bit_width)), 228 builder->getIntNTy(32 * num_segments)), 229 llvm::VectorType::get(builder->getInt32Ty(), num_segments)); 230 for (int i = 0; i < num_segments; ++i) { 231 x = builder->CreateInsertElement( 232 x, 233 llvm_ir::EmitCallToIntrinsic( 234 llvm::Intrinsic::nvvm_shfl_sync_down_i32, 235 {all_warps_mask, builder->CreateExtractElement(x, i), offset, 236 builder->getInt32(kWarpSize - 1)}, 237 {}, builder), 238 i); 239 } 240 return builder->CreateBitCast( 241 builder->CreateTrunc( 242 builder->CreateBitCast(x, builder->getIntNTy(32 * num_segments)), 243 builder->getIntNTy(bit_width)), 244 value->getType()); 245 } 246 247 StatusOr<CudnnConvKind> GetCudnnConvKind( 248 const HloCustomCallInstruction* instr) { 249 absl::string_view target = instr->custom_call_target(); 250 if (target == kCudnnConvForwardCallTarget) { 251 return CudnnConvKind::kForward; 252 } 253 if (target == kCudnnConvBackwardInputCallTarget) { 254 return CudnnConvKind::kBackwardInput; 255 } 256 if (target == kCudnnConvBackwardFilterCallTarget) { 257 return CudnnConvKind::kBackwardFilter; 258 } 259 if (target == kCudnnConvBiasActivationForwardCallTarget) { 260 return CudnnConvKind::kForwardActivation; 261 } 262 return InternalError("Unexpected call target: %s", target); 263 } 264 265 string CudnnConvKindToString(CudnnConvKind kind) { 266 switch (kind) { 267 case CudnnConvKind::kForward: 268 return "forward"; 269 case CudnnConvKind::kBackwardFilter: 270 return "backward_filter"; 271 case CudnnConvKind::kBackwardInput: 272 return "backward_input"; 273 case CudnnConvKind::kForwardActivation: 274 return "forward with activation"; 275 } 276 } 277 278 llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b) { 279 return b->CreateAnd( 280 b->CreateICmpEQ( 281 b->getInt32(0), 282 llvm_ir::EmitCallToIntrinsic( 283 llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b)), 284 b->CreateICmpEQ( 285 b->getInt32(0), 286 llvm_ir::EmitCallToIntrinsic( 287 llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, b))); 288 } 289 290 } // namespace gpu 291 } // namespace xla 292