Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // XLA-specific Ops for 2D convolution.
     17 
     18 #include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h"
     19 #include "absl/types/span.h"
     20 #include "tensorflow/compiler/tf2xla/shape_util.h"
     21 #include "tensorflow/compiler/tf2xla/type_util.h"
     22 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
     23 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
     24 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     25 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
     26 #include "tensorflow/compiler/xla/client/lib/constants.h"
     27 #include "tensorflow/compiler/xla/client/xla_builder.h"
     28 #include "tensorflow/compiler/xla/literal_util.h"
     29 #include "tensorflow/core/framework/bounds_check.h"
     30 #include "tensorflow/core/framework/node_def_util.h"
     31 #include "tensorflow/core/framework/numeric_op.h"
     32 #include "tensorflow/core/framework/op_kernel.h"
     33 #include "tensorflow/core/framework/ops_util.h"
     34 #include "tensorflow/core/framework/tensor.h"
     35 #include "tensorflow/core/framework/tensor_shape.h"
     36 #include "tensorflow/core/framework/tensor_slice.h"
     37 #include "tensorflow/core/kernels/conv_grad_ops.h"
     38 #include "tensorflow/core/util/padding.h"
     39 #include "tensorflow/core/util/tensor_format.h"
     40 
     41 namespace tensorflow {
     42 namespace {
     43 
     44 // Returns the expanded size of a filter used for depthwise convolution.
     45 // If `shape` is [H, W, ..., M, N] returns [H, W, ..., M, M*N].
     46 xla::Shape ExpandedFilterShapeForDepthwiseConvolution(const xla::Shape& shape) {
     47   int num_dims = shape.dimensions_size();
     48   CHECK_GE(num_dims, 2);  // Crash OK
     49   xla::Shape expanded_shape = shape;
     50   expanded_shape.set_dimensions(
     51       num_dims - 1,
     52       shape.dimensions(num_dims - 2) * shape.dimensions(num_dims - 1));
     53   return expanded_shape;
     54 }
     55 
     56 // Create a mask for depthwise convolution that will make a normal convolution
     57 // produce the same results as a depthwise convolution. For a [2, 2, 3, 2]
     58 // depthwise filter this returns a [2, 2, 3, 6] tensor
     59 //   1 1 0 0 0 0   1 1 0 0 0 0
     60 //   0 0 1 1 0 0   0 0 1 1 0 0
     61 //   0 0 0 0 1 1   0 0 0 0 1 1
     62 //
     63 //   1 1 0 0 0 0   1 1 0 0 0 0
     64 //   0 0 1 1 0 0   0 0 1 1 0 0
     65 //   0 0 0 0 1 1   0 0 0 0 1 1
     66 //
     67 // The first step is to create a iota A with iota_dimension = 2
     68 //   0 0 0 0 0 0   0 0 0 0 0 0
     69 //   1 1 1 1 1 1   1 1 1 1 1 1
     70 //   2 2 2 2 2 2   2 2 2 2 2 2
     71 //
     72 //   0 0 0 0 0 0   0 0 0 0 0 0
     73 //   1 1 1 1 1 1   1 1 1 1 1 1
     74 //   2 2 2 2 2 2   2 2 2 2 2 2
     75 //
     76 // and another iota B with iota_dimension = 3
     77 //   0 1 2 3 4 5  0 1 2 3 4 5
     78 //   0 1 2 3 4 5  0 1 2 3 4 5
     79 //   0 1 2 3 4 5  0 1 2 3 4 5
     80 //
     81 //   0 1 2 3 4 5  0 1 2 3 4 5
     82 //   0 1 2 3 4 5  0 1 2 3 4 5
     83 //   0 1 2 3 4 5  0 1 2 3 4 5
     84 //
     85 // and divide B by 2 to get
     86 //   0 0 1 1 2 2  0 0 1 1 2 2
     87 //   0 0 1 1 2 2  0 0 1 1 2 2
     88 //   0 0 1 1 2 2  0 0 1 1 2 2
     89 //
     90 //   0 0 1 1 2 2  0 0 1 1 2 2
     91 //   0 0 1 1 2 2  0 0 1 1 2 2
     92 //   0 0 1 1 2 2  0 0 1 1 2 2
     93 //
     94 // Finally compare A and B and return the result at the beginning of the
     95 // comment.
     96 xla::XlaOp CreateExpandedFilterMask(const xla::Shape& filter_shape,
     97                                     xla::XlaBuilder* builder) {
     98   xla::Shape expanded_filter_shape =
     99       ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
    100   int64 depthwise_multiplier =
    101       filter_shape.dimensions(filter_shape.dimensions_size() - 1);
    102 
    103   // Create two iotas with the shape of the expanded filter, one of them with
    104   // the iota dimension chosen as the feature dimension, and the other a iota
    105   // with the iota dimension chosen as the expanded output feature dimension.
    106   std::vector<int64> iota_dimensions(expanded_filter_shape.dimensions().begin(),
    107                                      expanded_filter_shape.dimensions().end());
    108   xla::Shape iota_shape = xla::ShapeUtil::MakeShape(xla::S32, iota_dimensions);
    109   xla::XlaOp input_feature_iota = xla::Iota(
    110       builder, iota_shape, /*iota_dimension=*/iota_dimensions.size() - 2);
    111   xla::XlaOp expanded_feature_iota = xla::Iota(
    112       builder, iota_shape, /*iota_dimension=*/iota_dimensions.size() - 1);
    113 
    114   // Divide 'expanded_feature_iota' by the depthwise_multiplier to create
    115   // [0 0 1 1 2 2] ... in the example in the function comment.
    116   expanded_feature_iota =
    117       xla::Div(expanded_feature_iota,
    118                XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32,
    119                                           depthwise_multiplier));
    120 
    121   // Compare 'input_feature_iota' with 'expanded_feature_iota' to create a
    122   // diagonal predicate.
    123   return xla::Eq(expanded_feature_iota, input_feature_iota);
    124 }
    125 
    126 // Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to
    127 // build a depthwise convolution.
    128 xla::XlaOp ReshapeFilterForDepthwiseConvolution(const xla::Shape& filter_shape,
    129                                                 const xla::XlaOp& filter) {
    130   int64 input_feature_dim = filter_shape.dimensions_size() - 2;
    131   int64 output_feature_dim = filter_shape.dimensions_size() - 1;
    132   int64 depthwise_multiplier = filter_shape.dimensions(output_feature_dim);
    133   int64 input_feature = filter_shape.dimensions(input_feature_dim);
    134 
    135   // Create a [H, W, ..., 1, N*M] reshape of the filter.
    136   xla::Shape implicit_broadcast_filter_shape = filter_shape;
    137   implicit_broadcast_filter_shape.set_dimensions(input_feature_dim, 1);
    138   implicit_broadcast_filter_shape.set_dimensions(
    139       output_feature_dim, depthwise_multiplier * input_feature);
    140   return xla::Reshape(
    141       filter, xla::AsInt64Slice(implicit_broadcast_filter_shape.dimensions()));
    142 }
    143 
    144 // Reduces the results of the convolution with an expanded filter to the
    145 // non-expanded filter.
    146 xla::XlaOp ContractFilterForDepthwiseBackprop(const xla::Shape& filter_shape,
    147                                               const xla::XlaOp& filter_backprop,
    148                                               xla::XlaBuilder* builder) {
    149   auto masked_expanded_filter =
    150       xla::Select(CreateExpandedFilterMask(filter_shape, builder),
    151                   filter_backprop, xla::ZerosLike(filter_backprop));
    152 
    153   auto elem_type = filter_shape.element_type();
    154   return xla::Reshape(
    155       // This reduce does not need inputs to be converted with
    156       // XlaHelpers::SumAccumulationType() since the select above guarantees
    157       // that only one element is non zero, so there cannot be accumulated
    158       // precision error.
    159       xla::Reduce(masked_expanded_filter, xla::Zero(builder, elem_type),
    160                   CreateScalarAddComputation(elem_type, builder),
    161                   {filter_shape.dimensions_size() - 2}),
    162       xla::AsInt64Slice(filter_shape.dimensions()));
    163 }
    164 
    165 // Performs some basic checks on ConvOpAttrs that are true for all kinds of XLA
    166 // convolutions (as currently implemented).
    167 Status CheckConvAttrs(const ConvOpAttrs& attrs) {
    168   const int num_dims = attrs.num_spatial_dims + 2;
    169   if (attrs.strides.size() != num_dims) {
    170     return errors::InvalidArgument("Sliding window strides field must specify ",
    171                                    num_dims, " dimensions");
    172   }
    173   int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
    174   int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
    175   if (attrs.strides[batch_dim] != 1 || attrs.strides[feature_dim] != 1) {
    176     return errors::Unimplemented(
    177         "Current implementation does not yet support strides in the batch and "
    178         "depth dimensions.");
    179   }
    180   if (attrs.dilations.size() != num_dims) {
    181     return errors::InvalidArgument("Dilations field must specify ", num_dims,
    182                                    " dimensions");
    183   }
    184   if (attrs.dilations[batch_dim] != 1 || attrs.dilations[feature_dim] != 1) {
    185     return errors::Unimplemented(
    186         "Current implementation does not support dilations in the batch and "
    187         "depth dimensions.");
    188   }
    189   for (int i = 0; i < attrs.num_spatial_dims; ++i) {
    190     int input_dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
    191     if (attrs.dilations[input_dim] < 1) {
    192       return errors::Unimplemented("Dilation values must be positive; ", i,
    193                                    "th spatial dimension had dilation ",
    194                                    attrs.dilations[input_dim]);
    195     }
    196   }
    197   return Status::OK();
    198 }
    199 
    200 // Wrapper around ConvBackpropComputeDimensions that converts from XLA shapes
    201 // to TensorShapes.
    202 Status ConvBackpropComputeDimensionsV2XlaShapes(
    203     StringPiece label, int num_spatial_dims, const xla::Shape& input_shape,
    204     const xla::Shape& filter_shape, const xla::Shape& out_backprop_shape,
    205     absl::Span<const int32> dilations, const std::vector<int32>& strides,
    206     Padding padding, TensorFormat data_format, ConvBackpropDimensions* dims,
    207     absl::Span<const int64> explicit_paddings) {
    208   TensorShape input_tensor_shape, filter_tensor_shape,
    209       out_backprop_tensor_shape;
    210   TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape));
    211   TF_RETURN_IF_ERROR(XLAShapeToTensorShape(filter_shape, &filter_tensor_shape));
    212   TF_RETURN_IF_ERROR(
    213       XLAShapeToTensorShape(out_backprop_shape, &out_backprop_tensor_shape));
    214   return ConvBackpropComputeDimensionsV2(
    215       label, num_spatial_dims, input_tensor_shape, filter_tensor_shape,
    216       out_backprop_tensor_shape, dilations, strides, padding, explicit_paddings,
    217       data_format, dims);
    218 }
    219 
    220 }  // anonymous namespace
    221 
    222 xla::StatusOr<ConvOpAttrs> ConvOpAttrs::Create(int num_spatial_dims,
    223                                                bool depthwise,
    224                                                OpKernelConstruction* ctx) {
    225   ConvOpAttrs attrs;
    226   attrs.num_spatial_dims = num_spatial_dims;
    227   attrs.depthwise = depthwise;
    228   TF_RETURN_IF_ERROR(ctx->GetAttr("dilations", &attrs.dilations));
    229   TF_RETURN_IF_ERROR(ctx->GetAttr("strides", &attrs.strides));
    230   TF_RETURN_IF_ERROR(ctx->GetAttr("padding", &attrs.padding));
    231   if (attrs.padding == EXPLICIT) {
    232     TF_RETURN_IF_ERROR(
    233         ctx->GetAttr("explicit_paddings", &attrs.explicit_paddings));
    234   }
    235 
    236   string data_format;
    237   TF_RETURN_IF_ERROR(ctx->GetAttr("data_format", &data_format));
    238   if (!FormatFromString(data_format, &attrs.data_format)) {
    239     return errors::InvalidArgument("Invalid data format: ", data_format);
    240   }
    241 
    242   return attrs;
    243 }
    244 
    245 xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece /*type_string*/,
    246                                                xla::XlaOp conv_input,
    247                                                xla::XlaOp filter,
    248                                                const ConvOpAttrs& attrs) {
    249   TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
    250 
    251   auto* builder = conv_input.builder();
    252   TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(conv_input));
    253   // Filter has the form [filter_rows, filter_cols, ..., in_depth, out_depth]
    254   TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter));
    255 
    256   // For 2D convolution, there should be 4 dimensions.
    257   int num_dims = attrs.num_spatial_dims + 2;
    258   if (input_shape.dimensions_size() != num_dims) {
    259     return errors::InvalidArgument("input must be ", num_dims, "-dimensional",
    260                                    input_shape.DebugString());
    261   }
    262   if (filter_shape.dimensions_size() != num_dims) {
    263     return errors::InvalidArgument(
    264         "filter must be ", num_dims,
    265         "-dimensional: ", filter_shape.DebugString());
    266   }
    267 
    268   // The last two dimensions of the filter are the input and output shapes.
    269   int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
    270   int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
    271 
    272   int64 in_depth = filter_shape.dimensions(attrs.num_spatial_dims);
    273   // The 'C' dimension for input is in_depth. It must be the same as
    274   // the filter's in_depth.
    275   if (in_depth != input_shape.dimensions(feature_dim)) {
    276     return errors::InvalidArgument(
    277         "input and filter must have the same depth: ", in_depth, " vs ",
    278         input_shape.dimensions(feature_dim));
    279   }
    280 
    281   if (attrs.depthwise) {
    282     filter = ReshapeFilterForDepthwiseConvolution(filter_shape, filter);
    283   }
    284 
    285   xla::ConvolutionDimensionNumbers dims;
    286   std::vector<int64> window_strides(attrs.num_spatial_dims);
    287   std::vector<int64> lhs_dilation(attrs.num_spatial_dims, 1);
    288   std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
    289   std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
    290 
    291   dims.set_input_batch_dimension(batch_dim);
    292   dims.set_output_batch_dimension(batch_dim);
    293   dims.set_input_feature_dimension(feature_dim);
    294   dims.set_output_feature_dimension(feature_dim);
    295   dims.set_kernel_input_feature_dimension(attrs.num_spatial_dims);
    296   dims.set_kernel_output_feature_dimension(attrs.num_spatial_dims + 1);
    297 
    298   for (int i = 0; i < attrs.num_spatial_dims; ++i) {
    299     const int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
    300     dims.add_input_spatial_dimensions(dim);
    301     dims.add_kernel_spatial_dimensions(i);
    302     dims.add_output_spatial_dimensions(dim);
    303     window_strides[i] = attrs.strides.at(dim);
    304     rhs_dilation[i] = attrs.dilations.at(dim);
    305 
    306     if (attrs.padding == EXPLICIT) {
    307       padding[i] = {attrs.explicit_paddings.at(dim * 2),
    308                     attrs.explicit_paddings.at(dim * 2 + 1)};
    309     }
    310 
    311     int64 unused_output_size;
    312     TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2(
    313         input_shape.dimensions(dim), filter_shape.dimensions(i),
    314         rhs_dilation[i], window_strides[i], attrs.padding, &unused_output_size,
    315         &padding[i].first, &padding[i].second));
    316   }
    317 
    318   return xla::ConvGeneralDilated(
    319       conv_input, filter, window_strides, padding, lhs_dilation, rhs_dilation,
    320       dims, /*feature_group_count=*/attrs.depthwise ? in_depth : 1);
    321 }
    322 
    323 xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
    324     StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter,
    325     xla::XlaOp out_backprop, const ConvOpAttrs& attrs) {
    326   TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
    327 
    328   int num_dims = attrs.num_spatial_dims + 2;
    329   int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
    330   int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
    331 
    332   auto* builder = filter.builder();
    333   TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter));
    334   TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape,
    335                       builder->GetShape(out_backprop));
    336 
    337   xla::Shape expanded_filter_shape =
    338       attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
    339                       : filter_shape;
    340   // Reuse dimension computation logic from conv_grad_ops.cc.
    341   ConvBackpropDimensions dims;
    342   TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
    343       type_string, attrs.num_spatial_dims, input_shape, expanded_filter_shape,
    344       out_backprop_shape, attrs.dilations, attrs.strides, attrs.padding,
    345       attrs.data_format, &dims, attrs.explicit_paddings));
    346 
    347   // The input gradients are computed by a convolution of the output
    348   // gradients and the filter, with some appropriate padding. See the
    349   // comment at the top of conv_grad_ops.h for details.
    350 
    351   xla::ConvolutionDimensionNumbers dnums;
    352   dnums.set_input_batch_dimension(batch_dim);
    353   dnums.set_output_batch_dimension(batch_dim);
    354   dnums.set_input_feature_dimension(feature_dim);
    355   dnums.set_output_feature_dimension(feature_dim);
    356 
    357   // TF filter shape is [ H, W, ..., inC, outC ]
    358   // Transpose the input and output features for computing the gradient.
    359   dnums.set_kernel_input_feature_dimension(attrs.num_spatial_dims + 1);
    360   dnums.set_kernel_output_feature_dimension(attrs.num_spatial_dims);
    361 
    362   std::vector<int64> kernel_spatial_dims(attrs.num_spatial_dims);
    363   std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
    364   std::vector<int64> lhs_dilation(attrs.num_spatial_dims);
    365   std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
    366   std::vector<int64> ones(attrs.num_spatial_dims, 1);
    367   for (int i = 0; i < attrs.num_spatial_dims; ++i) {
    368     int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
    369     dnums.add_input_spatial_dimensions(dim);
    370     dnums.add_kernel_spatial_dimensions(i);
    371     dnums.add_output_spatial_dimensions(dim);
    372 
    373     kernel_spatial_dims[i] = i;
    374     padding[i] = {dims.spatial_dims[i].pad_before,
    375                   dims.spatial_dims[i].pad_after};
    376     lhs_dilation[i] = dims.spatial_dims[i].stride;
    377     rhs_dilation[i] = attrs.dilations[dim];
    378   }
    379 
    380   // Mirror the filter in the spatial dimensions.
    381   xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims);
    382 
    383   // activation gradients
    384   //   = gradients (with padding and dilation) <conv> mirrored_weights
    385   return xla::ConvGeneralDilated(
    386       out_backprop, mirrored_weights, /*window_strides=*/ones, padding,
    387       lhs_dilation, rhs_dilation, dnums,
    388       /*feature_group_count=*/
    389       attrs.depthwise ? out_backprop_shape.dimensions(feature_dim) /
    390                             filter_shape.dimensions(attrs.num_spatial_dims + 1)
    391                       : 1);
    392 }
    393 
    394 xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
    395     StringPiece type_string, xla::XlaOp activations,
    396     const xla::Shape& filter_shape, xla::XlaOp gradients,
    397     const ConvOpAttrs& attrs) {
    398   TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
    399 
    400   auto* builder = activations.builder();
    401   TF_ASSIGN_OR_RETURN(xla::Shape activations_shape,
    402                       builder->GetShape(activations));
    403   TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape,
    404                       builder->GetShape(gradients));
    405   xla::XlaOp filter_backprop;
    406 
    407   xla::Shape input_shape = activations_shape;
    408   xla::Shape output_shape = out_backprop_shape;
    409 
    410   TensorShape input_tensor_shape, filter_tensor_shape, output_tensor_shape;
    411   TF_RETURN_IF_ERROR(XLAShapeToTensorShape(filter_shape, &filter_tensor_shape));
    412   TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape));
    413   TF_RETURN_IF_ERROR(XLAShapeToTensorShape(output_shape, &output_tensor_shape));
    414 
    415   const xla::Shape expanded_filter_shape =
    416       attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
    417                       : filter_shape;
    418   // Reuse dimension computation logic from conv_grad_ops.cc.
    419   ConvBackpropDimensions dims;
    420   // The filter gradients are computed by a convolution of the input
    421   // activations and the output gradients, with some appropriate padding.
    422   // See the comment at the top of conv_grad_ops.h for details.
    423   xla::ConvolutionDimensionNumbers dnums;
    424 
    425   TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
    426       type_string, attrs.num_spatial_dims, activations_shape,
    427       expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides,
    428       attrs.padding, attrs.data_format, &dims, attrs.explicit_paddings));
    429 
    430   // The activations (inputs) form the LHS of the convolution.
    431   // Activations have shape: [batch, in_rows, in_cols, ..., in_depth]
    432   // For the gradient computation, we flip the roles of the batch and
    433   // feature dimensions.
    434   // Each spatial entry has size in_depth * batch
    435 
    436   // The last two dimensions of the filter are the input and output shapes.
    437   int num_dims = attrs.num_spatial_dims + 2;
    438   int n_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
    439   int c_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
    440 
    441   bool use_batch_group_count =
    442       filter_tensor_shape.dim_size(num_dims - 1) == 1 && attrs.depthwise;
    443 
    444   std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
    445   std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
    446   std::vector<int64> window_strides(attrs.num_spatial_dims);
    447   std::vector<int64> ones(attrs.num_spatial_dims, 1);
    448 
    449   // Swap n_dim and c_dim in the activations.
    450   dnums.set_input_batch_dimension(c_dim);
    451   dnums.set_input_feature_dimension(n_dim);
    452 
    453   // The gradients become the RHS of the convolution.
    454   // The gradients have shape [batch, out_rows, out_cols, ..., out_depth]
    455   // where the batch becomes the input feature for the convolution.
    456   dnums.set_kernel_input_feature_dimension(n_dim);
    457   dnums.set_kernel_output_feature_dimension(c_dim);
    458 
    459   // The dimension swap below is needed because filter shape is KH,KW,F,DM.
    460   if (use_batch_group_count) {
    461     dnums.set_output_batch_dimension(attrs.num_spatial_dims + 1);
    462     dnums.set_output_feature_dimension(attrs.num_spatial_dims);
    463   } else {
    464     dnums.set_output_batch_dimension(attrs.num_spatial_dims);
    465     dnums.set_output_feature_dimension(attrs.num_spatial_dims + 1);
    466   }
    467 
    468   // Tensorflow filter shape is [ H, W, ..., inC, outC ].
    469   for (int i = 0; i < attrs.num_spatial_dims; ++i) {
    470     dnums.add_output_spatial_dimensions(i);
    471   }
    472 
    473   for (int64 i = 0; i < attrs.num_spatial_dims; ++i) {
    474     int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
    475     dnums.add_input_spatial_dimensions(dim);
    476     dnums.add_kernel_spatial_dimensions(dim);
    477     rhs_dilation[i] = dims.spatial_dims[i].stride;
    478     window_strides[i] = attrs.dilations[dim];
    479 
    480     // We will also need to pad the input with zeros such that after the
    481     // convolution, we get the right size for the filter.
    482     // The padded_in_rows should be such that when we convolve this with the
    483     // expanded_out_rows as a filter, we should get filter_rows back.
    484 
    485     const int64 padded_in_size =
    486         dims.spatial_dims[i].expanded_output_size +
    487         (dims.spatial_dims[i].filter_size - 1) * attrs.dilations[dim];
    488 
    489     // However it can be smaller than input_rows: in this
    490     // case it means some of the inputs are not used.
    491     //
    492     // An example is to have input_cols = 3, filter_cols = 2 and stride = 2:
    493     //
    494     // INPUT =  [ A  B  C ]
    495     //
    496     // FILTER = [ x y ]
    497     //
    498     // and the output will only have one column: a = A * x + B * y
    499     //
    500     // and input "C" is not used at all.
    501     //
    502     // We apply negative padding in this case.
    503     const int64 pad_total = padded_in_size - dims.spatial_dims[i].input_size;
    504 
    505     // + For the EXPLICIT padding, we pad the top/left side with the explicit
    506     //   padding and pad the bottom/right side with the remaining space.
    507     // + For the VALID padding, we don't pad anything on the top/left side
    508     //   and pad the bottom/right side with the remaining space.
    509     // + For the SAME padding, we pad top/left side the same as bottom/right
    510     //   side.
    511     //
    512     // In addition, if the padded input size is smaller than the input size,
    513     // we need to ignore some training elements of the input. We do this by
    514     // applying negative padding on the right/bottom.
    515     const int64 pad_before = attrs.padding == Padding::EXPLICIT
    516                                  ? attrs.explicit_paddings[2 * dim]
    517                                  : attrs.padding == Padding::SAME
    518                                        ? std::max<int64>(pad_total / 2, 0)
    519                                        : 0;
    520     padding[i] = {pad_before, pad_total - pad_before};
    521   }
    522 
    523   // Besides padding the input, we will also expand output_rows to
    524   //    expanded_out_rows = (output_rows - 1) * stride + 1
    525   // with zeros in between:
    526   //
    527   //      a . . . b . . . c . . . d . . . e
    528   //
    529   // This is done by specifying the window dilation factors in the
    530   // convolution HLO below.
    531 
    532   filter_backprop = xla::ConvGeneralDilated(
    533       activations, gradients, window_strides, padding, /*lhs_dilation=*/ones,
    534       rhs_dilation, dnums,
    535       /*feature_group_count=*/1,
    536       /*batch_group_count=*/use_batch_group_count ? dims.in_depth : 1);
    537 
    538   if (!use_batch_group_count && attrs.depthwise) {
    539     filter_backprop = ContractFilterForDepthwiseBackprop(
    540         filter_shape, filter_backprop, activations.builder());
    541   }
    542 
    543   return filter_backprop;
    544 }
    545 
    546 }  // namespace tensorflow
    547