1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/client/lib/pooling.h" 17 #include "tensorflow/compiler/xla/client/lib/arithmetic.h" 18 #include "tensorflow/compiler/xla/client/lib/constants.h" 19 #include "tensorflow/compiler/xla/client/lib/conv_grad_size_util.h" 20 21 namespace xla { 22 23 namespace { 24 25 // Common computation shared between AvgPool and AvgPoolGrad. Divide each 26 // element of an image by the count of elements that contributed to that 27 // element during pooling. 28 XlaOp AvgPoolDivideByCountWithGeneralPadding( 29 XlaOp sums, PrimitiveType dtype, absl::Span<const int64> input_shape, 30 absl::Span<const std::pair<int64, int64>> spatial_padding, 31 absl::Span<const int64> ksize, absl::Span<const int64> stride, 32 const TensorFormat& data_format) { 33 // The padding shouldn't be included in the counts. We use another 34 // ReduceWindow to find the right counts. 35 const int num_spatial_dims = spatial_padding.size(); 36 37 std::vector<int64> input_dim_sizes(num_spatial_dims); 38 std::vector<int64> window_dims(num_spatial_dims); 39 std::vector<int64> window_ksize(num_spatial_dims); 40 std::vector<int64> window_stride(num_spatial_dims); 41 CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims) 42 << "Invalid number of spatial dimentions in data format specification"; 43 for (int i = 0; i < num_spatial_dims; ++i) { 44 int dim = data_format.spatial_dimension(i); 45 input_dim_sizes[i] = input_shape[dim]; 46 window_dims[i] = dim; 47 window_ksize[i] = ksize[dim]; 48 window_stride[i] = stride[dim]; 49 } 50 51 XlaBuilder* b = sums.builder(); 52 // Build a matrix of all 1s, with the same width/height as the input. 53 auto ones = Broadcast(One(b, dtype), input_dim_sizes); 54 PaddingConfig padding_config; 55 for (int i = 0; i < num_spatial_dims; ++i) { 56 auto dims = padding_config.add_dimensions(); 57 dims->set_edge_padding_low(spatial_padding[i].first); 58 dims->set_edge_padding_high(spatial_padding[i].second); 59 } 60 auto zero = Zero(b, dtype); 61 auto padded_ones = Pad(ones, zero, padding_config); 62 63 // Perform a ReduceWindow with the same window size, strides, and padding 64 // to count the number of contributions to each result element. 65 auto counts = 66 ReduceWindow(padded_ones, zero, CreateScalarAddComputation(dtype, b), 67 window_ksize, window_stride, Padding::kValid); 68 69 return Div(sums, counts, window_dims); 70 } 71 72 // Sums all elements in the window specified by 'kernel_size' and 'stride'. 73 XlaOp ComputeSums(XlaOp operand, XlaOp init_value, 74 absl::Span<const int64> kernel_size, 75 absl::Span<const int64> stride, 76 const TensorFormat& data_format) { 77 XlaBuilder* b = operand.builder(); 78 return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { 79 TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand)); 80 TF_ASSIGN_OR_RETURN(Shape init_shape, b->GetShape(init_value)); 81 PrimitiveType accumulation_type = init_shape.element_type(); 82 auto add_computation = CreateScalarAddComputation(accumulation_type, b); 83 return ReduceWindow(operand, init_value, add_computation, kernel_size, 84 stride, Padding::kValid); 85 }); 86 } 87 88 // Creates a padding configuration out of spatial padding values. 89 PaddingConfig MakeSpatialPaddingConfig( 90 absl::Span<const std::pair<int64, int64>> spatial_padding, 91 int num_spatial_dims, absl::Span<const int64> stride, 92 const TensorFormat& data_format) { 93 PaddingConfig padding_config; 94 for (int i = 0; i < 2 + num_spatial_dims; ++i) { 95 padding_config.add_dimensions(); 96 } 97 CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims) 98 << "Invalid number of spatial dimentions in data format specification"; 99 for (int i = 0; i < num_spatial_dims; ++i) { 100 int dim = data_format.spatial_dimension(i); 101 auto padding_dimension = padding_config.mutable_dimensions(dim); 102 padding_dimension->set_edge_padding_low(spatial_padding[i].first); 103 padding_dimension->set_edge_padding_high(spatial_padding[i].second); 104 } 105 return padding_config; 106 } 107 108 XlaOp AvgPoolDivideByCount(XlaOp pooled, absl::Span<const int64> input_size, 109 absl::Span<const int64> window_dimensions, 110 absl::Span<const int64> window_strides, 111 absl::Span<const std::pair<int64, int64>> padding, 112 PrimitiveType dtype, const TensorFormat& data_format, 113 bool counts_include_padding) { 114 if (counts_include_padding) { 115 // If counts include padding, all windows have the same number of elements 116 // contributing to each average. Divide by the window size everywhere to get 117 // the average. 118 int64 window_size = 119 std::accumulate(window_dimensions.begin(), window_dimensions.end(), 1, 120 [](int64 a, int64 b) { return a * b; }); 121 auto divisor = ConstantR0WithType(pooled.builder(), dtype, window_size); 122 123 return pooled / divisor; 124 } else { 125 return AvgPoolDivideByCountWithGeneralPadding(pooled, dtype, input_size, 126 padding, window_dimensions, 127 window_strides, data_format); 128 } 129 } 130 131 } // namespace 132 133 XlaOp MaxPool(XlaOp operand, absl::Span<const int64> kernel_size, 134 absl::Span<const int64> stride, Padding padding, 135 const TensorFormat& data_format) { 136 XlaBuilder* b = operand.builder(); 137 return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { 138 TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand)); 139 PrimitiveType dtype = operand_shape.element_type(); 140 auto max_computation = CreateScalarMaxComputation(dtype, b); 141 auto init_value = MinValue(b, dtype); 142 return ReduceWindow(operand, init_value, max_computation, kernel_size, 143 stride, padding); 144 }); 145 } 146 147 XlaOp AvgPool(XlaOp operand, absl::Span<const int64> kernel_size, 148 absl::Span<const int64> stride, 149 absl::Span<const std::pair<int64, int64>> padding, 150 const TensorFormat& data_format, 151 const bool counts_include_padding) { 152 XlaBuilder* b = operand.builder(); 153 return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { 154 TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand)); 155 PrimitiveType dtype = operand_shape.element_type(); 156 auto init_value = Zero(b, dtype); 157 std::vector<int64> input_size(operand_shape.dimensions().begin(), 158 operand_shape.dimensions().end()); 159 const int num_dims = kernel_size.size(); 160 const int num_spatial_dims = num_dims - 2; 161 auto padding_config = MakeSpatialPaddingConfig(padding, num_spatial_dims, 162 stride, data_format); 163 auto padded_operand = Pad(operand, Zero(b, dtype), padding_config); 164 auto pooled = ComputeSums(padded_operand, init_value, kernel_size, stride, 165 data_format); 166 return AvgPoolDivideByCount(pooled, input_size, kernel_size, stride, 167 padding, dtype, data_format, 168 counts_include_padding); 169 }); 170 } 171 172 std::vector<std::pair<int64, int64>> MakeSpatialPadding( 173 absl::Span<const int64> input_size, absl::Span<const int64> kernel_size, 174 absl::Span<const int64> stride, Padding padding, 175 const TensorFormat& data_format) { 176 const int num_spatial_dims = kernel_size.size() - 2; 177 std::vector<int64> input_spatial_dimensions; 178 std::vector<int64> kernel_size_spatial_dimensions; 179 std::vector<int64> stride_spatial_dimensions; 180 CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims) 181 << "Invalid number of spatial dimentions in data format specification"; 182 for (int i = 0; i < num_spatial_dims; ++i) { 183 int dim = data_format.spatial_dimension(i); 184 input_spatial_dimensions.push_back(input_size[dim]); 185 kernel_size_spatial_dimensions.push_back(kernel_size[dim]); 186 stride_spatial_dimensions.push_back(stride[dim]); 187 } 188 return MakePadding(input_spatial_dimensions, kernel_size_spatial_dimensions, 189 stride_spatial_dimensions, padding); 190 } 191 192 XlaOp AvgPoolGrad(XlaOp out_backprop, absl::Span<const int64> gradients_size, 193 absl::Span<const int64> kernel_size, 194 absl::Span<const int64> stride, 195 absl::Span<const std::pair<int64, int64>> spatial_padding, 196 const TensorFormat& data_format, 197 const bool counts_include_padding) { 198 XlaBuilder* b = out_backprop.builder(); 199 return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { 200 const int num_dims = kernel_size.size(); 201 202 if (gradients_size.size() != num_dims) { 203 return tensorflow::errors::InvalidArgument("gradients must be ", num_dims, 204 "-dimensional"); 205 } 206 207 TF_ASSIGN_OR_RETURN(Shape out_backprop_xla_shape, 208 b->GetShape(out_backprop)); 209 if (out_backprop_xla_shape.dimensions().size() != num_dims) { 210 return tensorflow::errors::InvalidArgument("out_backprop must be ", 211 num_dims, "-dimensional"); 212 } 213 214 // We can think of average-pooling as: 215 // * a convolution with a kernel consisting entirely of 1s, where the 216 // input feature and output feature are equal, and 0s everywhere else. 217 // * followed by dividing by the counts. 218 // 219 // This then gives us an algorithm to build the gradient: 220 // * divide out_backprop by the counts, followed by 221 // * Conv2DBackpropInput specialized for that kernel, which simplifies to 222 // a Pad and a ReduceWindow. 223 // 224 // For an explanation of backpropagation for convolution, see the comments 225 // in third_party/tensorflow/core/kernels/conv_grad_ops.h 226 227 // TF filter shape is [ H, W, ..., inC, outC ] 228 229 // The input gradients are computed by a convolution of the output gradients 230 // and the filter, with some appropriate padding. See the comment at the top 231 // of conv_grad_ops.h for details. 232 PrimitiveType dtype = out_backprop_xla_shape.element_type(); 233 auto out_backprop_div = AvgPoolDivideByCount( 234 out_backprop, gradients_size, kernel_size, stride, spatial_padding, 235 dtype, data_format, counts_include_padding); 236 237 // Pad the gradients in the spatial dimensions. We use the same padding 238 // as Conv2DBackpropInput. 239 PaddingConfig padding_config = MakeNoPaddingConfig(num_dims); 240 std::vector<int64> padded_gradients_size(gradients_size.begin(), 241 gradients_size.end()); 242 // First, pad the output gradients the same way as the input. The additional 243 // padding will be removed as a last step before returning the input 244 // gradients. 245 const int num_spatial_dims = num_dims - 2; 246 for (int i = 0; i < num_spatial_dims; ++i) { 247 int dim = data_format.spatial_dimension(i); 248 padded_gradients_size[dim] += 249 (spatial_padding[i].first + spatial_padding[i].second); 250 } 251 for (int i = 0; i < num_spatial_dims; ++i) { 252 int dim = data_format.spatial_dimension(i); 253 TF_ASSIGN_OR_RETURN( 254 SpatialDimensionOutputSizeAndPadding conv_backprop_spatial_dim, 255 ConvGradExtractAndVerifyDimension( 256 /*input_size=*/padded_gradients_size[dim], 257 /*filter_size=*/kernel_size[dim], 258 /*output_size=*/out_backprop_xla_shape.dimensions(dim), 259 /*dilation=*/1, 260 /*stride=*/stride[dim], /*padding=*/Padding::kValid)); 261 auto* padding = padding_config.mutable_dimensions(dim); 262 padding->set_edge_padding_low(conv_backprop_spatial_dim.pad_before); 263 padding->set_edge_padding_high(conv_backprop_spatial_dim.pad_after); 264 padding->set_interior_padding(stride[dim] - 1); 265 } 266 267 auto zero = Zero(b, dtype); 268 auto padded_gradients = Pad(out_backprop_div, zero, padding_config); 269 270 // in_backprop = padded_gradients <conv> ones 271 std::vector<int64> ones(num_dims, 1LL); 272 auto in_backprop = 273 ReduceWindow(padded_gradients, Zero(b, dtype), 274 CreateScalarAddComputation(dtype, b), kernel_size, 275 /*window_strides=*/ones, Padding::kValid); 276 // The input padding doesn't contribute to the gradient, remove it. 277 std::vector<std::pair<int64, int64>> neg_spatial_padding; 278 neg_spatial_padding.reserve(spatial_padding.size()); 279 for (const std::pair<int64, int64>& spatial_padding_dim : spatial_padding) { 280 neg_spatial_padding.emplace_back(-spatial_padding_dim.first, 281 -spatial_padding_dim.second); 282 } 283 auto remove_padding_config = MakeSpatialPaddingConfig( 284 neg_spatial_padding, num_spatial_dims, stride, data_format); 285 return Pad(in_backprop, zero, remove_padding_config); 286 }); 287 } 288 289 } // namespace xla 290