Home | History | Annotate | Download | only in gpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
     17 
     18 #include <algorithm>
     19 #include <vector>
     20 
     21 #include "llvm/IR/Module.h"
     22 #include "tensorflow/compiler/xla/layout_util.h"
     23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     25 #include "tensorflow/compiler/xla/service/hlo_module.h"
     26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     27 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
     28 #include "tensorflow/compiler/xla/shape_util.h"
     29 #include "tensorflow/compiler/xla/util.h"
     30 #include "tensorflow/compiler/xla/window_util.h"
     31 #include "tensorflow/compiler/xla/xla_data.pb.h"
     32 #include "tensorflow/core/platform/logging.h"
     33 #include "tensorflow/core/platform/protobuf.h"
     34 
     35 namespace xla {
     36 namespace gpu {
     37 
     38 namespace {
     39 
     40 // Return whether the given shape is rank 2 excluding the batch dimensions.
     41 bool IsRank2(const Shape& shape, int64 batch_dimensions_size) {
     42   return shape.rank() == batch_dimensions_size + 2;
     43 }
     44 
     45 // In a gemm operation where output = lhs * rhs, check whether the given shapes
     46 // are valid for the operation.
     47 bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
     48                         const Shape& output_shape,
     49                         int64 batch_dimensions_size) {
     50   // The inputs and the output must
     51   // 1) be matrices with no padding and a non-zero number of elements,
     52   // 2) have an allowed element type.
     53   PrimitiveType output_primitive_type = output_shape.element_type();
     54   bool type_is_allowed =
     55       (output_primitive_type == F16 || output_primitive_type == F32 ||
     56        output_primitive_type == F64 || output_primitive_type == C64 ||
     57        output_primitive_type == C128);
     58   return type_is_allowed && IsRank2(lhs_shape, batch_dimensions_size) &&
     59          IsRank2(rhs_shape, batch_dimensions_size) &&
     60          IsRank2(output_shape, batch_dimensions_size) &&
     61          !ShapeUtil::IsZeroElementArray(lhs_shape) &&
     62          !ShapeUtil::IsZeroElementArray(rhs_shape);
     63 }
     64 
     65 bool DotImplementedAsGemm(const HloInstruction& dot) {
     66   CHECK_EQ(dot.opcode(), HloOpcode::kDot);
     67   const Shape& lhs_shape = dot.operand(0)->shape();
     68   const Shape& rhs_shape = dot.operand(1)->shape();
     69   const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers();
     70 
     71   // If gemm can accept the operand shapes, use it rather than a custom
     72   // kernel.
     73   if (AreValidGemmShapes(lhs_shape, rhs_shape, dot.shape(),
     74                          dim_numbers.lhs_batch_dimensions_size())) {
     75     // The size of the reduction dimension should match. The shape inference
     76     // guarantees this invariant, so the check here is for programming
     77     // errors.
     78     CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)),
     79              rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0)));
     80     return true;
     81   }
     82   return false;
     83 }
     84 }  // namespace
     85 
     86 bool ImplementedAsGemm(const HloInstruction& hlo) {
     87   // For certain types of Dot, we can call pre-canned BLAS gemm.
     88   if (hlo.opcode() == HloOpcode::kDot) {
     89     return DotImplementedAsGemm(hlo);
     90   }
     91 
     92   if (hlo.opcode() == HloOpcode::kFusion &&
     93       hlo.fusion_kind() == HloInstruction::FusionKind::kOutput &&
     94       (hlo.fused_expression_root()->opcode() == HloOpcode::kMultiply ||
     95        hlo.fused_expression_root()->opcode() == HloOpcode::kAdd)) {
     96     // Try to find the dot inside the output fusion node.
     97     const HloInstruction* dot = hlo.fused_expression_root()->operand(0);
     98     if (dot->opcode() != HloOpcode::kDot) {
     99       dot = hlo.fused_expression_root()->operand(1);
    100     }
    101     if (dot->opcode() == HloOpcode::kDot) {
    102       return DotImplementedAsGemm(*dot);
    103     }
    104   }
    105 
    106   return false;
    107 }
    108 
    109 const char* const kCudnnBatchNormForwardInferenceCallTarget =
    110     "__cudnn$batchNormalizationForwardInference";
    111 const char* const kCudnnBatchNormForwardTrainingCallTarget =
    112     "__cudnn$batchNormalizationForwardTraining";
    113 const char* const kCudnnBatchNormBackwardCallTarget =
    114     "__cudnn$batchNormalizationBackward";
    115 
    116 bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo) {
    117   if (hlo.opcode() != HloOpcode::kCustomCall) {
    118     return false;
    119   }
    120   const auto& target = hlo.custom_call_target();
    121   return target == kCudnnBatchNormForwardInferenceCallTarget ||
    122          target == kCudnnBatchNormForwardTrainingCallTarget ||
    123          target == kCudnnBatchNormBackwardCallTarget;
    124 }
    125 
    126 const char* const kCudnnConvForwardCallTarget = "__cudnn$convForward";
    127 const char* const kCudnnConvBackwardInputCallTarget =
    128     "__cudnn$convBackwardInput";
    129 const char* const kCudnnConvBackwardFilterCallTarget =
    130     "__cudnn$convBackwardFilter";
    131 const char* const kCudnnConvBiasActivationForwardCallTarget =
    132     "__cudnn$convBiasActivationForward";
    133 
    134 bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) {
    135   if (hlo.opcode() != HloOpcode::kCustomCall) {
    136     return false;
    137   }
    138   const auto& target = hlo.custom_call_target();
    139   return target == kCudnnConvForwardCallTarget ||
    140          target == kCudnnConvBackwardInputCallTarget ||
    141          target == kCudnnConvBackwardFilterCallTarget ||
    142          target == kCudnnConvBiasActivationForwardCallTarget;
    143 }
    144 
    145 const char* const kCusolverCholeskyCallTarget = "__cusolver$cholesky";
    146 
    147 bool IsCustomCallToCusolver(const HloInstruction& hlo) {
    148   if (hlo.opcode() != HloOpcode::kCustomCall) {
    149     return false;
    150   }
    151   const auto& target = hlo.custom_call_target();
    152   return target == kCusolverCholeskyCallTarget;
    153 }
    154 
    155 bool ImplementedAsLibraryCall(const HloInstruction& hlo) {
    156   return ImplementedAsGemm(hlo) || IsCustomCallToDnnBatchNorm(hlo) ||
    157          IsCustomCallToDnnConvolution(hlo);
    158 }
    159 
    160 bool IsReductionToVector(const HloInstruction& reduce) {
    161   if (HloOpcode::kReduce != reduce.opcode()) {
    162     return false;
    163   }
    164   const HloInstruction* input = reduce.operand(0);
    165   std::vector<int64> dims_to_keep;
    166   for (int64 dim = 0; dim < input->shape().dimensions().size(); ++dim) {
    167     if (!absl::c_linear_search(reduce.dimensions(), dim)) {
    168       dims_to_keep.push_back(dim);
    169     }
    170   }
    171   return LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
    172                                               dims_to_keep) &&
    173          ShapeUtil::Equal(
    174              reduce.shape(),
    175              ShapeUtil::FilterDimensions(
    176                  [&](int64 dim) { return absl::c_count(dims_to_keep, dim); },
    177                  input->shape()));
    178 }
    179 
    180 // This emits a device-side call to
    181 // "i32 vprintf(i8* fmt, arguments_type* arguments)" in the driver; see
    182 // http://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/index.html#system-calls
    183 llvm::Value* EmitPrintf(absl::string_view fmt,
    184                         absl::Span<llvm::Value* const> arguments,
    185                         llvm::IRBuilder<>* builder) {
    186   std::vector<llvm::Type*> argument_types;
    187   for (auto argument : arguments) {
    188     argument_types.push_back(argument->getType());
    189   }
    190   auto* arguments_type = llvm::StructType::create(argument_types);
    191   llvm::Value* arguments_ptr = builder->CreateAlloca(arguments_type);
    192   for (size_t i = 0; i < arguments.size(); ++i) {
    193     builder->CreateStore(
    194         arguments[i],
    195         builder->CreateGEP(arguments_ptr,
    196                            {builder->getInt64(0), builder->getInt32(i)}));
    197   }
    198   return builder->CreateCall(
    199       builder->GetInsertBlock()->getParent()->getParent()->getOrInsertFunction(
    200           "vprintf",
    201           llvm::FunctionType::get(builder->getInt32Ty(),
    202                                   {builder->getInt8Ty()->getPointerTo(),
    203                                    arguments_type->getPointerTo()},
    204                                   /*isVarArg=*/false)),
    205       {builder->CreateGlobalStringPtr(llvm_ir::AsStringRef(fmt)),
    206        arguments_ptr});
    207 }
    208 
    209 llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
    210                                      llvm::IRBuilder<>* builder) {
    211   int bit_width = value->getType()->getPrimitiveSizeInBits();
    212   llvm::Value* all_warps_mask = builder->getInt32(-1);
    213 
    214   // Special case for efficiency
    215   if (value->getType()->isFloatTy() && bit_width == 32) {
    216     return llvm_ir::EmitCallToIntrinsic(
    217         llvm::Intrinsic::nvvm_shfl_sync_down_f32,
    218         {all_warps_mask, value, offset, builder->getInt32(kWarpSize - 1)}, {},
    219         builder);
    220   }
    221 
    222   // We must split values wider than 32 bits as the "shfl" instruction operates
    223   // on 32-bit values.
    224   int num_segments = CeilOfRatio(bit_width, 32);
    225   llvm::Value* x = builder->CreateBitCast(
    226       builder->CreateZExt(
    227           builder->CreateBitCast(value, builder->getIntNTy(bit_width)),
    228           builder->getIntNTy(32 * num_segments)),
    229       llvm::VectorType::get(builder->getInt32Ty(), num_segments));
    230   for (int i = 0; i < num_segments; ++i) {
    231     x = builder->CreateInsertElement(
    232         x,
    233         llvm_ir::EmitCallToIntrinsic(
    234             llvm::Intrinsic::nvvm_shfl_sync_down_i32,
    235             {all_warps_mask, builder->CreateExtractElement(x, i), offset,
    236              builder->getInt32(kWarpSize - 1)},
    237             {}, builder),
    238         i);
    239   }
    240   return builder->CreateBitCast(
    241       builder->CreateTrunc(
    242           builder->CreateBitCast(x, builder->getIntNTy(32 * num_segments)),
    243           builder->getIntNTy(bit_width)),
    244       value->getType());
    245 }
    246 
    247 StatusOr<CudnnConvKind> GetCudnnConvKind(
    248     const HloCustomCallInstruction* instr) {
    249   absl::string_view target = instr->custom_call_target();
    250   if (target == kCudnnConvForwardCallTarget) {
    251     return CudnnConvKind::kForward;
    252   }
    253   if (target == kCudnnConvBackwardInputCallTarget) {
    254     return CudnnConvKind::kBackwardInput;
    255   }
    256   if (target == kCudnnConvBackwardFilterCallTarget) {
    257     return CudnnConvKind::kBackwardFilter;
    258   }
    259   if (target == kCudnnConvBiasActivationForwardCallTarget) {
    260     return CudnnConvKind::kForwardActivation;
    261   }
    262   return InternalError("Unexpected call target: %s", target);
    263 }
    264 
    265 string CudnnConvKindToString(CudnnConvKind kind) {
    266   switch (kind) {
    267     case CudnnConvKind::kForward:
    268       return "forward";
    269     case CudnnConvKind::kBackwardFilter:
    270       return "backward_filter";
    271     case CudnnConvKind::kBackwardInput:
    272       return "backward_input";
    273     case CudnnConvKind::kForwardActivation:
    274       return "forward with activation";
    275   }
    276 }
    277 
    278 llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b) {
    279   return b->CreateAnd(
    280       b->CreateICmpEQ(
    281           b->getInt32(0),
    282           llvm_ir::EmitCallToIntrinsic(
    283               llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b)),
    284       b->CreateICmpEQ(
    285           b->getInt32(0),
    286           llvm_ir::EmitCallToIntrinsic(
    287               llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, b)));
    288 }
    289 
    290 }  // namespace gpu
    291 }  // namespace xla
    292