Home | History | Annotate | Download | only in gpu
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h"
     17 
     18 #include <numeric>
     19 #include <vector>
     20 
     21 #include "tensorflow/compiler/xla/literal_util.h"
     22 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
     23 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
     24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     26 #include "tensorflow/compiler/xla/util.h"
     27 #include "tensorflow/compiler/xla/window_util.h"
     28 #include "tensorflow/compiler/xla/xla_data.pb.h"
     29 #include "tensorflow/core/lib/core/status.h"
     30 #include "tensorflow/core/platform/logging.h"
     31 
     32 namespace xla {
     33 namespace gpu {
     34 
     35 namespace {
     36 
     37 bool CanImplementAsCudnnForwardConv(HloInstruction* conv) {
     38   const ConvolutionDimensionNumbers& dnums =
     39       conv->convolution_dimension_numbers();
     40   if (dnums.input_spatial_dimensions_size() > 3) {
     41     return false;
     42   }
     43 
     44   // CuDNN does not accept zero-element arguments
     45   if (ShapeUtil::HasZeroElements(conv->operand(0)->shape()) ||
     46       ShapeUtil::HasZeroElements(conv->operand(1)->shape())) {
     47     return false;
     48   }
     49 
     50   if (window_util::HasWindowReversal(conv->window())) {
     51     return false;
     52   }
     53   return true;
     54 }
     55 
     56 // Try to match a backward filter pattern that contains "conv".
     57 // Precondition: "conv" is a kConvolution.
     58 std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardFilter(
     59     HloInstruction* conv) {
     60   const auto no_match_result =
     61       std::make_tuple(false, Window(), ConvolutionDimensionNumbers());
     62   // Step 1: match the instruction pattern without considering the paddings and
     63   // dimension numbers just yet. We may need some generic pattern matcher
     64   // similar to third_party/llvm/llvm/include/llvm/IR/PatternMatch.h
     65   //
     66   // Backward filter convolution is implemented in XLA as the forward
     67   // convolution of padded activations and dilated gradients. Padding on
     68   // activations and dilation on gradients are specified in the "window" field
     69   // of the forward convolution.
     70   //
     71   //        activations  gradients
     72   //              \         /
     73   //               v       v
     74   //              Convolution
     75   //                 conv
     76   CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
     77 
     78   // Step 2: match paddings and dimension numbers of the forward convolution.
     79   const ConvolutionDimensionNumbers& conv_dnums =
     80       conv->convolution_dimension_numbers();
     81   auto input_batch_dim = conv_dnums.input_batch_dimension();
     82   auto input_feature_dim = conv_dnums.input_feature_dimension();
     83   auto input_spatial_dims = conv_dnums.input_spatial_dimensions();
     84   auto kernel_input_feature_dim = conv_dnums.kernel_input_feature_dimension();
     85   auto kernel_output_feature_dim = conv_dnums.kernel_output_feature_dimension();
     86   auto kernel_spatial_dims = conv_dnums.kernel_spatial_dimensions();
     87   auto output_batch_dim = conv_dnums.output_batch_dimension();
     88   auto output_feature_dim = conv_dnums.output_feature_dimension();
     89   auto output_spatial_dims = conv_dnums.output_spatial_dimensions();
     90 
     91   for (const WindowDimension& window_dim : conv->window().dimensions()) {
     92     if (window_dim.stride() != 1) {
     93       VLOG(1) << "Forward convolution's window "
     94               << conv->window().ShortDebugString()
     95               << " should have stride of 1.";
     96       return no_match_result;
     97     }
     98     if (window_dim.base_dilation() != 1) {
     99       VLOG(1) << "Forward convolution's window "
    100               << conv->window().ShortDebugString()
    101               << " should have no base (LHS) dilation.";
    102       return no_match_result;
    103     }
    104     if (window_dim.padding_low() < 0) {
    105       VLOG(1) << "Padding low should be non-negative.";
    106       return no_match_result;
    107     }
    108     if (window_dim.window_reversal()) {
    109       VLOG(1) << "Window reversal field not supported";
    110       return no_match_result;
    111     }
    112     // Padding high will be checked in Step 3.
    113   }
    114   if (input_batch_dim == output_batch_dim &&
    115       !window_util::HasWindowDilation(conv->window())) {
    116     VLOG(1) << conv->ToString()
    117             << " is a regular forward convolution. No need "
    118                "to fold it to a backward filter convolution.";
    119     return no_match_result;
    120   }
    121 
    122   // Step 3: fuse the matched HLOs into a backward convolution instruction.
    123   //
    124   // Compute the window of the backward convolution.
    125   Window backward_conv_window;
    126   for (int i = 0; i < input_spatial_dims.size(); ++i) {
    127     WindowDimension* dim = backward_conv_window.add_dimensions();
    128     // The window size of the backward convolution equals the output size of the
    129     // forward convolution.
    130     int64 filter_size = conv->shape().dimensions(output_spatial_dims[i]);
    131     dim->set_size(filter_size);
    132     // The window stride equals the window dilation of the forward convolution.
    133     dim->set_stride(conv->window().dimensions(i).window_dilation());
    134     // The window's low padding is the same as the low padding of the
    135     // activations.
    136     dim->set_padding_low(conv->window().dimensions(i).padding_low());
    137 
    138     int64 input_size =
    139         conv->operand(0)->shape().dimensions(input_spatial_dims[i]);
    140     int64 output_size = conv->window().dimensions(i).size();
    141     // Compute the range of the amount of valid high padding. We first compute
    142     // min_padding_high, the amount of padding on the right/bottom to ensure the
    143     // last patch ends at the border, i.e.,
    144     //
    145     //   input_size + dim->padding_low() + min_padding_high
    146     //     = (output_size - 1) * stride + filter_size
    147     //
    148     // Because convolution ignores trailing incomplete windows, any amount of
    149     // padding high from min_padding_high to min_padding_high+stride-1
    150     // (max_padding_high) has the same effect.
    151     int64 padded_input_size = filter_size + (output_size - 1) * dim->stride();
    152     int64 min_padding_high =
    153         padded_input_size - input_size - dim->padding_low();
    154     int64 max_padding_high = min_padding_high + dim->stride() - 1;
    155     CHECK_GE(dim->padding_low(), 0);
    156     // In practice, since cuDNN convolution only supports even padding, we make
    157     // the amount of high padding the same as the amount of low padding as long
    158     // as it is between min_padding_high and max_padding_high. If it is not in
    159     // that range, we pick the one that's closest to dim->padding_low() and let
    160     // PadInsertion canonicalize the resultant backward convolution later.
    161     // Picking the closest one minimizes the cost of the kPad instruction to be
    162     // inserted by PadInsertion.
    163     if (dim->padding_low() >= min_padding_high &&
    164         dim->padding_low() <= max_padding_high) {
    165       dim->set_padding_high(dim->padding_low());
    166     } else {
    167       if (dim->padding_low() < min_padding_high) {
    168         dim->set_padding_high(min_padding_high);
    169       } else {
    170         dim->set_padding_high(max_padding_high);
    171       }
    172     }
    173     if (dim->padding_high() < 0) {
    174       LOG(ERROR)
    175           << "Fusing this pattern to backward filter convolution would cause "
    176              "negative padding ("
    177           << dim->padding_high()
    178           << ") on right/bottom of the weight gradients, which is not "
    179              "supported by PadInsertion (b/32744257). Falling back to "
    180              "unfused convolution for instruction: "
    181           << conv->ToString();
    182       return no_match_result;
    183     }
    184   }
    185 
    186   // Restore the dimension numbers of the backward convolution from the forward
    187   // convolution. The two activation dimensions are reversed (batch and
    188   // feature).
    189   ConvolutionDimensionNumbers backward_conv_dnums;
    190   backward_conv_dnums.set_input_batch_dimension(input_feature_dim);
    191   backward_conv_dnums.set_input_feature_dimension(input_batch_dim);
    192   for (int i = 0; i < input_spatial_dims.size(); ++i) {
    193     backward_conv_dnums.add_input_spatial_dimensions(input_spatial_dims[i]);
    194   }
    195   backward_conv_dnums.set_output_batch_dimension(kernel_input_feature_dim);
    196   backward_conv_dnums.set_output_feature_dimension(kernel_output_feature_dim);
    197   for (int i = 0; i < kernel_spatial_dims.size(); ++i) {
    198     backward_conv_dnums.add_output_spatial_dimensions(kernel_spatial_dims[i]);
    199   }
    200   // The dimension numbering of the output of the forward convolution (before
    201   // transposition) is the same as that of the activations (according to the
    202   // semantics of kConvolution). The batch dimension of the activations should
    203   // be treated as the input feature dimension, and the feature dimension should
    204   // be treated as the output feature.
    205   backward_conv_dnums.set_kernel_input_feature_dimension(output_batch_dim);
    206   backward_conv_dnums.set_kernel_output_feature_dimension(output_feature_dim);
    207   for (int i = 0; i < output_spatial_dims.size(); ++i) {
    208     backward_conv_dnums.add_kernel_spatial_dimensions(output_spatial_dims[i]);
    209   }
    210 
    211   return std::make_tuple(true, backward_conv_window, backward_conv_dnums);
    212 }
    213 
    214 // Try to match a backward input pattern that contains "conv".
    215 // Precondition: "conv" is a kConvolution.
    216 std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardInput(
    217     HloInstruction* conv) {
    218   const auto no_match_result =
    219       std::make_tuple(false, Window(), ConvolutionDimensionNumbers());
    220 
    221   // Match instruction pattern.
    222   CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
    223   HloInstruction* reverse_filter = conv->mutable_operand(1);
    224 
    225   // Match the reverse of the filter.
    226   ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers();
    227   const auto& kernel_spatial_dims = dnums.kernel_spatial_dimensions();
    228   if (reverse_filter->opcode() == HloOpcode::kReverse) {
    229     if (kernel_spatial_dims.size() != reverse_filter->dimensions().size() ||
    230         !std::is_permutation(kernel_spatial_dims.begin(),
    231                              kernel_spatial_dims.end(),
    232                              reverse_filter->dimensions().begin())) {
    233       VLOG(1)
    234           << "Backward input convolution should reverse all kernel dimensions.";
    235       return no_match_result;
    236     }
    237   } else {
    238     // Possibly 1x1 filter.
    239     for (int64 i = 0; i < kernel_spatial_dims.size(); ++i) {
    240       if (conv->window().dimensions(i).size() != 1) {
    241         VLOG(1) << "The reverse filter is neither a kReverse nor a 1x1 filter: "
    242                 << reverse_filter->ToString();
    243         return no_match_result;
    244       }
    245     }
    246     if (!window_util::HasBaseDilation(conv->window())) {
    247       VLOG(1) << conv->ToString()
    248               << " is a regular forward convolution. No need "
    249                  "to fold it to a backward input convolution.";
    250       return no_match_result;
    251     }
    252   }
    253 
    254   // Match padding and dilation of the forward convolution.
    255   for (const WindowDimension& window_dim : conv->window().dimensions()) {
    256     if (window_dim.stride() != 1) {
    257       VLOG(1) << "Forward convolution's window "
    258               << conv->window().ShortDebugString()
    259               << " should have stride of 1.";
    260       return no_match_result;
    261     }
    262     if (window_dim.window_dilation() != 1) {
    263       VLOG(1) << "Forward convolution's window "
    264               << conv->window().ShortDebugString()
    265               << " should have no window dilation.";
    266       return no_match_result;
    267     }
    268     if (window_dim.window_reversal()) {
    269       VLOG(1) << "Window reversal field not supported";
    270       return no_match_result;
    271     }
    272   }
    273 
    274   const auto& input_spatial_dims = dnums.input_spatial_dimensions();
    275   const auto& output_spatial_dims = dnums.output_spatial_dimensions();
    276   CHECK_EQ(conv->window().dimensions().size(), input_spatial_dims.size());
    277   CHECK_EQ(output_spatial_dims.size(), input_spatial_dims.size());
    278 
    279   const Window& old_window = conv->window();
    280   Window new_window = old_window;
    281   for (size_t i = 0; i < input_spatial_dims.size(); ++i) {
    282     // Restore backward convolution's padding config from the matched pattern.
    283     // See the comment in tensorflow/core/kernels/conv_grad_tuple_ops.cc
    284     // for how we convert backward input convolution to a variant of forward
    285     // convolution.
    286     //
    287     // The stride of the backward convolution
    288     // = the base dilation factor of the forward convolution
    289     auto dim = new_window.mutable_dimensions(i);
    290     dim->set_stride(old_window.dimensions(i).base_dilation());
    291 
    292     // The low padding = kernel_size - 1 - low padding on the gradients
    293     // Make sure the low padding is not negative.
    294     auto kernel_size = old_window.dimensions(i).size();
    295     auto backward_padding_low =
    296         kernel_size - 1 - old_window.dimensions(i).padding_low();
    297     if (backward_padding_low < 0) {
    298       LOG(ERROR)
    299           << "The low padding of the backward convolution would be negative ("
    300           << backward_padding_low
    301           << "), which isn't supported by PadInsertion for now (b/32744257).";
    302       return no_match_result;
    303     }
    304     dim->set_padding_low(backward_padding_low);
    305 
    306     // Compute the range of the amount of padding on the right/bottom of the
    307     // activations. XLA's convolution requires all patches to be within the
    308     // padded base. This gives us flexiblity to choose the amount of high
    309     // padding from a set of values without changing the result of the backward
    310     // convolution. The minimum amount (min_padding_high) makes the last patch
    311     // end at the border. The maximum amount (max_padding_high) equals
    312     // min_padding_high+stride-1 -- max_padding_high+1 would cause the output
    313     // size to change.
    314     auto unpadded_input_size = conv->shape().dimensions(output_spatial_dims[i]);
    315     auto output_size =
    316         conv->operand(0)->shape().dimensions(input_spatial_dims[i]);
    317     auto padded_input_size = kernel_size + dim->stride() * (output_size - 1);
    318     auto total_pad_size = padded_input_size - unpadded_input_size;
    319     auto min_padding_high = total_pad_size - backward_padding_low;
    320     auto max_padding_high = min_padding_high + dim->stride() - 1;
    321 
    322     if (backward_padding_low >= min_padding_high &&
    323         backward_padding_low <= max_padding_high) {
    324       // In the best case (most likely), if backward_padding_low is in the range
    325       // of the amounts of valid high padding, we choose backward_padding_low
    326       // because cuDNN supports even padding only.
    327       dim->set_padding_high(backward_padding_low);
    328     } else {
    329       // Otherwise, we choose the amount that's closest to backward_padding_low,
    330       // and PadInsertion will later insert kSlice instructions to enforce even
    331       // padding.
    332       //
    333       // For example, consider the backward convolution pattern
    334       //
    335       //   ab     xy
    336       //   | pad  | reverse
    337       //  .a.b    yx
    338       //     \   /
    339       //      ABC
    340       //
    341       // The amount of low padding on activations (in backward convolution) is
    342       //   backward_padding_low = kernel_size - 1 - forward_padding_low
    343       //                        = 2 - 1 - 1 = 0
    344       //
    345       // The amount of padding high must be between 1 and 2, in order to make
    346       // Conv(ABC, xy, stride=2) produce exactly 2 elements (ab). 0 is not in
    347       // the range of [1,2], so we pick the closest valid amount of padding
    348       // high, which is 1 in this case. Therefore, we fuse the above pattern to
    349       //
    350       //   ABC = BackwardInputConv(ab, xy, stride=2, padding_high=1)
    351       if (backward_padding_low < min_padding_high) {
    352         dim->set_padding_high(min_padding_high);
    353       } else {
    354         dim->set_padding_high(max_padding_high);
    355       }
    356     }
    357     // PadInsertion doesn't handle backward input convolution with negative
    358     // padding for now. So fall back to unfused convolution in case of negative
    359     // padding. For example,
    360     //   ABCD = Conv(abc, reverse(xy), padding_high=2)
    361     // could be fused to
    362     //   ABCD = BackwardInputConv(abc, xy, padding_low=1, padding_high=-1)
    363     // with positive padding low but negative padding high.
    364     if (dim->padding_high() < 0) {
    365       LOG(ERROR) << "Fusing this pattern to backward convolution would cause "
    366                     "negative padding ("
    367                  << dim->padding_high()
    368                  << ") on right/bottom of the activations, which is not "
    369                     "supported by PadInsertion (b/32744257). Falling back to "
    370                     "unfused convolution for instruction: "
    371                  << conv->ToString();
    372       return no_match_result;
    373     }
    374   }
    375 
    376   // Fuse the matched HLOs into a backward convolution instruction.
    377   //
    378   // If the reverse is omitted (for 1x1 filters) in the original pattern, we add
    379   // it back in the fusion instruction so that later passes (such as
    380   // PadInsertion) can handle such fusion instructions easily.
    381   if (reverse_filter->opcode() != HloOpcode::kReverse) {
    382     reverse_filter = reverse_filter->parent()->AddInstruction(
    383         HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter,
    384                                       AsInt64Slice(kernel_spatial_dims)));
    385     TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_no=*/1, reverse_filter));
    386   }
    387   dnums.set_kernel_input_feature_dimension(
    388       conv->convolution_dimension_numbers().kernel_output_feature_dimension());
    389   dnums.set_kernel_output_feature_dimension(
    390       conv->convolution_dimension_numbers().kernel_input_feature_dimension());
    391 
    392   return std::make_tuple(true, new_window, dnums);
    393 }
    394 
    395 // Tries to rewrite a single convolution into a call to cudnn.
    396 StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
    397   CHECK_EQ(conv->opcode(), HloOpcode::kConvolution);
    398 
    399   HloInstruction* custom_call = [&]() -> HloInstruction* {
    400     bool match;
    401     Window window;
    402     ConvolutionDimensionNumbers dnums;
    403 
    404     std::tie(match, window, dnums) = MatchBackwardFilter(conv);
    405     if (match) {
    406       return CreateCudnnConvBackwardFilter(
    407           conv->shape(), conv->mutable_operand(0), conv->mutable_operand(1),
    408           window, dnums);
    409     }
    410 
    411     std::tie(match, window, dnums) = MatchBackwardInput(conv);
    412     if (match) {
    413       // Backward input conv subsumes the conv plus the reverse in operand 1.
    414       HloInstruction* reverse = conv->mutable_operand(1);
    415       CHECK_EQ(reverse->opcode(), HloOpcode::kReverse);
    416       HloInstruction* rhs = reverse->mutable_operand(0);
    417 
    418       return CreateCudnnConvBackwardInput(
    419           conv->shape(), conv->mutable_operand(0), rhs, window, dnums);
    420     }
    421 
    422     // If all else fails, try a forward convolution.
    423     if (CanImplementAsCudnnForwardConv(conv)) {
    424       return CreateCudnnConvForward(conv->shape(), conv->mutable_operand(0),
    425                                     conv->mutable_operand(1), conv->window(),
    426                                     conv->convolution_dimension_numbers());
    427     }
    428 
    429     return nullptr;
    430   }();
    431 
    432   if (custom_call == nullptr) {
    433     return false;
    434   }
    435 
    436   // The CustomCall returns a tuple (conv_result, scratch_memory).  Extract out
    437   // the conv result and replace `conv` with it.
    438   TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction(
    439       conv,
    440       HloInstruction::CreateGetTupleElement(conv->shape(), custom_call, 0)));
    441   return true;
    442 }
    443 
    444 // Rewrites the convolutions in the given computation into calls to cudnn.
    445 // Returns true if it made any changes.
    446 StatusOr<bool> RunOnComputation(HloComputation* computation) {
    447   std::vector<HloInstruction*> convs;
    448   for (auto* hlo : computation->instructions()) {
    449     if (hlo->opcode() == HloOpcode::kConvolution) {
    450       convs.push_back(hlo);
    451     }
    452   }
    453 
    454   bool changed = false;
    455   for (HloInstruction* conv : convs) {
    456     TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(conv));
    457     changed |= result;
    458   }
    459   return changed;
    460 }
    461 }  // namespace
    462 
    463 StatusOr<bool> CudnnConvolutionRewriter::Run(HloModule* module) {
    464   bool changed = false;
    465   for (HloComputation* computation : module->MakeNonfusionComputations()) {
    466     TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
    467     changed |= result;
    468   }
    469   return changed;
    470 }
    471 
    472 }  // namespace gpu
    473 }  // namespace xla
    474