Home | History | Annotate | Download | only in gpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
     17 
     18 #include <algorithm>
     19 #include <vector>
     20 
     21 #include "llvm/IR/Module.h"
     22 #include "tensorflow/compiler/xla/layout_util.h"
     23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     25 #include "tensorflow/compiler/xla/service/hlo_module.h"
     26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     27 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
     28 #include "tensorflow/compiler/xla/shape_util.h"
     29 #include "tensorflow/compiler/xla/util.h"
     30 #include "tensorflow/compiler/xla/window_util.h"
     31 #include "tensorflow/compiler/xla/xla_data.pb.h"
     32 #include "tensorflow/core/platform/logging.h"
     33 #include "tensorflow/core/platform/protobuf.h"
     34 
     35 namespace xla {
     36 namespace gpu {
     37 
     38 namespace {
     39 
     40 // Return whether the given shape is a matrix with no padding.
     41 bool IsRank2WithNoPadding(const Shape& shape) {
     42   return ShapeUtil::Rank(shape) == 2 && !LayoutUtil::IsPadded(shape);
     43 }
     44 
     45 // In a gemm operation where output = lhs * rhs, check whether the given shapes
     46 // are valid for the operation.
     47 bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
     48                         const Shape& output_shape) {
     49   // The inputs and the output must
     50   // 1) be matrices with no padding and a non-zero number of elements,
     51   // 2) have an allowed element type.
     52   bool type_is_allowed = (output_shape.element_type() == F32 ||
     53                           output_shape.element_type() == F64);
     54   return type_is_allowed && IsRank2WithNoPadding(lhs_shape) &&
     55          IsRank2WithNoPadding(rhs_shape) &&
     56          IsRank2WithNoPadding(output_shape) &&
     57          !ShapeUtil::HasZeroElements(lhs_shape) &&
     58          !ShapeUtil::HasZeroElements(rhs_shape);
     59 }
     60 }  // namespace
     61 
     62 bool ImplementedAsGemm(const HloInstruction& hlo) {
     63   // We can only do this if the HLO is unnested.
     64   if (hlo.parent() != hlo.GetModule()->entry_computation()) {
     65     return false;
     66   }
     67 
     68   // For certain types of Dot, we can call pre-canned BLAS gemm.
     69   if (hlo.opcode() == HloOpcode::kDot) {
     70     const Shape& lhs_shape = hlo.operand(0)->shape();
     71     const Shape& rhs_shape = hlo.operand(1)->shape();
     72 
     73     // If gemm can accept the operand shapes, use it rather than a custom
     74     // kernel.
     75     if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) {
     76       // The size of the reduction dimension should match. The shape inference
     77       // guarantees this invariant, so the check here is for programming
     78       // errors.
     79       CHECK_EQ(lhs_shape.dimensions(1), rhs_shape.dimensions(0));
     80       return true;
     81     }
     82   }
     83 
     84   if (hlo.opcode() == HloOpcode::kFusion &&
     85       hlo.fusion_kind() == HloInstruction::FusionKind::kTransposeDot &&
     86       hlo.fused_expression_root()->opcode() == HloOpcode::kDot) {
     87     return true;
     88   }
     89 
     90   return false;
     91 }
     92 
     93 const char* const kCudnnBatchNormForwardInferenceCallTarget =
     94     "__cudnn$batchNormalizationForwardInference";
     95 const char* const kCudnnBatchNormForwardTrainingCallTarget =
     96     "__cudnn$batchNormalizationForwardTraining";
     97 const char* const kCudnnBatchNormBackwardCallTarget =
     98     "__cudnn$batchNormalizationBackward";
     99 
    100 bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo) {
    101   if (hlo.opcode() != HloOpcode::kCustomCall) {
    102     return false;
    103   }
    104   const auto& target = hlo.custom_call_target();
    105   return target == kCudnnBatchNormForwardInferenceCallTarget ||
    106          target == kCudnnBatchNormForwardTrainingCallTarget ||
    107          target == kCudnnBatchNormBackwardCallTarget;
    108 }
    109 
    110 const char* const kCudnnConvForwardCallTarget = "__cudnn$convForward";
    111 const char* const kCudnnConvBackwardInputCallTarget =
    112     "__cudnn$convBackwardInput";
    113 const char* const kCudnnConvBackwardFilterCallTarget =
    114     "__cudnn$convBackwardFilter";
    115 
    116 bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) {
    117   if (hlo.opcode() != HloOpcode::kCustomCall) {
    118     return false;
    119   }
    120   const auto& target = hlo.custom_call_target();
    121   return target == kCudnnConvForwardCallTarget ||
    122          target == kCudnnConvBackwardInputCallTarget ||
    123          target == kCudnnConvBackwardFilterCallTarget;
    124 }
    125 
    126 bool ImplementedAsLibraryCall(const HloInstruction& hlo) {
    127   return ImplementedAsGemm(hlo) || IsCustomCallToDnnBatchNorm(hlo) ||
    128          IsCustomCallToDnnConvolution(hlo);
    129 }
    130 
    131 static HloInstruction* CreateCudnnConv(
    132     const char* call_target, const Shape& shape, HloInstruction* lhs,
    133     HloInstruction* rhs, const Window& window,
    134     const ConvolutionDimensionNumbers& dnums) {
    135   HloComputation* computation = lhs->parent();
    136 
    137   // This call returns a tuple of (conv_result, scratch_memory), where
    138   // conv_result is the actual result of the convolution, and scratch_memory is
    139   // temporary memory used by cudnn.
    140   //
    141   // At the moment, we don't know how much scratch memory this conv is going to
    142   // use, so we put u8[0] in this place.  Later on another pass will choose
    143   // which conv algorithm to use, and at that point we'll modify the shape of
    144   // this second tuple element.
    145   Shape call_shape =
    146       ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})});
    147 
    148   // Our CustomCall takes three arguments: The conv lhs and rhs, and the cudnn
    149   // algorithm to use.  It's up to a later pass to choose the algorithm, so to
    150   // indicate that we haven't yet made a choice, we speicfy -1 for that arg.
    151   HloInstruction* negative_one = computation->AddInstruction(
    152       HloInstruction::CreateConstant(Literal::CreateR0<int64>(-1)));
    153   HloInstruction* custom_call =
    154       computation->AddInstruction(HloInstruction::CreateCustomCall(
    155           call_shape, {lhs, rhs, negative_one}, call_target));
    156   custom_call->set_window(window);
    157   custom_call->set_convolution_dimension_numbers(dnums);
    158   return custom_call;
    159 }
    160 
    161 HloInstruction* CreateCudnnConvForward(
    162     const Shape& shape, HloInstruction* input, HloInstruction* kernel,
    163     const Window& window, const ConvolutionDimensionNumbers& dnums) {
    164   return CreateCudnnConv(kCudnnConvForwardCallTarget, shape, input, kernel,
    165                          window, dnums);
    166 }
    167 
    168 HloInstruction* CreateCudnnConvBackwardInput(
    169     const Shape& shape, HloInstruction* output, HloInstruction* reverse_filter,
    170     const Window& window, const ConvolutionDimensionNumbers& dnums) {
    171   return CreateCudnnConv(kCudnnConvBackwardInputCallTarget, shape, output,
    172                          reverse_filter, window, dnums);
    173 }
    174 
    175 HloInstruction* CreateCudnnConvBackwardFilter(
    176     const Shape& shape, HloInstruction* input, HloInstruction* output,
    177     const Window& window, const ConvolutionDimensionNumbers& dnums) {
    178   return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, shape, input,
    179                          output, window, dnums);
    180 }
    181 
    182 bool IsReductionToVector(const HloInstruction& reduce) {
    183   if (HloOpcode::kReduce != reduce.opcode()) {
    184     return false;
    185   }
    186   const HloInstruction* input = reduce.operand(0);
    187   std::vector<int64> dims_to_keep;
    188   for (int64 dim = 0; dim < input->shape().dimensions().size(); ++dim) {
    189     if (!std::count(reduce.dimensions().begin(), reduce.dimensions().end(),
    190                     dim)) {
    191       dims_to_keep.push_back(dim);
    192     }
    193   }
    194   return LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
    195                                               dims_to_keep) &&
    196          ShapeUtil::Equal(reduce.shape(), ShapeUtil::FilterDimensions(
    197                                               [&dims_to_keep](int64 dim) {
    198                                                 return std::count(
    199                                                     dims_to_keep.begin(),
    200                                                     dims_to_keep.end(), dim);
    201                                               },
    202                                               input->shape()));
    203 }
    204 
    205 // This emits a device-side call to
    206 // "i32 vprintf(i8* fmt, arguments_type* arguments)" in the driver; see
    207 // http://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/index.html#system-calls
    208 llvm::Value* EmitPrintf(tensorflow::StringPiece fmt,
    209                         tensorflow::gtl::ArraySlice<llvm::Value*> arguments,
    210                         llvm::IRBuilder<>* builder) {
    211   std::vector<llvm::Type*> argument_types;
    212   for (auto argument : arguments) {
    213     argument_types.push_back(argument->getType());
    214   }
    215   auto* arguments_type = llvm::StructType::create(argument_types);
    216   llvm::Value* arguments_ptr = builder->CreateAlloca(arguments_type);
    217   for (size_t i = 0; i < arguments.size(); ++i) {
    218     builder->CreateStore(
    219         arguments[i],
    220         builder->CreateGEP(arguments_ptr,
    221                            {builder->getInt64(0), builder->getInt32(i)}));
    222   }
    223   return builder->CreateCall(
    224       builder->GetInsertBlock()->getParent()->getParent()->getOrInsertFunction(
    225           "vprintf",
    226           llvm::FunctionType::get(builder->getInt32Ty(),
    227                                   {builder->getInt8Ty()->getPointerTo(),
    228                                    arguments_type->getPointerTo()},
    229                                   /*isVarArg=*/false)),
    230       {builder->CreateGlobalStringPtr(llvm_ir::AsStringRef(fmt)),
    231        arguments_ptr});
    232 }
    233 
    234 llvm::Value* EmitShuffleDown(llvm::Value* value, llvm::Value* offset,
    235                              llvm::IRBuilder<>* builder) {
    236   int bit_width = value->getType()->getPrimitiveSizeInBits();
    237 
    238   // Special case for efficiency
    239   if (value->getType()->isFloatTy() && bit_width == 32) {
    240     return llvm_ir::EmitCallToIntrinsic(
    241         llvm::Intrinsic::nvvm_shfl_down_f32,
    242         {value, offset, builder->getInt32(kWarpSize - 1)}, {}, builder);
    243   }
    244 
    245   // We must split values wider than 32 bits as the "shfl" instruction operates
    246   // on 32-bit values.
    247   int num_segments = CeilOfRatio(bit_width, 32);
    248   llvm::Value* x = builder->CreateBitCast(
    249       builder->CreateZExt(
    250           builder->CreateBitCast(value, builder->getIntNTy(bit_width)),
    251           builder->getIntNTy(32 * num_segments)),
    252       llvm::VectorType::get(builder->getInt32Ty(), num_segments));
    253   for (int i = 0; i < num_segments; ++i) {
    254     x = builder->CreateInsertElement(
    255         x,
    256         llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_shfl_down_i32,
    257                                      {builder->CreateExtractElement(x, i),
    258                                       offset, builder->getInt32(kWarpSize - 1)},
    259                                      {}, builder),
    260         i);
    261   }
    262   return builder->CreateBitCast(
    263       builder->CreateTrunc(
    264           builder->CreateBitCast(x, builder->getIntNTy(32 * num_segments)),
    265           builder->getIntNTy(bit_width)),
    266       value->getType());
    267 }
    268 
    269 }  // namespace gpu
    270 }  // namespace xla
    271