Home | History | Annotate | Download | only in lib
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/client/lib/pooling.h"
     17 #include "absl/container/inlined_vector.h"
     18 #include "tensorflow/compiler/xla/test.h"
     19 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
     20 #include "tensorflow/compiler/xla/tests/test_macros.h"
     21 
     22 namespace xla {
     23 namespace {
     24 
     25 TensorFormat MakeNCHWFormat(int num_spatial_dims) {
     26   absl::InlinedVector<int64, 4> spatial_dimensions;
     27   for (int i = 0; i < num_spatial_dims; ++i) {
     28     spatial_dimensions.push_back(i + 2);
     29   }
     30   return TensorFormat(/*batch_dimension=*/0, /*feature_dimension=*/1,
     31                       /*spatial_dimensions=*/spatial_dimensions);
     32 }
     33 
     34 std::vector<std::pair<int64, int64>> MakeGeneralPadding(
     35     XlaOp input, absl::Span<const int64> kernel_size,
     36     absl::Span<const int64> stride, Padding padding,
     37     const xla::TensorFormat& data_format) {
     38   XlaBuilder* b = input.builder();
     39   Shape operand_shape = b->GetShape(input).ValueOrDie();
     40   std::vector<int64> input_size(operand_shape.dimensions().begin(),
     41                                 operand_shape.dimensions().end());
     42   return MakeSpatialPadding(input_size, kernel_size, stride, padding,
     43                             data_format);
     44 }
     45 
     46 // Add singleton batch and feature dimensions to spatial dimensions, according
     47 // to 'data_format' specification.
     48 std::vector<int64> ExpandWithBatchAndFeatureDimensions(
     49     absl::Span<const int64> spatial_dim_sizes,
     50     const xla::TensorFormat& data_format) {
     51   const int num_spatial_dims = spatial_dim_sizes.size();
     52   std::vector<int64> tensor_sizes(num_spatial_dims + 2, 1);
     53   for (int i = 0; i < num_spatial_dims; ++i) {
     54     int dim = data_format.spatial_dimension(i);
     55     tensor_sizes[dim] = spatial_dim_sizes[i];
     56   }
     57   return tensor_sizes;
     58 }
     59 
     60 class PoolingTest : public ClientLibraryTestBase {
     61  public:
     62   ErrorSpec error_spec_{0.0001};
     63 };
     64 
     65 XLA_TEST_F(PoolingTest, MaxPool2D) {
     66   XlaBuilder builder(TestName());
     67 
     68   XlaOp input = ConstantR4FromArray4D<float>(
     69       &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}});
     70   auto data_format = MakeNCHWFormat(2);
     71   auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
     72   auto stride = kernel_size;
     73   MaxPool(input, kernel_size, stride, Padding::kValid, data_format);
     74 
     75   ComputeAndCompareR4<float>(&builder, {{{{5, 4}}}}, {}, error_spec_);
     76 }
     77 
     78 XLA_TEST_F(PoolingTest, MaxPool2DWithPadding) {
     79   XlaBuilder builder(TestName());
     80 
     81   XlaOp input = ConstantR4FromArray4D<float>(
     82       &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}});
     83   auto data_format = MakeNCHWFormat(2);
     84   auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
     85   auto stride = kernel_size;
     86   MaxPool(input, kernel_size, stride, Padding::kSame, data_format);
     87 
     88   ComputeAndCompareR4<float>(&builder, {{{{5, 4, 5}}}}, {}, error_spec_);
     89 }
     90 
     91 XLA_TEST_F(PoolingTest, MaxPool2DWithPaddingAndStride) {
     92   XlaBuilder builder(TestName());
     93 
     94   XlaOp input = ConstantR4FromArray4D<float>(
     95       &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}});
     96   auto data_format = MakeNCHWFormat(2);
     97   auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
     98   auto stride = ExpandWithBatchAndFeatureDimensions({1, 1}, data_format);
     99   MaxPool(input, kernel_size, stride, Padding::kSame, data_format);
    100 
    101   ComputeAndCompareR4<float>(&builder, {{{{5, 4, 4, 5, 5}, {5, 4, 3, 2, 1}}}},
    102                              {}, error_spec_);
    103 }
    104 
    105 XLA_TEST_F(PoolingTest, AvgPool2D) {
    106   XlaBuilder builder(TestName());
    107 
    108   XlaOp input = ConstantR4FromArray4D<float>(
    109       &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}});
    110   auto data_format = MakeNCHWFormat(2);
    111   auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
    112   auto stride = kernel_size;
    113   auto padding = MakeGeneralPadding(input, kernel_size, stride, Padding::kValid,
    114                                     data_format);
    115   AvgPool(input, kernel_size, stride, padding, data_format,
    116           /*counts_include_padding=*/true);
    117 
    118   ComputeAndCompareR4<float>(&builder, {{{{3, 3}}}}, {}, error_spec_);
    119 }
    120 
    121 XLA_TEST_F(PoolingTest, AvgPool2DWithPadding) {
    122   XlaBuilder builder(TestName());
    123 
    124   XlaOp input = ConstantR4FromArray4D<float>(
    125       &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}});
    126   auto data_format = MakeNCHWFormat(2);
    127   auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
    128   auto stride = kernel_size;
    129   auto padding = MakeGeneralPadding(input, kernel_size, stride, Padding::kSame,
    130                                     data_format);
    131   AvgPool(input, kernel_size, stride, padding, data_format,
    132           /*counts_include_padding=*/false);
    133 
    134   ComputeAndCompareR4<float>(&builder, {{{{3, 3, 3}}}}, {}, error_spec_);
    135 }
    136 
    137 XLA_TEST_F(PoolingTest, AvgPool2DWithPaddingAndStride) {
    138   XlaBuilder builder(TestName());
    139 
    140   XlaOp input = ConstantR4FromArray4D<float>(
    141       &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}});
    142   auto data_format = MakeNCHWFormat(2);
    143   auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
    144   auto stride = ExpandWithBatchAndFeatureDimensions({1, 1}, data_format);
    145   auto padding = MakeGeneralPadding(input, kernel_size, stride, Padding::kSame,
    146                                     data_format);
    147   AvgPool(input, kernel_size, stride, padding, data_format,
    148           /*counts_include_padding=*/false);
    149 
    150   ComputeAndCompareR4<float>(&builder,
    151                              {{{{3, 3, 3, 3, 3}, {4.5, 3.5, 2.5, 1.5, 1}}}}, {},
    152                              error_spec_);
    153 }
    154 
    155 XLA_TEST_F(PoolingTest, AvgPool2DWithGeneralPaddingCountNotIncludePadding) {
    156   XlaBuilder builder(TestName());
    157 
    158   XlaOp input = ConstantR4FromArray4D<float>(
    159       &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}});
    160   auto data_format = MakeNCHWFormat(2);
    161   auto kernel_size = ExpandWithBatchAndFeatureDimensions({3, 3}, data_format);
    162   auto stride = kernel_size;
    163   AvgPool(input, kernel_size, stride, {{1, 1}, {2, 1}}, data_format,
    164           /*counts_include_padding=*/false);
    165 
    166   ComputeAndCompareR4<float>(&builder, {{{{3, 3}}}}, {}, error_spec_);
    167 }
    168 
    169 XLA_TEST_F(PoolingTest,
    170            AvgPool2DWithGeneralPaddingCountNotIncludePaddingAndStride) {
    171   XlaBuilder builder(TestName());
    172 
    173   XlaOp input = ConstantR4FromArray4D<float>(
    174       &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}});
    175   auto data_format = MakeNCHWFormat(2);
    176   auto kernel_size = ExpandWithBatchAndFeatureDimensions({3, 3}, data_format);
    177   auto stride = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
    178   AvgPool(input, kernel_size, stride, {{2, 1}, {1, 1}}, data_format,
    179           /*counts_include_padding=*/false);
    180 
    181   ComputeAndCompareR4<float>(&builder, {{{{1.5, 3, 4.5}, {3, 3, 3}}}}, {},
    182                              error_spec_);
    183 }
    184 
    185 XLA_TEST_F(PoolingTest, AvgPool2DGradNoPadding) {
    186   XlaBuilder builder(TestName());
    187   for (bool counts_include_padding : {false, true}) {
    188     XlaOp out_backprop = ConstantR4FromArray4D<float>(&builder, {{{{1.}}}});
    189     auto data_format = MakeNCHWFormat(2);
    190     auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
    191     auto stride = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
    192     AvgPoolGrad(out_backprop, {1, 1, 3, 3}, kernel_size, stride,
    193                 {{0, 0}, {0, 0}}, MakeNCHWFormat(2),
    194                 /*counts_include_padding=*/counts_include_padding);
    195     // Without padding, counts_include_padding makes no difference.
    196     ComputeAndCompareR4<float>(
    197         &builder, {{{{0.25, 0.25, 0.}, {0.25, 0.25, 0.}, {0., 0., 0.}}}}, {},
    198         error_spec_);
    199   }
    200 }
    201 
    202 XLA_TEST_F(PoolingTest, AvgPool2DGradNoPaddingWithStride) {
    203   XlaBuilder builder(TestName());
    204   for (bool counts_include_padding : {false, true}) {
    205     XlaOp out_backprop =
    206         ConstantR4FromArray4D<float>(&builder, {{{{1., 1.}, {1., 1.}}}});
    207     auto data_format = MakeNCHWFormat(2);
    208     auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
    209     auto stride = ExpandWithBatchAndFeatureDimensions({1, 1}, data_format);
    210     AvgPoolGrad(out_backprop, {1, 1, 3, 3}, kernel_size, stride,
    211                 {{0, 0}, {0, 0}}, MakeNCHWFormat(2),
    212                 /*counts_include_padding=*/counts_include_padding);
    213     // Without padding, counts_include_padding makes no difference.
    214     ComputeAndCompareR4<float>(
    215         &builder, {{{{0.25, 0.5, 0.25}, {0.5, 1., 0.5}, {0.25, 0.5, 0.25}}}},
    216         {}, error_spec_);
    217   }
    218 }
    219 
    220 XLA_TEST_F(PoolingTest, AvgPool2DGradWithPadding) {
    221   XlaBuilder builder(TestName());
    222 
    223   XlaOp out_backprop =
    224       ConstantR4FromArray4D<float>(&builder, {{{{1., 1.}, {1., 1.}}}});
    225   auto data_format = MakeNCHWFormat(2);
    226   auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
    227   auto stride = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
    228   AvgPoolGrad(out_backprop, {1, 1, 3, 3}, kernel_size, stride, {{1, 1}, {1, 1}},
    229               MakeNCHWFormat(2),
    230               /*counts_include_padding=*/true);
    231   ComputeAndCompareR4<float>(
    232       &builder,
    233       {{{{0.25, 0.25, 0.25}, {0.25, 0.25, 0.25}, {0.25, 0.25, 0.25}}}}, {},
    234       error_spec_);
    235 }
    236 
    237 XLA_TEST_F(PoolingTest, AvgPool2DGradWithPaddingCountNotIncludePadding) {
    238   XlaBuilder builder(TestName());
    239 
    240   XlaOp out_backprop =
    241       ConstantR4FromArray4D<float>(&builder, {{{{1., 1.}, {1., 1.}}}});
    242   auto data_format = MakeNCHWFormat(2);
    243   auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
    244   auto stride = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
    245   AvgPoolGrad(out_backprop, {1, 1, 3, 3}, kernel_size, stride, {{1, 1}, {1, 1}},
    246               MakeNCHWFormat(2), false);
    247   ComputeAndCompareR4<float>(
    248       &builder, {{{{1., 0.5, 0.5}, {0.5, 0.25, 0.25}, {0.5, 0.25, 0.25}}}}, {},
    249       error_spec_);
    250 }
    251 
    252 XLA_TEST_F(PoolingTest, AvgPool2DGradWithPaddingCountWithStride) {
    253   XlaBuilder builder(TestName());
    254 
    255   XlaOp out_backprop =
    256       ConstantR4FromArray4D<float>(&builder, {{{{1., 1., 1., 1.},
    257                                                 {1., 1., 1., 1.},
    258                                                 {1., 1., 1., 1.},
    259                                                 {1., 1., 1., 1.}}}});
    260   auto data_format = MakeNCHWFormat(2);
    261   auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
    262   auto stride = ExpandWithBatchAndFeatureDimensions({1, 1}, data_format);
    263   AvgPoolGrad(out_backprop, {1, 1, 3, 3}, kernel_size, stride, {{1, 1}, {1, 1}},
    264               MakeNCHWFormat(2), true);
    265   ComputeAndCompareR4<float>(&builder,
    266                              {{{{1., 1., 1.}, {1., 1., 1.}, {1., 1., 1.}}}}, {},
    267                              error_spec_);
    268 }
    269 
    270 XLA_TEST_F(PoolingTest,
    271            AvgPool2DGradWithPaddingCountWithStrideNotIncludePadding) {
    272   XlaBuilder builder(TestName());
    273 
    274   XlaOp out_backprop =
    275       ConstantR4FromArray4D<float>(&builder, {{{{1., 1., 1., 1.},
    276                                                 {1., 1., 1., 1.},
    277                                                 {1., 1., 1., 1.},
    278                                                 {1., 1., 1., 1.}}}});
    279   auto data_format = MakeNCHWFormat(2);
    280   auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
    281   auto stride = ExpandWithBatchAndFeatureDimensions({1, 1}, data_format);
    282   AvgPoolGrad(out_backprop, {1, 1, 3, 3}, kernel_size, stride, {{1, 1}, {1, 1}},
    283               MakeNCHWFormat(2), false);
    284   ComputeAndCompareR4<float>(
    285       &builder, {{{{2.25, 1.5, 2.25}, {1.5, 1., 1.5}, {2.25, 1.5, 2.25}}}}, {},
    286       error_spec_);
    287 }
    288 
    289 }  // namespace
    290 }  // namespace xla
    291