1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // XLA-specific Ops for 2D convolution. 17 18 #include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h" 19 #include "absl/types/span.h" 20 #include "tensorflow/compiler/tf2xla/shape_util.h" 21 #include "tensorflow/compiler/tf2xla/type_util.h" 22 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 23 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 24 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 25 #include "tensorflow/compiler/xla/client/lib/arithmetic.h" 26 #include "tensorflow/compiler/xla/client/lib/constants.h" 27 #include "tensorflow/compiler/xla/client/xla_builder.h" 28 #include "tensorflow/compiler/xla/literal_util.h" 29 #include "tensorflow/core/framework/bounds_check.h" 30 #include "tensorflow/core/framework/node_def_util.h" 31 #include "tensorflow/core/framework/numeric_op.h" 32 #include "tensorflow/core/framework/op_kernel.h" 33 #include "tensorflow/core/framework/ops_util.h" 34 #include "tensorflow/core/framework/tensor.h" 35 #include "tensorflow/core/framework/tensor_shape.h" 36 #include "tensorflow/core/framework/tensor_slice.h" 37 #include "tensorflow/core/kernels/conv_grad_ops.h" 38 #include "tensorflow/core/util/padding.h" 39 #include "tensorflow/core/util/tensor_format.h" 40 41 namespace tensorflow { 42 namespace { 43 44 // Returns the expanded size of a filter used for depthwise convolution. 45 // If `shape` is [H, W, ..., M, N] returns [H, W, ..., M, M*N]. 46 xla::Shape ExpandedFilterShapeForDepthwiseConvolution(const xla::Shape& shape) { 47 int num_dims = shape.dimensions_size(); 48 CHECK_GE(num_dims, 2); // Crash OK 49 xla::Shape expanded_shape = shape; 50 expanded_shape.set_dimensions( 51 num_dims - 1, 52 shape.dimensions(num_dims - 2) * shape.dimensions(num_dims - 1)); 53 return expanded_shape; 54 } 55 56 // Create a mask for depthwise convolution that will make a normal convolution 57 // produce the same results as a depthwise convolution. For a [2, 2, 3, 2] 58 // depthwise filter this returns a [2, 2, 3, 6] tensor 59 // 1 1 0 0 0 0 1 1 0 0 0 0 60 // 0 0 1 1 0 0 0 0 1 1 0 0 61 // 0 0 0 0 1 1 0 0 0 0 1 1 62 // 63 // 1 1 0 0 0 0 1 1 0 0 0 0 64 // 0 0 1 1 0 0 0 0 1 1 0 0 65 // 0 0 0 0 1 1 0 0 0 0 1 1 66 // 67 // The first step is to create a iota A with iota_dimension = 2 68 // 0 0 0 0 0 0 0 0 0 0 0 0 69 // 1 1 1 1 1 1 1 1 1 1 1 1 70 // 2 2 2 2 2 2 2 2 2 2 2 2 71 // 72 // 0 0 0 0 0 0 0 0 0 0 0 0 73 // 1 1 1 1 1 1 1 1 1 1 1 1 74 // 2 2 2 2 2 2 2 2 2 2 2 2 75 // 76 // and another iota B with iota_dimension = 3 77 // 0 1 2 3 4 5 0 1 2 3 4 5 78 // 0 1 2 3 4 5 0 1 2 3 4 5 79 // 0 1 2 3 4 5 0 1 2 3 4 5 80 // 81 // 0 1 2 3 4 5 0 1 2 3 4 5 82 // 0 1 2 3 4 5 0 1 2 3 4 5 83 // 0 1 2 3 4 5 0 1 2 3 4 5 84 // 85 // and divide B by 2 to get 86 // 0 0 1 1 2 2 0 0 1 1 2 2 87 // 0 0 1 1 2 2 0 0 1 1 2 2 88 // 0 0 1 1 2 2 0 0 1 1 2 2 89 // 90 // 0 0 1 1 2 2 0 0 1 1 2 2 91 // 0 0 1 1 2 2 0 0 1 1 2 2 92 // 0 0 1 1 2 2 0 0 1 1 2 2 93 // 94 // Finally compare A and B and return the result at the beginning of the 95 // comment. 96 xla::XlaOp CreateExpandedFilterMask(const xla::Shape& filter_shape, 97 xla::XlaBuilder* builder) { 98 xla::Shape expanded_filter_shape = 99 ExpandedFilterShapeForDepthwiseConvolution(filter_shape); 100 int64 depthwise_multiplier = 101 filter_shape.dimensions(filter_shape.dimensions_size() - 1); 102 103 // Create two iotas with the shape of the expanded filter, one of them with 104 // the iota dimension chosen as the feature dimension, and the other a iota 105 // with the iota dimension chosen as the expanded output feature dimension. 106 std::vector<int64> iota_dimensions(expanded_filter_shape.dimensions().begin(), 107 expanded_filter_shape.dimensions().end()); 108 xla::Shape iota_shape = xla::ShapeUtil::MakeShape(xla::S32, iota_dimensions); 109 xla::XlaOp input_feature_iota = xla::Iota( 110 builder, iota_shape, /*iota_dimension=*/iota_dimensions.size() - 2); 111 xla::XlaOp expanded_feature_iota = xla::Iota( 112 builder, iota_shape, /*iota_dimension=*/iota_dimensions.size() - 1); 113 114 // Divide 'expanded_feature_iota' by the depthwise_multiplier to create 115 // [0 0 1 1 2 2] ... in the example in the function comment. 116 expanded_feature_iota = 117 xla::Div(expanded_feature_iota, 118 XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32, 119 depthwise_multiplier)); 120 121 // Compare 'input_feature_iota' with 'expanded_feature_iota' to create a 122 // diagonal predicate. 123 return xla::Eq(expanded_feature_iota, input_feature_iota); 124 } 125 126 // Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to 127 // build a depthwise convolution. 128 xla::XlaOp ReshapeFilterForDepthwiseConvolution(const xla::Shape& filter_shape, 129 const xla::XlaOp& filter) { 130 int64 input_feature_dim = filter_shape.dimensions_size() - 2; 131 int64 output_feature_dim = filter_shape.dimensions_size() - 1; 132 int64 depthwise_multiplier = filter_shape.dimensions(output_feature_dim); 133 int64 input_feature = filter_shape.dimensions(input_feature_dim); 134 135 // Create a [H, W, ..., 1, N*M] reshape of the filter. 136 xla::Shape implicit_broadcast_filter_shape = filter_shape; 137 implicit_broadcast_filter_shape.set_dimensions(input_feature_dim, 1); 138 implicit_broadcast_filter_shape.set_dimensions( 139 output_feature_dim, depthwise_multiplier * input_feature); 140 return xla::Reshape( 141 filter, xla::AsInt64Slice(implicit_broadcast_filter_shape.dimensions())); 142 } 143 144 // Reduces the results of the convolution with an expanded filter to the 145 // non-expanded filter. 146 xla::XlaOp ContractFilterForDepthwiseBackprop(const xla::Shape& filter_shape, 147 const xla::XlaOp& filter_backprop, 148 xla::XlaBuilder* builder) { 149 auto masked_expanded_filter = 150 xla::Select(CreateExpandedFilterMask(filter_shape, builder), 151 filter_backprop, xla::ZerosLike(filter_backprop)); 152 153 auto elem_type = filter_shape.element_type(); 154 return xla::Reshape( 155 // This reduce does not need inputs to be converted with 156 // XlaHelpers::SumAccumulationType() since the select above guarantees 157 // that only one element is non zero, so there cannot be accumulated 158 // precision error. 159 xla::Reduce(masked_expanded_filter, xla::Zero(builder, elem_type), 160 CreateScalarAddComputation(elem_type, builder), 161 {filter_shape.dimensions_size() - 2}), 162 xla::AsInt64Slice(filter_shape.dimensions())); 163 } 164 165 // Performs some basic checks on ConvOpAttrs that are true for all kinds of XLA 166 // convolutions (as currently implemented). 167 Status CheckConvAttrs(const ConvOpAttrs& attrs) { 168 const int num_dims = attrs.num_spatial_dims + 2; 169 if (attrs.strides.size() != num_dims) { 170 return errors::InvalidArgument("Sliding window strides field must specify ", 171 num_dims, " dimensions"); 172 } 173 int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format); 174 int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format); 175 if (attrs.strides[batch_dim] != 1 || attrs.strides[feature_dim] != 1) { 176 return errors::Unimplemented( 177 "Current implementation does not yet support strides in the batch and " 178 "depth dimensions."); 179 } 180 if (attrs.dilations.size() != num_dims) { 181 return errors::InvalidArgument("Dilations field must specify ", num_dims, 182 " dimensions"); 183 } 184 if (attrs.dilations[batch_dim] != 1 || attrs.dilations[feature_dim] != 1) { 185 return errors::Unimplemented( 186 "Current implementation does not support dilations in the batch and " 187 "depth dimensions."); 188 } 189 for (int i = 0; i < attrs.num_spatial_dims; ++i) { 190 int input_dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i); 191 if (attrs.dilations[input_dim] < 1) { 192 return errors::Unimplemented("Dilation values must be positive; ", i, 193 "th spatial dimension had dilation ", 194 attrs.dilations[input_dim]); 195 } 196 } 197 return Status::OK(); 198 } 199 200 // Wrapper around ConvBackpropComputeDimensions that converts from XLA shapes 201 // to TensorShapes. 202 Status ConvBackpropComputeDimensionsV2XlaShapes( 203 StringPiece label, int num_spatial_dims, const xla::Shape& input_shape, 204 const xla::Shape& filter_shape, const xla::Shape& out_backprop_shape, 205 absl::Span<const int32> dilations, const std::vector<int32>& strides, 206 Padding padding, TensorFormat data_format, ConvBackpropDimensions* dims, 207 absl::Span<const int64> explicit_paddings) { 208 TensorShape input_tensor_shape, filter_tensor_shape, 209 out_backprop_tensor_shape; 210 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape)); 211 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(filter_shape, &filter_tensor_shape)); 212 TF_RETURN_IF_ERROR( 213 XLAShapeToTensorShape(out_backprop_shape, &out_backprop_tensor_shape)); 214 return ConvBackpropComputeDimensionsV2( 215 label, num_spatial_dims, input_tensor_shape, filter_tensor_shape, 216 out_backprop_tensor_shape, dilations, strides, padding, explicit_paddings, 217 data_format, dims); 218 } 219 220 } // anonymous namespace 221 222 xla::StatusOr<ConvOpAttrs> ConvOpAttrs::Create(int num_spatial_dims, 223 bool depthwise, 224 OpKernelConstruction* ctx) { 225 ConvOpAttrs attrs; 226 attrs.num_spatial_dims = num_spatial_dims; 227 attrs.depthwise = depthwise; 228 TF_RETURN_IF_ERROR(ctx->GetAttr("dilations", &attrs.dilations)); 229 TF_RETURN_IF_ERROR(ctx->GetAttr("strides", &attrs.strides)); 230 TF_RETURN_IF_ERROR(ctx->GetAttr("padding", &attrs.padding)); 231 if (attrs.padding == EXPLICIT) { 232 TF_RETURN_IF_ERROR( 233 ctx->GetAttr("explicit_paddings", &attrs.explicit_paddings)); 234 } 235 236 string data_format; 237 TF_RETURN_IF_ERROR(ctx->GetAttr("data_format", &data_format)); 238 if (!FormatFromString(data_format, &attrs.data_format)) { 239 return errors::InvalidArgument("Invalid data format: ", data_format); 240 } 241 242 return attrs; 243 } 244 245 xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece /*type_string*/, 246 xla::XlaOp conv_input, 247 xla::XlaOp filter, 248 const ConvOpAttrs& attrs) { 249 TF_RETURN_IF_ERROR(CheckConvAttrs(attrs)); 250 251 auto* builder = conv_input.builder(); 252 TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(conv_input)); 253 // Filter has the form [filter_rows, filter_cols, ..., in_depth, out_depth] 254 TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter)); 255 256 // For 2D convolution, there should be 4 dimensions. 257 int num_dims = attrs.num_spatial_dims + 2; 258 if (input_shape.dimensions_size() != num_dims) { 259 return errors::InvalidArgument("input must be ", num_dims, "-dimensional", 260 input_shape.DebugString()); 261 } 262 if (filter_shape.dimensions_size() != num_dims) { 263 return errors::InvalidArgument( 264 "filter must be ", num_dims, 265 "-dimensional: ", filter_shape.DebugString()); 266 } 267 268 // The last two dimensions of the filter are the input and output shapes. 269 int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format); 270 int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format); 271 272 int64 in_depth = filter_shape.dimensions(attrs.num_spatial_dims); 273 // The 'C' dimension for input is in_depth. It must be the same as 274 // the filter's in_depth. 275 if (in_depth != input_shape.dimensions(feature_dim)) { 276 return errors::InvalidArgument( 277 "input and filter must have the same depth: ", in_depth, " vs ", 278 input_shape.dimensions(feature_dim)); 279 } 280 281 if (attrs.depthwise) { 282 filter = ReshapeFilterForDepthwiseConvolution(filter_shape, filter); 283 } 284 285 xla::ConvolutionDimensionNumbers dims; 286 std::vector<int64> window_strides(attrs.num_spatial_dims); 287 std::vector<int64> lhs_dilation(attrs.num_spatial_dims, 1); 288 std::vector<int64> rhs_dilation(attrs.num_spatial_dims); 289 std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims); 290 291 dims.set_input_batch_dimension(batch_dim); 292 dims.set_output_batch_dimension(batch_dim); 293 dims.set_input_feature_dimension(feature_dim); 294 dims.set_output_feature_dimension(feature_dim); 295 dims.set_kernel_input_feature_dimension(attrs.num_spatial_dims); 296 dims.set_kernel_output_feature_dimension(attrs.num_spatial_dims + 1); 297 298 for (int i = 0; i < attrs.num_spatial_dims; ++i) { 299 const int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i); 300 dims.add_input_spatial_dimensions(dim); 301 dims.add_kernel_spatial_dimensions(i); 302 dims.add_output_spatial_dimensions(dim); 303 window_strides[i] = attrs.strides.at(dim); 304 rhs_dilation[i] = attrs.dilations.at(dim); 305 306 if (attrs.padding == EXPLICIT) { 307 padding[i] = {attrs.explicit_paddings.at(dim * 2), 308 attrs.explicit_paddings.at(dim * 2 + 1)}; 309 } 310 311 int64 unused_output_size; 312 TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2( 313 input_shape.dimensions(dim), filter_shape.dimensions(i), 314 rhs_dilation[i], window_strides[i], attrs.padding, &unused_output_size, 315 &padding[i].first, &padding[i].second)); 316 } 317 318 return xla::ConvGeneralDilated( 319 conv_input, filter, window_strides, padding, lhs_dilation, rhs_dilation, 320 dims, /*feature_group_count=*/attrs.depthwise ? in_depth : 1); 321 } 322 323 xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp( 324 StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter, 325 xla::XlaOp out_backprop, const ConvOpAttrs& attrs) { 326 TF_RETURN_IF_ERROR(CheckConvAttrs(attrs)); 327 328 int num_dims = attrs.num_spatial_dims + 2; 329 int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format); 330 int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format); 331 332 auto* builder = filter.builder(); 333 TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter)); 334 TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape, 335 builder->GetShape(out_backprop)); 336 337 xla::Shape expanded_filter_shape = 338 attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape) 339 : filter_shape; 340 // Reuse dimension computation logic from conv_grad_ops.cc. 341 ConvBackpropDimensions dims; 342 TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes( 343 type_string, attrs.num_spatial_dims, input_shape, expanded_filter_shape, 344 out_backprop_shape, attrs.dilations, attrs.strides, attrs.padding, 345 attrs.data_format, &dims, attrs.explicit_paddings)); 346 347 // The input gradients are computed by a convolution of the output 348 // gradients and the filter, with some appropriate padding. See the 349 // comment at the top of conv_grad_ops.h for details. 350 351 xla::ConvolutionDimensionNumbers dnums; 352 dnums.set_input_batch_dimension(batch_dim); 353 dnums.set_output_batch_dimension(batch_dim); 354 dnums.set_input_feature_dimension(feature_dim); 355 dnums.set_output_feature_dimension(feature_dim); 356 357 // TF filter shape is [ H, W, ..., inC, outC ] 358 // Transpose the input and output features for computing the gradient. 359 dnums.set_kernel_input_feature_dimension(attrs.num_spatial_dims + 1); 360 dnums.set_kernel_output_feature_dimension(attrs.num_spatial_dims); 361 362 std::vector<int64> kernel_spatial_dims(attrs.num_spatial_dims); 363 std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims); 364 std::vector<int64> lhs_dilation(attrs.num_spatial_dims); 365 std::vector<int64> rhs_dilation(attrs.num_spatial_dims); 366 std::vector<int64> ones(attrs.num_spatial_dims, 1); 367 for (int i = 0; i < attrs.num_spatial_dims; ++i) { 368 int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i); 369 dnums.add_input_spatial_dimensions(dim); 370 dnums.add_kernel_spatial_dimensions(i); 371 dnums.add_output_spatial_dimensions(dim); 372 373 kernel_spatial_dims[i] = i; 374 padding[i] = {dims.spatial_dims[i].pad_before, 375 dims.spatial_dims[i].pad_after}; 376 lhs_dilation[i] = dims.spatial_dims[i].stride; 377 rhs_dilation[i] = attrs.dilations[dim]; 378 } 379 380 // Mirror the filter in the spatial dimensions. 381 xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims); 382 383 // activation gradients 384 // = gradients (with padding and dilation) <conv> mirrored_weights 385 return xla::ConvGeneralDilated( 386 out_backprop, mirrored_weights, /*window_strides=*/ones, padding, 387 lhs_dilation, rhs_dilation, dnums, 388 /*feature_group_count=*/ 389 attrs.depthwise ? out_backprop_shape.dimensions(feature_dim) / 390 filter_shape.dimensions(attrs.num_spatial_dims + 1) 391 : 1); 392 } 393 394 xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp( 395 StringPiece type_string, xla::XlaOp activations, 396 const xla::Shape& filter_shape, xla::XlaOp gradients, 397 const ConvOpAttrs& attrs) { 398 TF_RETURN_IF_ERROR(CheckConvAttrs(attrs)); 399 400 auto* builder = activations.builder(); 401 TF_ASSIGN_OR_RETURN(xla::Shape activations_shape, 402 builder->GetShape(activations)); 403 TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape, 404 builder->GetShape(gradients)); 405 xla::XlaOp filter_backprop; 406 407 xla::Shape input_shape = activations_shape; 408 xla::Shape output_shape = out_backprop_shape; 409 410 TensorShape input_tensor_shape, filter_tensor_shape, output_tensor_shape; 411 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(filter_shape, &filter_tensor_shape)); 412 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape)); 413 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(output_shape, &output_tensor_shape)); 414 415 const xla::Shape expanded_filter_shape = 416 attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape) 417 : filter_shape; 418 // Reuse dimension computation logic from conv_grad_ops.cc. 419 ConvBackpropDimensions dims; 420 // The filter gradients are computed by a convolution of the input 421 // activations and the output gradients, with some appropriate padding. 422 // See the comment at the top of conv_grad_ops.h for details. 423 xla::ConvolutionDimensionNumbers dnums; 424 425 TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes( 426 type_string, attrs.num_spatial_dims, activations_shape, 427 expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides, 428 attrs.padding, attrs.data_format, &dims, attrs.explicit_paddings)); 429 430 // The activations (inputs) form the LHS of the convolution. 431 // Activations have shape: [batch, in_rows, in_cols, ..., in_depth] 432 // For the gradient computation, we flip the roles of the batch and 433 // feature dimensions. 434 // Each spatial entry has size in_depth * batch 435 436 // The last two dimensions of the filter are the input and output shapes. 437 int num_dims = attrs.num_spatial_dims + 2; 438 int n_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format); 439 int c_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format); 440 441 bool use_batch_group_count = 442 filter_tensor_shape.dim_size(num_dims - 1) == 1 && attrs.depthwise; 443 444 std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims); 445 std::vector<int64> rhs_dilation(attrs.num_spatial_dims); 446 std::vector<int64> window_strides(attrs.num_spatial_dims); 447 std::vector<int64> ones(attrs.num_spatial_dims, 1); 448 449 // Swap n_dim and c_dim in the activations. 450 dnums.set_input_batch_dimension(c_dim); 451 dnums.set_input_feature_dimension(n_dim); 452 453 // The gradients become the RHS of the convolution. 454 // The gradients have shape [batch, out_rows, out_cols, ..., out_depth] 455 // where the batch becomes the input feature for the convolution. 456 dnums.set_kernel_input_feature_dimension(n_dim); 457 dnums.set_kernel_output_feature_dimension(c_dim); 458 459 // The dimension swap below is needed because filter shape is KH,KW,F,DM. 460 if (use_batch_group_count) { 461 dnums.set_output_batch_dimension(attrs.num_spatial_dims + 1); 462 dnums.set_output_feature_dimension(attrs.num_spatial_dims); 463 } else { 464 dnums.set_output_batch_dimension(attrs.num_spatial_dims); 465 dnums.set_output_feature_dimension(attrs.num_spatial_dims + 1); 466 } 467 468 // Tensorflow filter shape is [ H, W, ..., inC, outC ]. 469 for (int i = 0; i < attrs.num_spatial_dims; ++i) { 470 dnums.add_output_spatial_dimensions(i); 471 } 472 473 for (int64 i = 0; i < attrs.num_spatial_dims; ++i) { 474 int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i); 475 dnums.add_input_spatial_dimensions(dim); 476 dnums.add_kernel_spatial_dimensions(dim); 477 rhs_dilation[i] = dims.spatial_dims[i].stride; 478 window_strides[i] = attrs.dilations[dim]; 479 480 // We will also need to pad the input with zeros such that after the 481 // convolution, we get the right size for the filter. 482 // The padded_in_rows should be such that when we convolve this with the 483 // expanded_out_rows as a filter, we should get filter_rows back. 484 485 const int64 padded_in_size = 486 dims.spatial_dims[i].expanded_output_size + 487 (dims.spatial_dims[i].filter_size - 1) * attrs.dilations[dim]; 488 489 // However it can be smaller than input_rows: in this 490 // case it means some of the inputs are not used. 491 // 492 // An example is to have input_cols = 3, filter_cols = 2 and stride = 2: 493 // 494 // INPUT = [ A B C ] 495 // 496 // FILTER = [ x y ] 497 // 498 // and the output will only have one column: a = A * x + B * y 499 // 500 // and input "C" is not used at all. 501 // 502 // We apply negative padding in this case. 503 const int64 pad_total = padded_in_size - dims.spatial_dims[i].input_size; 504 505 // + For the EXPLICIT padding, we pad the top/left side with the explicit 506 // padding and pad the bottom/right side with the remaining space. 507 // + For the VALID padding, we don't pad anything on the top/left side 508 // and pad the bottom/right side with the remaining space. 509 // + For the SAME padding, we pad top/left side the same as bottom/right 510 // side. 511 // 512 // In addition, if the padded input size is smaller than the input size, 513 // we need to ignore some training elements of the input. We do this by 514 // applying negative padding on the right/bottom. 515 const int64 pad_before = attrs.padding == Padding::EXPLICIT 516 ? attrs.explicit_paddings[2 * dim] 517 : attrs.padding == Padding::SAME 518 ? std::max<int64>(pad_total / 2, 0) 519 : 0; 520 padding[i] = {pad_before, pad_total - pad_before}; 521 } 522 523 // Besides padding the input, we will also expand output_rows to 524 // expanded_out_rows = (output_rows - 1) * stride + 1 525 // with zeros in between: 526 // 527 // a . . . b . . . c . . . d . . . e 528 // 529 // This is done by specifying the window dilation factors in the 530 // convolution HLO below. 531 532 filter_backprop = xla::ConvGeneralDilated( 533 activations, gradients, window_strides, padding, /*lhs_dilation=*/ones, 534 rhs_dilation, dnums, 535 /*feature_group_count=*/1, 536 /*batch_group_count=*/use_batch_group_count ? dims.in_depth : 1); 537 538 if (!use_batch_group_count && attrs.depthwise) { 539 filter_backprop = ContractFilterForDepthwiseBackprop( 540 filter_shape, filter_backprop, activations.builder()); 541 } 542 543 return filter_backprop; 544 } 545 546 } // namespace tensorflow 547