Home | History | Annotate | Download | only in gpu
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h"
     17 
     18 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
     19 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     20 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     21 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
     22 #include "tensorflow/compiler/xla/service/hlo_module.h"
     23 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     24 #include "tensorflow/compiler/xla/service/shape_inference.h"
     25 #include "tensorflow/compiler/xla/test.h"
     26 #include "tensorflow/compiler/xla/test_helpers.h"
     27 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     28 #include "tensorflow/core/platform/test.h"
     29 
     30 namespace xla {
     31 namespace gpu {
     32 namespace {
     33 
     34 namespace op = xla::testing::opcode_matchers;
     35 
     36 class CudnnConvolutionRewriterTest : public HloTestBase {
     37  public:
     38   CudnnConvolutionRewriterTest() {
     39     for (int i = 0; i < 2; ++i) {
     40       WindowDimension* window_dim = default_conv_window_.add_dimensions();
     41       window_dim->set_size(1);
     42       window_dim->set_stride(1);
     43       window_dim->set_padding_low(0);
     44       window_dim->set_padding_high(0);
     45       window_dim->set_window_dilation(1);
     46       window_dim->set_base_dilation(1);
     47     }
     48     // TF data shapes are by default in the NHWC order, and filter shape is by
     49     // default in HWIO order. For backward filter convolution, we need to swap
     50     // the batch and feature dimension in the activations, and treat the batch
     51     // dimension in gradients as the input feature dimension in the filter.
     52     //
     53     // TODO(jingyue): Add more tests on NCHW input order, which TF also
     54     // supports.
     55     tf_default_dnums_for_backward_filter_.set_input_batch_dimension(3);
     56     tf_default_dnums_for_backward_filter_.set_input_feature_dimension(0);
     57     tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(1);
     58     tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(2);
     59     tf_default_dnums_for_backward_filter_.set_kernel_input_feature_dimension(0);
     60     tf_default_dnums_for_backward_filter_.set_kernel_output_feature_dimension(
     61         3);
     62     tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(1);
     63     tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(2);
     64     tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(0);
     65     tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(1);
     66     tf_default_dnums_for_backward_filter_.set_output_batch_dimension(2);
     67     tf_default_dnums_for_backward_filter_.set_output_feature_dimension(3);
     68 
     69     tf_default_dnums_for_backward_input_.set_input_batch_dimension(0);
     70     tf_default_dnums_for_backward_input_.set_output_batch_dimension(0);
     71     tf_default_dnums_for_backward_input_.set_input_feature_dimension(3);
     72     tf_default_dnums_for_backward_input_.set_output_feature_dimension(3);
     73     tf_default_dnums_for_backward_input_.add_input_spatial_dimensions(1);
     74     tf_default_dnums_for_backward_input_.add_output_spatial_dimensions(1);
     75     tf_default_dnums_for_backward_input_.add_input_spatial_dimensions(2);
     76     tf_default_dnums_for_backward_input_.add_output_spatial_dimensions(2);
     77     tf_default_dnums_for_backward_input_.set_kernel_input_feature_dimension(3);
     78     tf_default_dnums_for_backward_input_.set_kernel_output_feature_dimension(2);
     79     tf_default_dnums_for_backward_input_.add_kernel_spatial_dimensions(0);
     80     tf_default_dnums_for_backward_input_.add_kernel_spatial_dimensions(1);
     81   }
     82 
     83  protected:
     84   bool RunPass(HloModule* module) {
     85     return CudnnConvolutionRewriter().Run(module).ValueOrDie();
     86   }
     87 
     88   // A convolution window with stride 1 and zero padding. The size fields are
     89   // not set.
     90   Window default_conv_window_;
     91   ConvolutionDimensionNumbers tf_default_dnums_for_backward_filter_;
     92   ConvolutionDimensionNumbers tf_default_dnums_for_backward_input_;
     93 };
     94 
     95 TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolve) {
     96   HloComputation::Builder builder(TestName());
     97   HloInstruction* activations =
     98       builder.AddInstruction(HloInstruction::CreateParameter(
     99           0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "activations"));
    100   HloInstruction* gradients =
    101       builder.AddInstruction(HloInstruction::CreateParameter(
    102           1, ShapeUtil::MakeShape(F32, {1, 1, 2, 1}), "gradients"));
    103   Window conv_window = default_conv_window_;
    104   conv_window.mutable_dimensions(1)->set_size(2);
    105   conv_window.mutable_dimensions(1)->set_window_dilation(2);
    106   builder.AddInstruction(HloInstruction::CreateConvolve(
    107       ShapeInference::InferConvolveShape(activations->shape(),
    108                                          gradients->shape(), conv_window,
    109                                          tf_default_dnums_for_backward_filter_)
    110           .ConsumeValueOrDie(),
    111       activations, gradients, conv_window,
    112       tf_default_dnums_for_backward_filter_));
    113 
    114   auto module = CreateNewModule();
    115   HloComputation* entry_computation =
    116       module->AddEntryComputation(builder.Build());
    117   EXPECT_TRUE(RunPass(module.get()));
    118   EXPECT_THAT(entry_computation->root_instruction(),
    119               op::GetTupleElement(
    120                   op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
    121 }
    122 
    123 TEST_F(CudnnConvolutionRewriterTest,
    124        BackwardFilterConvolveEquivalentToForwardConvolution) {
    125   HloComputation::Builder builder(TestName());
    126   HloInstruction* activations =
    127       builder.AddInstruction(HloInstruction::CreateParameter(
    128           0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "activations"));
    129   HloInstruction* gradients =
    130       builder.AddInstruction(HloInstruction::CreateParameter(
    131           1, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "gradients"));
    132   Window conv_window = default_conv_window_;
    133   conv_window.mutable_dimensions(1)->set_size(3);
    134   builder.AddInstruction(HloInstruction::CreateConvolve(
    135       ShapeInference::InferConvolveShape(activations->shape(),
    136                                          gradients->shape(), conv_window,
    137                                          tf_default_dnums_for_backward_filter_)
    138           .ConsumeValueOrDie(),
    139       activations, gradients, conv_window,
    140       tf_default_dnums_for_backward_filter_));
    141 
    142   auto module = CreateNewModule();
    143   HloComputation* entry_computation =
    144       module->AddEntryComputation(builder.Build());
    145   EXPECT_TRUE(RunPass(module.get()));
    146   EXPECT_THAT(entry_computation->root_instruction(),
    147               op::GetTupleElement(
    148                   op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
    149 }
    150 
    151 // Extracted from block35 training.
    152 TEST_F(CudnnConvolutionRewriterTest,
    153        BackwardFilterConvolveWithPaddedActivations) {
    154   auto builder = HloComputation::Builder(TestName());
    155   HloInstruction* activations =
    156       builder.AddInstruction(HloInstruction::CreateParameter(
    157           0, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "activations"));
    158   HloInstruction* gradients =
    159       builder.AddInstruction(HloInstruction::CreateParameter(
    160           1, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "gradients"));
    161 
    162   Window conv_window = default_conv_window_;
    163   for (int i = 0; i < 2; ++i) {
    164     conv_window.mutable_dimensions(i)->set_size(35);
    165     conv_window.mutable_dimensions(i)->set_padding_low(1);
    166     conv_window.mutable_dimensions(i)->set_padding_high(1);
    167   }
    168   builder.AddInstruction(HloInstruction::CreateConvolve(
    169       ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients,
    170       conv_window, tf_default_dnums_for_backward_filter_));
    171 
    172   auto module = CreateNewModule();
    173   HloComputation* entry_computation =
    174       module->AddEntryComputation(builder.Build());
    175   EXPECT_TRUE(RunPass(module.get()));
    176   EXPECT_THAT(entry_computation->root_instruction(),
    177               op::GetTupleElement(
    178                   op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
    179 }
    180 
    181 // Extracted from inception v3 training.
    182 TEST_F(CudnnConvolutionRewriterTest,
    183        BackwardFilterConvolveWithPaddedGradients) {
    184   auto builder = HloComputation::Builder(TestName());
    185   HloInstruction* activations =
    186       builder.AddInstruction(HloInstruction::CreateParameter(
    187           0, ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), "activations"));
    188   HloInstruction* gradients =
    189       builder.AddInstruction(HloInstruction::CreateParameter(
    190           1, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "gradients"));
    191 
    192   Window conv_window = default_conv_window_;
    193   for (int i = 0; i < 2; ++i) {
    194     conv_window.mutable_dimensions(i)->set_size(4);
    195     conv_window.mutable_dimensions(i)->set_padding_high(-1);
    196     conv_window.mutable_dimensions(i)->set_window_dilation(2);
    197   }
    198   builder.AddInstruction(HloInstruction::CreateConvolve(
    199       ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients,
    200       conv_window, tf_default_dnums_for_backward_filter_));
    201 
    202   auto module = CreateNewModule();
    203   HloComputation* entry_computation =
    204       module->AddEntryComputation(builder.Build());
    205   EXPECT_TRUE(RunPass(module.get()));
    206   EXPECT_THAT(entry_computation->root_instruction(),
    207               op::GetTupleElement(
    208                   op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
    209 }
    210 
    211 TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolveWithUnevenPadding) {
    212   auto builder = HloComputation::Builder(TestName());
    213   HloInstruction* activations =
    214       builder.AddInstruction(HloInstruction::CreateParameter(
    215           0, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "activations"));
    216   HloInstruction* gradients =
    217       builder.AddInstruction(HloInstruction::CreateParameter(
    218           1, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "gradients"));
    219 
    220   Window conv_window = default_conv_window_;
    221   for (int i = 0; i < 2; ++i) {
    222     conv_window.mutable_dimensions(i)->set_size(35);
    223     // Uneven padding: padding_low=0, padding_high=1
    224     conv_window.mutable_dimensions(i)->set_padding_high(1);
    225   }
    226   builder.AddInstruction(HloInstruction::CreateConvolve(
    227       ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients,
    228       conv_window, tf_default_dnums_for_backward_filter_));
    229 
    230   auto module = CreateNewModule();
    231   HloComputation* entry_computation =
    232       module->AddEntryComputation(builder.Build());
    233   EXPECT_TRUE(RunPass(module.get()));
    234   EXPECT_THAT(entry_computation->root_instruction(),
    235               op::GetTupleElement(
    236                   op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
    237 }
    238 
    239 TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveEvenPadding) {
    240   auto builder = HloComputation::Builder(TestName());
    241   HloInstruction* output =
    242       builder.AddInstruction(HloInstruction::CreateParameter(
    243           0, ShapeUtil::MakeShape(F32, {4, 5, 16, 16}), "output"));
    244   HloInstruction* kernel =
    245       builder.AddInstruction(HloInstruction::CreateParameter(
    246           1, ShapeUtil::MakeShape(F32, {5, 3, 7, 7}), "kernel"));
    247   HloInstruction* reverse_kernel = builder.AddInstruction(
    248       HloInstruction::CreateReverse(kernel->shape(), kernel, {2, 3}));
    249 
    250   Window conv_window = default_conv_window_;
    251   for (int i = 0; i < 2; ++i) {
    252     conv_window.mutable_dimensions(i)->set_size(7);
    253     conv_window.mutable_dimensions(i)->set_padding_low(3);
    254     conv_window.mutable_dimensions(i)->set_padding_high(3);
    255   }
    256   ConvolutionDimensionNumbers conv_dnums;
    257   conv_dnums.set_input_batch_dimension(0);
    258   conv_dnums.set_output_batch_dimension(0);
    259   conv_dnums.set_input_feature_dimension(1);
    260   conv_dnums.set_output_feature_dimension(1);
    261   conv_dnums.add_input_spatial_dimensions(2);
    262   conv_dnums.add_output_spatial_dimensions(2);
    263   conv_dnums.add_input_spatial_dimensions(3);
    264   conv_dnums.add_output_spatial_dimensions(3);
    265   conv_dnums.set_kernel_input_feature_dimension(0);
    266   conv_dnums.set_kernel_output_feature_dimension(1);
    267   conv_dnums.add_kernel_spatial_dimensions(2);
    268   conv_dnums.add_kernel_spatial_dimensions(3);
    269 
    270   HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
    271       ShapeUtil::MakeShape(F32, {4, 3, 16, 16}), /*lhs=*/output,
    272       /*rhs=*/reverse_kernel, conv_window, conv_dnums));
    273   // Verify the convolution's shape is consistent with ShapeInference.
    274   CHECK(ShapeUtil::Compatible(
    275       conv->shape(),
    276       ShapeInference::InferConvolveShape(
    277           output->shape(), reverse_kernel->shape(), conv_window, conv_dnums)
    278           .ValueOrDie()));
    279 
    280   auto module = CreateNewModule();
    281   HloComputation* entry_computation =
    282       module->AddEntryComputation(builder.Build());
    283   EXPECT_TRUE(RunPass(module.get()));
    284 
    285   ASSERT_THAT(entry_computation->root_instruction(),
    286               op::GetTupleElement(
    287                   op::CustomCall(kCudnnConvBackwardInputCallTarget), 0));
    288   const HloInstruction* custom_call =
    289       entry_computation->root_instruction()->operand(0);
    290   for (int i = 0; i < 2; ++i) {
    291     const WindowDimension& window_dim = custom_call->window().dimensions(i);
    292     // Low padding of the backward input convolution
    293     //   = kernel_size - 1 - low padding on gradients.
    294     EXPECT_EQ(3, window_dim.padding_low());
    295     EXPECT_EQ(3, window_dim.padding_high());
    296     EXPECT_EQ(1, window_dim.stride());
    297   }
    298 }
    299 
    300 // Convolve([abc], [x], base_dilation=2)
    301 //   = Convolve([abc], Reverse([x]), base_dilation=2)
    302 //   = BackwardInputConvolve([abc], [x], stride=2)
    303 TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolve1x1Filter) {
    304   auto builder = HloComputation::Builder(TestName());
    305   // NHWC dimension order.
    306   HloInstruction* output =
    307       builder.AddInstruction(HloInstruction::CreateParameter(
    308           0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output"));
    309   // HWOI dimension order.
    310   HloInstruction* kernel =
    311       builder.AddInstruction(HloInstruction::CreateParameter(
    312           1, ShapeUtil::MakeShape(F32, {1, 1, 1, 1}), "kernel"));
    313 
    314   Window conv_window = default_conv_window_;
    315   conv_window.mutable_dimensions(1)->set_base_dilation(2);
    316 
    317   builder.AddInstruction(HloInstruction::CreateConvolve(
    318       ShapeInference::InferConvolveShape(output->shape(), kernel->shape(),
    319                                          conv_window,
    320                                          tf_default_dnums_for_backward_input_)
    321           .ConsumeValueOrDie(),
    322       /*lhs=*/output, /*rhs=*/kernel, conv_window,
    323       tf_default_dnums_for_backward_input_));
    324 
    325   auto module = CreateNewModule();
    326   HloComputation* entry_computation =
    327       module->AddEntryComputation(builder.Build());
    328   EXPECT_TRUE(RunPass(module.get()));
    329   EXPECT_THAT(entry_computation->root_instruction(),
    330               op::GetTupleElement(
    331                   op::CustomCall(kCudnnConvBackwardInputCallTarget), 0));
    332 }
    333 
    334 // BackwardInputConvolve([abc], [x], stride=1) is equivalent to
    335 // ForwardConvolve([abc], [x], stride=1). No need to fold it into backward input
    336 // convolution.
    337 TEST_F(CudnnConvolutionRewriterTest,
    338        BackwardInputConvolve1x1FilterEquivalentToForwardConvolve) {
    339   auto builder = HloComputation::Builder(TestName());
    340   // NHWC dimension order.
    341   HloInstruction* output =
    342       builder.AddInstruction(HloInstruction::CreateParameter(
    343           0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output"));
    344   // HWOI dimension order.
    345   HloInstruction* kernel =
    346       builder.AddInstruction(HloInstruction::CreateParameter(
    347           1, ShapeUtil::MakeShape(F32, {1, 1, 1, 1}), "kernel"));
    348 
    349   builder.AddInstruction(HloInstruction::CreateConvolve(
    350       ShapeInference::InferConvolveShape(output->shape(), kernel->shape(),
    351                                          default_conv_window_,
    352                                          tf_default_dnums_for_backward_input_)
    353           .ConsumeValueOrDie(),
    354       /*lhs=*/output, /*rhs=*/kernel, default_conv_window_,
    355       tf_default_dnums_for_backward_input_));
    356 
    357   auto module = CreateNewModule();
    358   HloComputation* entry_computation =
    359       module->AddEntryComputation(builder.Build());
    360   EXPECT_TRUE(RunPass(module.get()));
    361   EXPECT_THAT(
    362       entry_computation->root_instruction(),
    363       op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0));
    364 }
    365 
    366 // Extracted from Inception V3 training.
    367 //
    368 //                                  filter(HWIO)
    369 //                                  3x3x192x320
    370 //                                      |
    371 //                                      v
    372 //      gradients(NHWC)              reverse
    373 //        20x4x4x320               3x3x192x320
    374 //                    \            /
    375 //                     \          /
    376 //  conv (NHWC) with padding (low=2,high=3,interior=1)
    377 //                     20x10x10x192
    378 //
    379 // Gradients are padded unevenly.
    380 TEST_F(CudnnConvolutionRewriterTest,
    381        BackwardInputConvolveUnevenPaddingOnGradients) {
    382   auto builder = HloComputation::Builder(TestName());
    383   HloInstruction* output =
    384       builder.AddInstruction(HloInstruction::CreateParameter(
    385           0, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "output"));
    386   HloInstruction* kernel =
    387       builder.AddInstruction(HloInstruction::CreateParameter(
    388           1, ShapeUtil::MakeShape(F32, {3, 3, 192, 320}), "kernel"));
    389   HloInstruction* reverse_kernel = builder.AddInstruction(
    390       HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1}));
    391 
    392   Window conv_window = default_conv_window_;
    393   for (int i = 0; i < 2; ++i) {
    394     conv_window.mutable_dimensions(i)->set_size(3);
    395     conv_window.mutable_dimensions(i)->set_padding_low(2);
    396     conv_window.mutable_dimensions(i)->set_padding_high(3);
    397     // Interior padding = 1.
    398     conv_window.mutable_dimensions(i)->set_base_dilation(2);
    399   }
    400   HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
    401       ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel,
    402       conv_window, tf_default_dnums_for_backward_input_));
    403   // Verify the convolution's shape is consistent with ShapeInference.
    404   CHECK(ShapeUtil::Compatible(
    405       conv->shape(), ShapeInference::InferConvolveShape(
    406                          output->shape(), reverse_kernel->shape(), conv_window,
    407                          tf_default_dnums_for_backward_input_)
    408                          .ValueOrDie()));
    409 
    410   auto module = CreateNewModule();
    411   HloComputation* entry_computation =
    412       module->AddEntryComputation(builder.Build());
    413   EXPECT_TRUE(RunPass(module.get()));
    414   ASSERT_THAT(entry_computation->root_instruction(),
    415               op::GetTupleElement(
    416                   op::CustomCall(kCudnnConvBackwardInputCallTarget), 0));
    417   const HloInstruction* custom_call =
    418       entry_computation->root_instruction()->operand(0);
    419   for (int i = 0; i < 2; ++i) {
    420     const WindowDimension& window_dim = custom_call->window().dimensions(i);
    421     EXPECT_EQ(0, window_dim.padding_low());
    422     EXPECT_EQ(0, window_dim.padding_high());
    423     EXPECT_EQ(2, window_dim.stride());
    424   }
    425 }
    426 
    427 // Similar to BackwardInputConvolveUnevenPadding, but the low padding of the
    428 // gradients exceeds kernel_size - 1. Therefore, this pattern cannot be fused.
    429 TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveLowPaddingTooLarge) {
    430   auto builder = HloComputation::Builder(TestName());
    431   HloInstruction* output =
    432       builder.AddInstruction(HloInstruction::CreateParameter(
    433           0, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "output"));
    434   HloInstruction* kernel =
    435       builder.AddInstruction(HloInstruction::CreateParameter(
    436           1, ShapeUtil::MakeShape(F32, {3, 3, 192, 320}), "kernel"));
    437   HloInstruction* reverse_kernel = builder.AddInstruction(
    438       HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1}));
    439 
    440   Window conv_window = default_conv_window_;
    441   for (int i = 0; i < 2; ++i) {
    442     conv_window.mutable_dimensions(i)->set_size(3);
    443     conv_window.mutable_dimensions(i)->set_padding_low(3);
    444     conv_window.mutable_dimensions(i)->set_padding_high(2);
    445     conv_window.mutable_dimensions(i)->set_base_dilation(2);
    446   }
    447   HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
    448       ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel,
    449       conv_window, tf_default_dnums_for_backward_input_));
    450   // Verify the convolution's shape is consistent with ShapeInference.
    451   CHECK(ShapeUtil::Compatible(
    452       conv->shape(), ShapeInference::InferConvolveShape(
    453                          output->shape(), reverse_kernel->shape(), conv_window,
    454                          tf_default_dnums_for_backward_input_)
    455                          .ValueOrDie()));
    456 
    457   auto module = CreateNewModule();
    458   HloComputation* entry_computation =
    459       module->AddEntryComputation(builder.Build());
    460   EXPECT_TRUE(RunPass(module.get()));
    461   EXPECT_THAT(
    462       entry_computation->root_instruction(),
    463       op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0));
    464 }
    465 
    466 // Extracted from //learning/brain/google/xla/benchmarks/resnet.py
    467 //
    468 // For simplicity, we focus on the column dimension and ignore other dimensions.
    469 // We use [?] to represent the shape instead of the content.
    470 //
    471 // Suppose operator FC does
    472 //   [4] = conv([14], [3], stride=2, padding_high=1)  // Padding::kSame
    473 //
    474 // BC = BackwardInput(FC) does:
    475 //   [14] = conv([7], reverse([3]),
    476 //               padding_low=2, padding_high=1, base_dilation=2)
    477 //
    478 // We should fuse BC even though padding on activations is uneven, because
    479 // PadInsertion will canonicalize the fusion HLO.
    480 TEST_F(CudnnConvolutionRewriterTest,
    481        BackwardInputConvolveUnevenPaddingOnActivations) {
    482   auto builder = HloComputation::Builder(TestName());
    483   // The gradients are in NCHW layout.
    484   HloInstruction* output =
    485       builder.AddInstruction(HloInstruction::CreateParameter(
    486           0, ShapeUtil::MakeShape(F32, {1, 1, 7, 1}), "output"));
    487   // The kernel is in HWIO layout.
    488   HloInstruction* kernel =
    489       builder.AddInstruction(HloInstruction::CreateParameter(
    490           1, ShapeUtil::MakeShape(F32, {1, 3, 1, 1}), "kernel"));
    491   HloInstruction* reverse_kernel = builder.AddInstruction(
    492       HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1}));
    493 
    494   Window conv_window = default_conv_window_;
    495   WindowDimension* forward_conv_col_dim = conv_window.mutable_dimensions(1);
    496   forward_conv_col_dim->set_size(3);
    497   forward_conv_col_dim->set_padding_low(2);
    498   forward_conv_col_dim->set_padding_high(1);
    499   forward_conv_col_dim->set_base_dilation(2);
    500   HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
    501       ShapeUtil::MakeShape(F32, {1, 1, 14, 1}), output, reverse_kernel,
    502       conv_window, tf_default_dnums_for_backward_input_));
    503   // Verify the convolution's shape is consistent with ShapeInference.
    504   CHECK(ShapeUtil::Compatible(
    505       conv->shape(), ShapeInference::InferConvolveShape(
    506                          output->shape(), reverse_kernel->shape(), conv_window,
    507                          tf_default_dnums_for_backward_input_)
    508                          .ValueOrDie()));
    509 
    510   auto module = CreateNewModule();
    511   const HloComputation* entry_computation =
    512       module->AddEntryComputation(builder.Build());
    513   EXPECT_TRUE(RunPass(module.get()));
    514   ASSERT_THAT(entry_computation->root_instruction(),
    515               op::GetTupleElement(
    516                   op::CustomCall(kCudnnConvBackwardInputCallTarget), 0));
    517   const WindowDimension& backward_conv_col_dim =
    518       entry_computation->root_instruction()->operand(0)->window().dimensions(1);
    519   EXPECT_EQ(0, backward_conv_col_dim.padding_low());
    520   EXPECT_EQ(1, backward_conv_col_dim.padding_high());
    521 }
    522 
    523 // For simplicity, we focus on the column dimension and ignore other dimensions.
    524 // We use [?] to represent the shape instead of the content.
    525 //
    526 // Suppose operator FC does
    527 //   [3] = conv([4], [2], padding_low=1, padding_high=-1)
    528 //
    529 // BC = BackwardInput(FC) does:
    530 //   [4] = conv([3], reverse([2]), padding_high=2)
    531 //
    532 // We currently don't fuse BC because PadInsertion doesn't support negative
    533 // padding on the gradients of backward convolution (b/32744257).
    534 TEST_F(CudnnConvolutionRewriterTest,
    535        BackwardInputConvolveNegativePaddingHighOnActivations) {
    536   auto builder = HloComputation::Builder(TestName());
    537   // The gradients are in NCHW layout.
    538   HloInstruction* output =
    539       builder.AddInstruction(HloInstruction::CreateParameter(
    540           0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output"));
    541   // The kernel is in HWIO layout.
    542   HloInstruction* kernel =
    543       builder.AddInstruction(HloInstruction::CreateParameter(
    544           1, ShapeUtil::MakeShape(F32, {1, 2, 1, 1}), "kernel"));
    545   HloInstruction* reverse_kernel = builder.AddInstruction(
    546       HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1}));
    547 
    548   Window conv_window = default_conv_window_;
    549   WindowDimension* forward_conv_col_dim = conv_window.mutable_dimensions(1);
    550   forward_conv_col_dim->set_size(2);
    551   forward_conv_col_dim->set_padding_high(2);
    552   HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
    553       ShapeUtil::MakeShape(F32, {1, 1, 4, 1}), output, reverse_kernel,
    554       conv_window, tf_default_dnums_for_backward_input_));
    555   // Verify the convolution's shape is consistent with ShapeInference.
    556   CHECK(ShapeUtil::Compatible(
    557       conv->shape(), ShapeInference::InferConvolveShape(
    558                          output->shape(), reverse_kernel->shape(), conv_window,
    559                          tf_default_dnums_for_backward_input_)
    560                          .ValueOrDie()));
    561 
    562   auto module = CreateNewModule();
    563   HloComputation* entry_computation =
    564       module->AddEntryComputation(builder.Build());
    565   EXPECT_TRUE(RunPass(module.get()));
    566   EXPECT_THAT(
    567       entry_computation->root_instruction(),
    568       op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0));
    569 }
    570 
    571 }  // anonymous namespace
    572 }  // namespace gpu
    573 }  // namespace xla
    574