Home | History | Annotate | Download | only in lib
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/client/lib/pooling.h"
     17 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
     18 #include "tensorflow/compiler/xla/client/lib/constants.h"
     19 #include "tensorflow/compiler/xla/client/lib/conv_grad_size_util.h"
     20 
     21 namespace xla {
     22 
     23 namespace {
     24 
     25 // Common computation shared between AvgPool and AvgPoolGrad. Divide each
     26 // element of an image by the count of elements that contributed to that
     27 // element during pooling.
     28 XlaOp AvgPoolDivideByCountWithGeneralPadding(
     29     XlaOp sums, PrimitiveType dtype, absl::Span<const int64> input_shape,
     30     absl::Span<const std::pair<int64, int64>> spatial_padding,
     31     absl::Span<const int64> ksize, absl::Span<const int64> stride,
     32     const TensorFormat& data_format) {
     33   // The padding shouldn't be included in the counts. We use another
     34   // ReduceWindow to find the right counts.
     35   const int num_spatial_dims = spatial_padding.size();
     36 
     37   std::vector<int64> input_dim_sizes(num_spatial_dims);
     38   std::vector<int64> window_dims(num_spatial_dims);
     39   std::vector<int64> window_ksize(num_spatial_dims);
     40   std::vector<int64> window_stride(num_spatial_dims);
     41   CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims)
     42       << "Invalid number of spatial dimentions in data format specification";
     43   for (int i = 0; i < num_spatial_dims; ++i) {
     44     int dim = data_format.spatial_dimension(i);
     45     input_dim_sizes[i] = input_shape[dim];
     46     window_dims[i] = dim;
     47     window_ksize[i] = ksize[dim];
     48     window_stride[i] = stride[dim];
     49   }
     50 
     51   XlaBuilder* b = sums.builder();
     52   // Build a matrix of all 1s, with the same width/height as the input.
     53   auto ones = Broadcast(One(b, dtype), input_dim_sizes);
     54   PaddingConfig padding_config;
     55   for (int i = 0; i < num_spatial_dims; ++i) {
     56     auto dims = padding_config.add_dimensions();
     57     dims->set_edge_padding_low(spatial_padding[i].first);
     58     dims->set_edge_padding_high(spatial_padding[i].second);
     59   }
     60   auto zero = Zero(b, dtype);
     61   auto padded_ones = Pad(ones, zero, padding_config);
     62 
     63   // Perform a ReduceWindow with the same window size, strides, and padding
     64   // to count the number of contributions to each result element.
     65   auto counts =
     66       ReduceWindow(padded_ones, zero, CreateScalarAddComputation(dtype, b),
     67                    window_ksize, window_stride, Padding::kValid);
     68 
     69   return Div(sums, counts, window_dims);
     70 }
     71 
     72 // Sums all elements in the window specified by 'kernel_size' and 'stride'.
     73 XlaOp ComputeSums(XlaOp operand, XlaOp init_value,
     74                   absl::Span<const int64> kernel_size,
     75                   absl::Span<const int64> stride,
     76                   const TensorFormat& data_format) {
     77   XlaBuilder* b = operand.builder();
     78   return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
     79     TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand));
     80     TF_ASSIGN_OR_RETURN(Shape init_shape, b->GetShape(init_value));
     81     PrimitiveType accumulation_type = init_shape.element_type();
     82     auto add_computation = CreateScalarAddComputation(accumulation_type, b);
     83     return ReduceWindow(operand, init_value, add_computation, kernel_size,
     84                         stride, Padding::kValid);
     85   });
     86 }
     87 
     88 // Creates a padding configuration out of spatial padding values.
     89 PaddingConfig MakeSpatialPaddingConfig(
     90     absl::Span<const std::pair<int64, int64>> spatial_padding,
     91     int num_spatial_dims, absl::Span<const int64> stride,
     92     const TensorFormat& data_format) {
     93   PaddingConfig padding_config;
     94   for (int i = 0; i < 2 + num_spatial_dims; ++i) {
     95     padding_config.add_dimensions();
     96   }
     97   CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims)
     98       << "Invalid number of spatial dimentions in data format specification";
     99   for (int i = 0; i < num_spatial_dims; ++i) {
    100     int dim = data_format.spatial_dimension(i);
    101     auto padding_dimension = padding_config.mutable_dimensions(dim);
    102     padding_dimension->set_edge_padding_low(spatial_padding[i].first);
    103     padding_dimension->set_edge_padding_high(spatial_padding[i].second);
    104   }
    105   return padding_config;
    106 }
    107 
    108 XlaOp AvgPoolDivideByCount(XlaOp pooled, absl::Span<const int64> input_size,
    109                            absl::Span<const int64> window_dimensions,
    110                            absl::Span<const int64> window_strides,
    111                            absl::Span<const std::pair<int64, int64>> padding,
    112                            PrimitiveType dtype, const TensorFormat& data_format,
    113                            bool counts_include_padding) {
    114   if (counts_include_padding) {
    115     // If counts include padding, all windows have the same number of elements
    116     // contributing to each average. Divide by the window size everywhere to get
    117     // the average.
    118     int64 window_size =
    119         std::accumulate(window_dimensions.begin(), window_dimensions.end(), 1,
    120                         [](int64 a, int64 b) { return a * b; });
    121     auto divisor = ConstantR0WithType(pooled.builder(), dtype, window_size);
    122 
    123     return pooled / divisor;
    124   } else {
    125     return AvgPoolDivideByCountWithGeneralPadding(pooled, dtype, input_size,
    126                                                   padding, window_dimensions,
    127                                                   window_strides, data_format);
    128   }
    129 }
    130 
    131 }  // namespace
    132 
    133 XlaOp MaxPool(XlaOp operand, absl::Span<const int64> kernel_size,
    134               absl::Span<const int64> stride, Padding padding,
    135               const TensorFormat& data_format) {
    136   XlaBuilder* b = operand.builder();
    137   return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
    138     TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand));
    139     PrimitiveType dtype = operand_shape.element_type();
    140     auto max_computation = CreateScalarMaxComputation(dtype, b);
    141     auto init_value = MinValue(b, dtype);
    142     return ReduceWindow(operand, init_value, max_computation, kernel_size,
    143                         stride, padding);
    144   });
    145 }
    146 
    147 XlaOp AvgPool(XlaOp operand, absl::Span<const int64> kernel_size,
    148               absl::Span<const int64> stride,
    149               absl::Span<const std::pair<int64, int64>> padding,
    150               const TensorFormat& data_format,
    151               const bool counts_include_padding) {
    152   XlaBuilder* b = operand.builder();
    153   return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
    154     TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand));
    155     PrimitiveType dtype = operand_shape.element_type();
    156     auto init_value = Zero(b, dtype);
    157     std::vector<int64> input_size(operand_shape.dimensions().begin(),
    158                                   operand_shape.dimensions().end());
    159     const int num_dims = kernel_size.size();
    160     const int num_spatial_dims = num_dims - 2;
    161     auto padding_config = MakeSpatialPaddingConfig(padding, num_spatial_dims,
    162                                                    stride, data_format);
    163     auto padded_operand = Pad(operand, Zero(b, dtype), padding_config);
    164     auto pooled = ComputeSums(padded_operand, init_value, kernel_size, stride,
    165                               data_format);
    166     return AvgPoolDivideByCount(pooled, input_size, kernel_size, stride,
    167                                 padding, dtype, data_format,
    168                                 counts_include_padding);
    169   });
    170 }
    171 
    172 std::vector<std::pair<int64, int64>> MakeSpatialPadding(
    173     absl::Span<const int64> input_size, absl::Span<const int64> kernel_size,
    174     absl::Span<const int64> stride, Padding padding,
    175     const TensorFormat& data_format) {
    176   const int num_spatial_dims = kernel_size.size() - 2;
    177   std::vector<int64> input_spatial_dimensions;
    178   std::vector<int64> kernel_size_spatial_dimensions;
    179   std::vector<int64> stride_spatial_dimensions;
    180   CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims)
    181       << "Invalid number of spatial dimentions in data format specification";
    182   for (int i = 0; i < num_spatial_dims; ++i) {
    183     int dim = data_format.spatial_dimension(i);
    184     input_spatial_dimensions.push_back(input_size[dim]);
    185     kernel_size_spatial_dimensions.push_back(kernel_size[dim]);
    186     stride_spatial_dimensions.push_back(stride[dim]);
    187   }
    188   return MakePadding(input_spatial_dimensions, kernel_size_spatial_dimensions,
    189                      stride_spatial_dimensions, padding);
    190 }
    191 
    192 XlaOp AvgPoolGrad(XlaOp out_backprop, absl::Span<const int64> gradients_size,
    193                   absl::Span<const int64> kernel_size,
    194                   absl::Span<const int64> stride,
    195                   absl::Span<const std::pair<int64, int64>> spatial_padding,
    196                   const TensorFormat& data_format,
    197                   const bool counts_include_padding) {
    198   XlaBuilder* b = out_backprop.builder();
    199   return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
    200     const int num_dims = kernel_size.size();
    201 
    202     if (gradients_size.size() != num_dims) {
    203       return tensorflow::errors::InvalidArgument("gradients must be ", num_dims,
    204                                                  "-dimensional");
    205     }
    206 
    207     TF_ASSIGN_OR_RETURN(Shape out_backprop_xla_shape,
    208                         b->GetShape(out_backprop));
    209     if (out_backprop_xla_shape.dimensions().size() != num_dims) {
    210       return tensorflow::errors::InvalidArgument("out_backprop must be ",
    211                                                  num_dims, "-dimensional");
    212     }
    213 
    214     // We can think of average-pooling as:
    215     // * a convolution with a kernel consisting entirely of 1s, where the
    216     //   input feature and output feature are equal, and 0s everywhere else.
    217     // * followed by dividing by the counts.
    218     //
    219     // This then gives us an algorithm to build the gradient:
    220     // * divide out_backprop by the counts, followed by
    221     // * Conv2DBackpropInput specialized for that kernel, which simplifies to
    222     //   a Pad and a ReduceWindow.
    223     //
    224     // For an explanation of backpropagation for convolution, see the comments
    225     // in third_party/tensorflow/core/kernels/conv_grad_ops.h
    226 
    227     // TF filter shape is [ H, W, ..., inC, outC ]
    228 
    229     // The input gradients are computed by a convolution of the output gradients
    230     // and the filter, with some appropriate padding. See the comment at the top
    231     // of conv_grad_ops.h for details.
    232     PrimitiveType dtype = out_backprop_xla_shape.element_type();
    233     auto out_backprop_div = AvgPoolDivideByCount(
    234         out_backprop, gradients_size, kernel_size, stride, spatial_padding,
    235         dtype, data_format, counts_include_padding);
    236 
    237     // Pad the gradients in the spatial dimensions. We use the same padding
    238     // as Conv2DBackpropInput.
    239     PaddingConfig padding_config = MakeNoPaddingConfig(num_dims);
    240     std::vector<int64> padded_gradients_size(gradients_size.begin(),
    241                                              gradients_size.end());
    242     // First, pad the output gradients the same way as the input. The additional
    243     // padding will be removed as a last step before returning the input
    244     // gradients.
    245     const int num_spatial_dims = num_dims - 2;
    246     for (int i = 0; i < num_spatial_dims; ++i) {
    247       int dim = data_format.spatial_dimension(i);
    248       padded_gradients_size[dim] +=
    249           (spatial_padding[i].first + spatial_padding[i].second);
    250     }
    251     for (int i = 0; i < num_spatial_dims; ++i) {
    252       int dim = data_format.spatial_dimension(i);
    253       TF_ASSIGN_OR_RETURN(
    254           SpatialDimensionOutputSizeAndPadding conv_backprop_spatial_dim,
    255           ConvGradExtractAndVerifyDimension(
    256               /*input_size=*/padded_gradients_size[dim],
    257               /*filter_size=*/kernel_size[dim],
    258               /*output_size=*/out_backprop_xla_shape.dimensions(dim),
    259               /*dilation=*/1,
    260               /*stride=*/stride[dim], /*padding=*/Padding::kValid));
    261       auto* padding = padding_config.mutable_dimensions(dim);
    262       padding->set_edge_padding_low(conv_backprop_spatial_dim.pad_before);
    263       padding->set_edge_padding_high(conv_backprop_spatial_dim.pad_after);
    264       padding->set_interior_padding(stride[dim] - 1);
    265     }
    266 
    267     auto zero = Zero(b, dtype);
    268     auto padded_gradients = Pad(out_backprop_div, zero, padding_config);
    269 
    270     // in_backprop = padded_gradients <conv> ones
    271     std::vector<int64> ones(num_dims, 1LL);
    272     auto in_backprop =
    273         ReduceWindow(padded_gradients, Zero(b, dtype),
    274                      CreateScalarAddComputation(dtype, b), kernel_size,
    275                      /*window_strides=*/ones, Padding::kValid);
    276     // The input padding doesn't contribute to the gradient, remove it.
    277     std::vector<std::pair<int64, int64>> neg_spatial_padding;
    278     neg_spatial_padding.reserve(spatial_padding.size());
    279     for (const std::pair<int64, int64>& spatial_padding_dim : spatial_padding) {
    280       neg_spatial_padding.emplace_back(-spatial_padding_dim.first,
    281                                        -spatial_padding_dim.second);
    282     }
    283     auto remove_padding_config = MakeSpatialPaddingConfig(
    284         neg_spatial_padding, num_spatial_dims, stride, data_format);
    285     return Pad(in_backprop, zero, remove_padding_config);
    286   });
    287 }
    288 
    289 }  // namespace xla
    290