Home | History | Annotate | Download | only in gpu
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
     17 #include "tensorflow/compiler/xla/shape_util.h"
     18 #include "tensorflow/compiler/xla/status_macros.h"
     19 #include "tensorflow/compiler/xla/util.h"
     20 
     21 namespace xla {
     22 namespace gpu {
     23 namespace {
     24 
     25 namespace se = ::perftools::gputools;
     26 
     27 using se::DeviceMemory;
     28 using se::DeviceMemoryBase;
     29 using se::Stream;
     30 using se::dnn::AlgorithmConfig;
     31 using se::dnn::BatchDescriptor;
     32 using se::dnn::ConvolutionDescriptor;
     33 using se::dnn::DataLayout;
     34 using se::dnn::DimIndex;
     35 using se::dnn::FilterDescriptor;
     36 using se::dnn::FilterLayout;
     37 using se::dnn::ProfileResult;
     38 
     39 // A StreamExecutor ScratchAllocator that wraps a single XLA allocation,
     40 // returning it (in its entirety) the first time Allocate() is called.
     41 class ScratchBufAllocator : public se::ScratchAllocator {
     42  public:
     43   explicit ScratchBufAllocator(se::DeviceMemoryBase scratch)
     44       : scratch_(scratch) {}
     45 
     46   ~ScratchBufAllocator() override = default;
     47 
     48   int64 GetMemoryLimitInBytes(se::Stream* /*stream*/) override {
     49     return scratch_.size();
     50   }
     51 
     52   se::port::StatusOr<DeviceMemory<uint8>> AllocateBytes(
     53       se::Stream* stream, int64 byte_size) override {
     54     if (allocated_) {
     55       return se::port::InternalError(
     56           "Can't allocate twice from a ScratchBufAllocator.");
     57     }
     58     if (byte_size > scratch_.size()) {
     59       return se::port::InternalError(tensorflow::strings::StrCat(
     60           "Can't allocate ", byte_size,
     61           " bytes from a ScratchBufAllocator of size ", scratch_.size()));
     62     }
     63 
     64     allocated_ = true;
     65     return se::DeviceMemory<uint8>(scratch_);
     66   }
     67 
     68  private:
     69   se::DeviceMemoryBase scratch_;
     70   bool allocated_ = false;
     71 };
     72 
     73 template <typename T>
     74 Status RunCudnnConvolution(
     75     CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
     76     const Shape& output_shape, DeviceMemory<T> input_buf,
     77     DeviceMemory<T> filter_buf, DeviceMemory<T> output_buf,
     78     se::ScratchAllocator* scratch_allocator, const Window& window,
     79     const ConvolutionDimensionNumbers& dnums, AlgorithmConfig algorithm,
     80     Stream* stream, ProfileResult* profile_result /*= nullptr*/) {
     81   VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm().algo_id();
     82   VLOG(3) << "tensor_ops_enabled: "
     83           << algorithm.algorithm().tensor_ops_enabled();
     84   VLOG(3) << "Convolution kind: " << CudnnConvKindToString(kind);
     85   VLOG(3) << "input shape: { " << ShapeUtil::HumanString(input_shape) << " }";
     86   VLOG(3) << "filter shape: { " << ShapeUtil::HumanString(filter_shape) << " }";
     87   VLOG(3) << "Output shape: { " << ShapeUtil::HumanString(output_shape) << " }";
     88   VLOG(3) << "Window: { " << window.ShortDebugString() << " }";
     89   VLOG(3) << "Dim nums: { " << dnums.ShortDebugString() << " }";
     90 
     91   const int num_dimensions = window.dimensions_size();
     92   CHECK_LE(num_dimensions, 3);
     93   // cuDNN does not support 1D convolutions. We therefore express 1D
     94   // convolutions as 2D convolutions where the first spatial dimension is 1.
     95   // This matches the behavior of TF (see definition of conv1d in
     96   // tensorflow/python/ops/nn_ops.py).
     97   const int effective_num_dimensions = std::max(2, num_dimensions);
     98 
     99   if (std::is_same<T, float>::value) {
    100     CHECK_EQ(F32, output_shape.element_type())
    101         << ShapeUtil::HumanString(output_shape);
    102   } else if (std::is_same<T, Eigen::half>::value) {
    103     CHECK_EQ(F16, output_shape.element_type())
    104         << ShapeUtil::HumanString(output_shape);
    105   } else {
    106     LOG(FATAL) << ShapeUtil::HumanString(output_shape);
    107   }
    108 
    109   CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size());
    110   CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size());
    111   CHECK_EQ(num_dimensions, dnums.output_spatial_dimensions_size());
    112   for (const WindowDimension& dim : window.dimensions()) {
    113     CHECK_EQ(dim.padding_low(), dim.padding_high());
    114   }
    115 
    116   // cuDNN's convolution APIs support the BDYX layout for activations/output and
    117   // the OIYX layout for weights.
    118   BatchDescriptor input_descriptor(effective_num_dimensions);
    119   input_descriptor.set_layout(DataLayout::kBatchDepthYX)
    120       .set_feature_map_count(
    121           input_shape.dimensions(dnums.input_feature_dimension()))
    122       .set_count(input_shape.dimensions(dnums.input_batch_dimension()));
    123   for (int dim = 0; dim < num_dimensions; ++dim) {
    124     // Note that the dimensions are reversed. The same holds below.
    125     input_descriptor.set_spatial_dim(
    126         static_cast<DimIndex>(effective_num_dimensions - dim - 1),
    127         input_shape.dimensions(dnums.input_spatial_dimensions(dim)));
    128   }
    129 
    130   FilterDescriptor filter_descriptor(effective_num_dimensions);
    131   filter_descriptor.set_layout(FilterLayout::kOutputInputYX)
    132       .set_input_feature_map_count(
    133           filter_shape.dimensions(dnums.kernel_input_feature_dimension()))
    134       .set_output_feature_map_count(
    135           filter_shape.dimensions(dnums.kernel_output_feature_dimension()));
    136   for (int dim = 0; dim < num_dimensions; ++dim) {
    137     filter_descriptor.set_spatial_dim(
    138         static_cast<DimIndex>(effective_num_dimensions - dim - 1),
    139         filter_shape.dimensions(dnums.kernel_spatial_dimensions(dim)));
    140   }
    141 
    142   ConvolutionDescriptor convolution_descriptor(effective_num_dimensions);
    143   for (int dim = 0; dim < num_dimensions; ++dim) {
    144     convolution_descriptor
    145         .set_zero_padding(
    146             static_cast<DimIndex>(effective_num_dimensions - dim - 1),
    147             window.dimensions(dim).padding_low())
    148         .set_filter_stride(
    149             static_cast<DimIndex>(effective_num_dimensions - dim - 1),
    150             window.dimensions(dim).stride());
    151   }
    152 
    153   BatchDescriptor output_descriptor(effective_num_dimensions);
    154   output_descriptor.set_layout(DataLayout::kBatchDepthYX)
    155       .set_feature_map_count(
    156           output_shape.dimensions(dnums.output_feature_dimension()))
    157       .set_count(output_shape.dimensions(dnums.output_batch_dimension()));
    158   for (int dim = 0; dim < num_dimensions; ++dim) {
    159     output_descriptor.set_spatial_dim(
    160         static_cast<DimIndex>(effective_num_dimensions - dim - 1),
    161         output_shape.dimensions(dnums.output_spatial_dimensions(dim)));
    162   }
    163 
    164   // Add a singleton dimension in the 1D convolution case.
    165   if (num_dimensions == 1) {
    166     input_descriptor.set_spatial_dim(static_cast<DimIndex>(0), 1);
    167     output_descriptor.set_spatial_dim(static_cast<DimIndex>(0), 1);
    168     filter_descriptor.set_spatial_dim(static_cast<DimIndex>(0), 1);
    169     convolution_descriptor.set_zero_padding(static_cast<DimIndex>(0), 0)
    170         .set_filter_stride(static_cast<DimIndex>(0), 1);
    171   }
    172 
    173   switch (kind) {
    174     case CudnnConvKind::kForward:
    175       stream->ThenConvolveWithAlgorithm(
    176           input_descriptor, input_buf, filter_descriptor, filter_buf,
    177           convolution_descriptor, output_descriptor, &output_buf,
    178           scratch_allocator, algorithm, profile_result);
    179       break;
    180     case CudnnConvKind::kBackwardInput:
    181       stream->ThenConvolveBackwardDataWithAlgorithm(
    182           filter_descriptor, filter_buf, output_descriptor, output_buf,
    183           convolution_descriptor, input_descriptor, &input_buf,
    184           scratch_allocator, algorithm, profile_result);
    185       break;
    186     case CudnnConvKind::kBackwardFilter:
    187       stream->ThenConvolveBackwardFilterWithAlgorithm(
    188           input_descriptor, input_buf, output_descriptor, output_buf,
    189           convolution_descriptor, filter_descriptor, &filter_buf,
    190           scratch_allocator, algorithm, profile_result);
    191       break;
    192   }
    193 
    194   if (!stream->ok()) {
    195     return InternalError(
    196         "Unable to launch convolution with type %s and algorithm (%lld, %lld)",
    197         CudnnConvKindToString(kind).c_str(), algorithm.algorithm().algo_id(),
    198         algorithm.algorithm_no_scratch().algo_id());
    199   }
    200   return Status::OK();
    201 }
    202 
    203 }  // anonymous namespace
    204 
    205 string CudnnConvKindToString(CudnnConvKind kind) {
    206   switch (kind) {
    207     case CudnnConvKind::kForward:
    208       return "forward";
    209     case CudnnConvKind::kBackwardFilter:
    210       return "backward_filter";
    211     case CudnnConvKind::kBackwardInput:
    212       return "backward_input";
    213   }
    214 }
    215 
    216 Status RunCudnnConvolution(
    217     CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
    218     const Shape& output_shape, perftools::gputools::DeviceMemoryBase input_buf,
    219     perftools::gputools::DeviceMemoryBase filter_buf,
    220     perftools::gputools::DeviceMemoryBase output_buf,
    221     perftools::gputools::DeviceMemoryBase scratch_buf, const Window& window,
    222     const ConvolutionDimensionNumbers& dnums,
    223     perftools::gputools::dnn::AlgorithmConfig algorithm,
    224     perftools::gputools::Stream* stream,
    225     perftools::gputools::dnn::ProfileResult* profile_result) {
    226   ScratchBufAllocator scratch_allocator(scratch_buf);
    227   return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
    228                              input_buf, filter_buf, output_buf,
    229                              &scratch_allocator, window, dnums, algorithm,
    230                              stream, profile_result);
    231 }
    232 
    233 Status RunCudnnConvolution(
    234     CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
    235     const Shape& output_shape, perftools::gputools::DeviceMemoryBase input_buf,
    236     perftools::gputools::DeviceMemoryBase filter_buf,
    237     perftools::gputools::DeviceMemoryBase output_buf,
    238     perftools::gputools::ScratchAllocator* scratch_allocator,
    239     const Window& window, const ConvolutionDimensionNumbers& dnums,
    240     perftools::gputools::dnn::AlgorithmConfig algorithm,
    241     perftools::gputools::Stream* stream,
    242     perftools::gputools::dnn::ProfileResult* profile_result) {
    243   PrimitiveType output_primitive_type = output_shape.element_type();
    244   CHECK(output_primitive_type == F32 || output_primitive_type == F16)
    245       << ShapeUtil::HumanString(output_shape);
    246   if (output_primitive_type == F32) {
    247     return RunCudnnConvolution(
    248         kind, input_shape, filter_shape, output_shape,
    249         se::DeviceMemory<float>(input_buf), se::DeviceMemory<float>(filter_buf),
    250         se::DeviceMemory<float>(output_buf), scratch_allocator, window, dnums,
    251         algorithm, stream, profile_result);
    252   }
    253   return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
    254                              se::DeviceMemory<Eigen::half>(input_buf),
    255                              se::DeviceMemory<Eigen::half>(filter_buf),
    256                              se::DeviceMemory<Eigen::half>(output_buf),
    257                              scratch_allocator, window, dnums, algorithm,
    258                              stream, profile_result);
    259 }
    260 
    261 }  // namespace gpu
    262 }  // namespace xla
    263