Home | History | Annotate | Download | only in gpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/gpu/pad_insertion.h"
     17 
     18 #include "tensorflow/compiler/xla/literal_util.h"
     19 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
     20 #include "tensorflow/compiler/xla/service/shape_inference.h"
     21 #include "tensorflow/compiler/xla/util.h"
     22 #include "tensorflow/compiler/xla/window_util.h"
     23 #include "tensorflow/compiler/xla/xla_data.pb.h"
     24 
     25 namespace xla {
     26 namespace gpu {
     27 
     28 namespace {
     29 bool IsForwardConvolutionCanonical(const HloInstruction& conv) {
     30   CHECK_EQ(conv.custom_call_target(), kCudnnConvForwardCallTarget);
     31   return window_util::HasSymmetricPadding(conv.window()) &&
     32          !window_util::HasNegativePadding(conv.window()) &&
     33          !window_util::HasDilation(conv.window());
     34 }
     35 
     36 // If the (positive and negative) padding on the input operand of a convolution
     37 // can't be folded into a cuDNN convolution libcall (e.g. uneven padding and
     38 // dilation), returns kPad and/or kSlice instructions that explicitly apply the
     39 // padding; otherwise returns the original input operand. When there is both
     40 // positive padding (including dilation) and negative padding, we insert both
     41 // kPad and kSlice.
     42 HloInstruction* MaybePaddedAndSlicedInput(
     43     const Window& conv_window, const ConvolutionDimensionNumbers& conv_dnums,
     44     HloInstruction* input) {
     45   HloComputation* computation = input->parent();
     46   if (!window_util::HasSymmetricPadding(conv_window) ||
     47       window_util::HasBaseDilation(conv_window)) {
     48     // If padding is uneven or has dilation, we insert a kPad instruction that
     49     // applies positive padding and dilation.
     50     //
     51     // TODO(phawkins): If conv_window has asymmetric padding, perhaps instead of
     52     // moving all the padding into an explicit pad op, we should keep as much
     53     // padding inside of cudnn as possible, on the assumption that padding
     54     // within cudnn is basically free, whereas a kPad's cost increases as the
     55     // amount of padding increases.
     56     PaddingConfig padding_config =
     57         MakeNoPaddingConfig(input->shape().dimensions_size());
     58     for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) {
     59       int64 dim = conv_dnums.input_spatial_dimensions(i);
     60       padding_config.mutable_dimensions(dim)->set_edge_padding_low(
     61           std::max<int64>(0LL, conv_window.dimensions(i).padding_low()));
     62       padding_config.mutable_dimensions(dim)->set_edge_padding_high(
     63           std::max<int64>(0LL, conv_window.dimensions(i).padding_high()));
     64       padding_config.mutable_dimensions(dim)->set_interior_padding(
     65           conv_window.dimensions(i).base_dilation() - 1);
     66     }
     67     PrimitiveType element_type = input->shape().element_type();
     68     HloInstruction* padding =
     69         computation->AddInstruction(HloInstruction::CreateConstant(
     70             MakeUnique<Literal>(Literal::Zero(element_type))));
     71     input = computation->AddInstruction(HloInstruction::CreatePad(
     72         ShapeInference::InferPadShape(
     73             /*operand_shape=*/input->shape(),
     74             /*padding_value_shape=*/ShapeUtil::MakeShape(element_type, {}),
     75             padding_config)
     76             .ConsumeValueOrDie(),
     77         input, padding, padding_config));
     78   }
     79 
     80   if (window_util::HasNegativePadding(conv_window)) {
     81     // If the window has negative padding, insert a kSlice that explicitly
     82     // applies negative padding.
     83     //
     84     // For each dimension, initialize the start index to 0 and the limit index
     85     // to the size of that dimension.
     86     std::vector<int64> start_indices(input->shape().dimensions_size(), 0);
     87     std::vector<int64> limit_indices(input->shape().dimensions().begin(),
     88                                      input->shape().dimensions().end());
     89     std::vector<int64> strides(input->shape().dimensions_size(), 1);
     90     for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) {
     91       int64 dim = conv_dnums.input_spatial_dimensions(i);
     92       // If dimension "dim" has negative padding, increase the start index or
     93       // decrement the limit index by the amount of negative padding.
     94       start_indices[dim] +=
     95           std::max<int64>(0LL, -conv_window.dimensions(i).padding_low());
     96       limit_indices[dim] -=
     97           std::max<int64>(0LL, -conv_window.dimensions(i).padding_high());
     98     }
     99 
    100     input = computation->AddInstruction(HloInstruction::CreateSlice(
    101         ShapeInference::InferSliceShape(input->shape(), start_indices,
    102                                         limit_indices, strides)
    103             .ConsumeValueOrDie(),
    104         input, start_indices, limit_indices, strides));
    105   }
    106 
    107   return input;
    108 }
    109 
    110 // If the padding on the kernel operand of a convolution can't be folded into a
    111 // cuDNN convolution libcall (e.g. dilation), returns a kPad instruction that
    112 // explicitly applies the padding; otherwise returns the original kernel
    113 // operand.
    114 HloInstruction* MaybePaddedKernel(const Window& conv_window,
    115                                   const ConvolutionDimensionNumbers& conv_dnums,
    116                                   HloInstruction* kernel) {
    117   if (!window_util::HasWindowDilation(conv_window)) {
    118     return kernel;
    119   }
    120 
    121   // Compute the shape and padding config of the pad to be inserted.
    122   PaddingConfig padding_config;
    123   for (size_t i = 0; i < kernel->shape().dimensions_size(); ++i) {
    124     padding_config.add_dimensions();
    125   }
    126   for (size_t i = 0; i < conv_dnums.kernel_spatial_dimensions().size(); ++i) {
    127     int64 dim = conv_dnums.kernel_spatial_dimensions(i);
    128     padding_config.mutable_dimensions(dim)->set_interior_padding(
    129         conv_window.dimensions(i).window_dilation() - 1);
    130   }
    131 
    132   HloComputation* computation = kernel->parent();
    133   PrimitiveType element_type = kernel->shape().element_type();
    134   HloInstruction* padding =
    135       computation->AddInstruction(HloInstruction::CreateConstant(
    136           MakeUnique<Literal>(Literal::Zero(element_type))));
    137   return computation->AddInstruction(HloInstruction::CreatePad(
    138       ShapeInference::InferPadShape(
    139           /*operand_shape=*/kernel->shape(),
    140           /*padding_value_shape=*/ShapeUtil::MakeShape(element_type, {}),
    141           padding_config)
    142           .ConsumeValueOrDie(),
    143       kernel, padding, padding_config));
    144 }
    145 }  // namespace
    146 
    147 bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) {
    148   if (IsForwardConvolutionCanonical(*conv)) {
    149     return false;
    150   }
    151 
    152   // Insert slices and/or pads between the convolution and its input and/or
    153   // kernel operand.
    154   HloInstruction* new_input = MaybePaddedAndSlicedInput(
    155       conv->window(), conv->convolution_dimension_numbers(),
    156       conv->mutable_operand(0));
    157   HloInstruction* new_kernel =
    158       MaybePaddedKernel(conv->window(), conv->convolution_dimension_numbers(),
    159                         conv->mutable_operand(1));
    160 
    161   // Remove the padding from convolution's window field. These paddings are
    162   // made explicit with the inserted pads.
    163   Window new_conv_window = conv->window();
    164   for (size_t i = 0; i < new_conv_window.dimensions_size(); ++i) {
    165     WindowDimension* dim = new_conv_window.mutable_dimensions(i);
    166 
    167     // The size of the kernel may have changed so update the Window to match.
    168     dim->set_size(new_kernel->shape().dimensions(
    169         conv->convolution_dimension_numbers().kernel_spatial_dimensions(i)));
    170     dim->set_padding_low(0);
    171     dim->set_padding_high(0);
    172     dim->set_base_dilation(1);
    173     dim->set_window_dilation(1);
    174   }
    175 
    176   // The conv CustomCall returns a tuple (conv_result, scratch_buffer).  Extract
    177   // out the shape of conv_result.
    178   Shape old_conv_shape = conv->shape().tuple_shapes(0);
    179 
    180   VLOG(1) << "Canonicalizing forward conv";
    181   auto new_conv = CreateCudnnConvForward(old_conv_shape, new_input, new_kernel,
    182                                          new_conv_window,
    183                                          conv->convolution_dimension_numbers());
    184   VLOG(1) << "Replacing:\n  " << conv->ToString() << "\nwith:\n  "
    185           << new_conv->ToString();
    186   TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv));
    187   return true;
    188 }
    189 
    190 namespace {
    191 void IncreasePaddingLowBy(int64 delta, WindowDimension* window_dim) {
    192   window_dim->set_padding_low(window_dim->padding_low() + delta);
    193 }
    194 
    195 void IncreasePaddingHighBy(int64 delta, WindowDimension* window_dim) {
    196   window_dim->set_padding_high(window_dim->padding_high() + delta);
    197 }
    198 }  // namespace
    199 
    200 bool PadInsertion::CanonicalizeBackwardFilterConvolution(
    201     HloInstruction* backward_conv) {
    202   CHECK_EQ(backward_conv->custom_call_target(),
    203            kCudnnConvBackwardFilterCallTarget);
    204   if (window_util::HasSymmetricPadding(backward_conv->window())) {
    205     return false;
    206   }
    207 
    208   // A backward filter convolution with uneven padding can be canonicalized to
    209   // one with even padding by padding the activations (input) beforehand. For
    210   // example,
    211   //   BackwardFilterConv(ABCD, xyz, padding_low=1, padding_high=2)
    212   // is equivalent to
    213   //   ABCD0 = Pad(ABCD, padding_high=1)
    214   //   BackwardFilterConv(ABCD0, xyz, padding_low=pading_high=1)
    215   // We choose the lesser of padding_low and padding_high as the new padding.
    216   HloInstruction* input = backward_conv->mutable_operand(0);
    217   Window new_backward_conv_window = backward_conv->window();
    218   // input_padding_config is the config of the kPad to be inserted.
    219   PaddingConfig input_padding_config =
    220       MakeNoPaddingConfig(ShapeUtil::Rank(input->shape()));
    221   ConvolutionDimensionNumbers backward_conv_dnums =
    222       backward_conv->convolution_dimension_numbers();
    223   for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
    224     int64 padding_low = backward_conv->window().dimensions(i).padding_low();
    225     int64 padding_high = backward_conv->window().dimensions(i).padding_high();
    226     if (padding_low < 0 || padding_high < 0) {
    227       // TODO(b/32744257): The following canonicalization wouldn't remove
    228       // negative padding in a backward convolution, and would therefore cause
    229       // cuDNN convolution (which doesn't support negative padding) to fail.
    230       return false;
    231     }
    232     // Compute the new, even padding for the backward conv operation.
    233     int64 new_conv_padding = std::min(padding_low, padding_high);
    234     int64 dim = backward_conv_dnums.input_spatial_dimensions(i);
    235     input_padding_config.mutable_dimensions(dim)->set_edge_padding_low(
    236         padding_low - new_conv_padding);
    237     input_padding_config.mutable_dimensions(dim)->set_edge_padding_high(
    238         padding_high - new_conv_padding);
    239 
    240     // Since we move some padding from the backward convolution to the kPad, we
    241     // need to accordingly reduce the padding amount of the backward convolution
    242     // and its inner forward convolution.
    243     auto* new_dim = new_backward_conv_window.mutable_dimensions(i);
    244     new_dim->set_padding_low(new_conv_padding);
    245     new_dim->set_padding_high(new_conv_padding);
    246   }
    247 
    248   // Create a new backward convolution replacing the old one.
    249   HloComputation* computation = backward_conv->parent();
    250   HloInstruction* output = backward_conv->mutable_operand(1);
    251   HloInstruction* padding =
    252       computation->AddInstruction(HloInstruction::CreateConstant(
    253           MakeUnique<Literal>(Literal::Zero(input->shape().element_type()))));
    254   HloInstruction* padded_input =
    255       computation->AddInstruction(HloInstruction::CreatePad(
    256           ShapeInference::InferPadShape(input->shape(), padding->shape(),
    257                                         input_padding_config)
    258               .ConsumeValueOrDie(),
    259           input, padding, input_padding_config));
    260 
    261   // The shape of the backward_conv CustomCall is a tuple (conv_result,
    262   // scratch_buffer).  Extract out the shape of conv_result.
    263   Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0);
    264   HloInstruction* new_backward_conv = CreateCudnnConvBackwardFilter(
    265       backward_conv_shape, padded_input, output, new_backward_conv_window,
    266       backward_conv_dnums);
    267 
    268   VLOG(1) << "Canonicalizing backward filter conv";
    269   VLOG(1) << "Replacing:\n  " << backward_conv->ToString() << "\nwith:\n  "
    270           << new_backward_conv->ToString();
    271 
    272   TF_CHECK_OK(
    273       computation->ReplaceInstruction(backward_conv, new_backward_conv));
    274   return true;
    275 }
    276 
    277 bool PadInsertion::CanonicalizeBackwardInputConvolution(
    278     HloInstruction* backward_conv) {
    279   if (window_util::HasSymmetricPadding(backward_conv->window())) {
    280     return false;
    281   }
    282 
    283   Window new_backward_conv_window = backward_conv->window();
    284   ConvolutionDimensionNumbers backward_conv_dnums =
    285       backward_conv->convolution_dimension_numbers();
    286 
    287   // The backward_conv CustomCall returns a tuple (conv_result, scratch_memory).
    288   // Get the shape of conv_result.
    289   Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0);
    290 
    291   Shape new_backward_conv_shape = backward_conv_shape;
    292   for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
    293     int64 padding_low = backward_conv->window().dimensions(i).padding_low();
    294     int64 padding_high = backward_conv->window().dimensions(i).padding_high();
    295     if (padding_low < 0 || padding_high < 0) {
    296       // TODO(b/32744257): The following canonicalization wouldn't remove
    297       // negative padding in a backward convolution, and would therefore cause
    298       // cuDNN convolution (which doesn't support negative padding) to fail.
    299       return false;
    300     }
    301     // If the backward convolution has uneven padding on the activations, we
    302     // move some padding on the larger end to "internal" padding, so that the
    303     // backward convolution produces larger activations which get sliced later.
    304     //
    305     // For example, suppose we have a non-canonical HLO
    306     //   [A] = BackwardInputConvolve([a b], [x y z], padding=(low=2,high=1))
    307     // where the amount of padding low is larger, we can canonicalize it to
    308     //   [B A] = BackwardInputConvolve([a b], [x y z], padding=(low=1,high=1))
    309     //   [A] = Slice([B A])
    310     if (padding_low > padding_high) {
    311       IncreasePaddingLowBy(padding_high - padding_low,
    312                            new_backward_conv_window.mutable_dimensions(i));
    313     } else if (padding_low < padding_high) {
    314       IncreasePaddingHighBy(padding_low - padding_high,
    315                             new_backward_conv_window.mutable_dimensions(i));
    316     }
    317     // Decreasing the padding by X *increases* the size of our output by X.
    318     int64 dim = backward_conv_dnums.output_spatial_dimensions(i);
    319     new_backward_conv_shape.set_dimensions(
    320         dim, new_backward_conv_shape.dimensions(dim) +
    321                  std::abs(padding_low - padding_high));
    322   }
    323 
    324   // Create a new backward convolution replacing the old one.
    325   HloComputation* computation = backward_conv->parent();
    326   HloInstruction* output = backward_conv->mutable_operand(0);
    327   HloInstruction* filter = backward_conv->mutable_operand(1);
    328 
    329   HloInstruction* new_backward_conv_call = CreateCudnnConvBackwardInput(
    330       new_backward_conv_shape, output, filter, new_backward_conv_window,
    331       backward_conv_dnums);
    332 
    333   // The CustomCall created above returns a tuple (conv_result, scratch_memory).
    334   // Extract out the two elements.
    335   HloInstruction* new_backward_conv =
    336       computation->AddInstruction(HloInstruction::CreateGetTupleElement(
    337           new_backward_conv_shape, new_backward_conv_call, 0));
    338   HloInstruction* new_backward_conv_scratch =
    339       computation->AddInstruction(HloInstruction::CreateGetTupleElement(
    340           new_backward_conv_call->shape().tuple_shapes(1),
    341           new_backward_conv_call, 1));
    342 
    343   // Slice the new backward convolution.
    344   //
    345   // Initialize start_indices and limit_indices as no slicing.
    346   std::vector<int64> start_indices(new_backward_conv->shape().dimensions_size(),
    347                                    0LL);
    348   std::vector<int64> limit_indices(
    349       new_backward_conv->shape().dimensions().begin(),
    350       new_backward_conv->shape().dimensions().end());
    351   std::vector<int64> strides(new_backward_conv->shape().dimensions_size(), 1LL);
    352   for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
    353     int64 padding_low = backward_conv->window().dimensions(i).padding_low();
    354     int64 padding_high = backward_conv->window().dimensions(i).padding_high();
    355     int64 dim = backward_conv_dnums.output_spatial_dimensions(i);
    356     if (padding_low > padding_high) {
    357       // If the amount of low padding (of the old backward convolution) is
    358       // larger, we internally pad the low end of the activations and slice
    359       // internal padding out here.
    360       start_indices[dim] += padding_low - padding_high;
    361     } else if (padding_low < padding_high) {
    362       // If the amount of high padding is larger, we slice out the internal
    363       // padding on the high end.
    364       limit_indices[dim] -= padding_high - padding_low;
    365     }
    366   }
    367 
    368   // Replace the old backward convolution with the slice.
    369   Shape slice_shape =
    370       ShapeInference::InferSliceShape(new_backward_conv->shape(), start_indices,
    371                                       limit_indices, strides)
    372           .ConsumeValueOrDie();
    373   CHECK(ShapeUtil::Compatible(slice_shape, backward_conv_shape))
    374       << ShapeUtil::HumanString(slice_shape) << " vs "
    375       << ShapeUtil::HumanString(backward_conv_shape);
    376 
    377   HloInstruction* slice = computation->AddInstruction(
    378       HloInstruction::CreateSlice(backward_conv_shape, new_backward_conv,
    379                                   start_indices, limit_indices, strides));
    380   HloInstruction* new_tuple = computation->AddInstruction(
    381       HloInstruction::CreateTuple({slice, new_backward_conv_scratch}));
    382 
    383   VLOG(1) << "Canonicalizing backward input conv";
    384   VLOG(1) << "Replacing:\n  " << backward_conv->ToString() << "\nwith:\n  "
    385           << new_tuple->ToString();
    386 
    387   TF_CHECK_OK(computation->ReplaceInstruction(backward_conv, new_tuple));
    388   return true;
    389 }
    390 
    391 StatusOr<bool> PadInsertion::Run(HloModule* module) {
    392   bool changed = false;
    393   for (HloInstruction* instruction :
    394        module->entry_computation()->MakeInstructionPostOrder()) {
    395     if (IsCustomCallToDnnConvolution(*instruction)) {
    396       const auto& target = instruction->custom_call_target();
    397       if (target == kCudnnConvForwardCallTarget) {
    398         changed |= CanonicalizeForwardConvolution(instruction);
    399       } else if (target == kCudnnConvBackwardFilterCallTarget) {
    400         changed |= CanonicalizeBackwardFilterConvolution(instruction);
    401       } else if (target == kCudnnConvBackwardInputCallTarget) {
    402         changed |= CanonicalizeBackwardInputConvolution(instruction);
    403       } else {
    404         LOG(FATAL) << "Unknown custom call target for cudnn conv: "
    405                    << instruction->ToString();
    406       }
    407     }
    408   }
    409   return changed;
    410 }
    411 
    412 }  // namespace gpu
    413 }  // namespace xla
    414