Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // XLA-specific Ops for 2D convolution.
     17 
     18 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
     19 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
     20 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     21 #include "tensorflow/compiler/xla/literal_util.h"
     22 #include "tensorflow/core/framework/numeric_op.h"
     23 #include "tensorflow/core/framework/op_kernel.h"
     24 #include "tensorflow/core/framework/tensor.h"
     25 #include "tensorflow/core/framework/tensor_shape.h"
     26 #include "tensorflow/core/framework/tensor_slice.h"
     27 #include "tensorflow/core/kernels/bounds_check.h"
     28 #include "tensorflow/core/kernels/conv_grad_ops.h"
     29 #include "tensorflow/core/kernels/ops_util.h"
     30 #include "tensorflow/core/util/padding.h"
     31 #include "tensorflow/core/util/tensor_format.h"
     32 
     33 namespace tensorflow {
     34 
     35 namespace {
     36 
     37 // Returns the expanded size of a filter used for depthwise convolution.
     38 // If `shape` is [H, W, ..., M, N] returns [H, W, ..., M, M*N].
     39 TensorShape ExpandedFilterShapeForDepthwiseConvolution(
     40     const TensorShape& shape) {
     41   int num_dims = shape.dims();
     42   CHECK_GE(num_dims, 2);
     43   TensorShape expanded_shape = shape;
     44   expanded_shape.set_dim(num_dims - 1, shape.dim_size(num_dims - 2) *
     45                                            shape.dim_size(num_dims - 1));
     46   return expanded_shape;
     47 }
     48 
     49 // Broadcast zeros to ExpandedFilterShapeForDepthwiseConvolution.
     50 xla::ComputationDataHandle CreateExpandedZero(
     51     const TensorShape& filter_shape, DataType dtype,
     52     xla::ComputationBuilder* builder) {
     53   TensorShape expanded_filter_shape =
     54       ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
     55   return builder->Broadcast(XlaHelpers::Zero(builder, dtype),
     56                             expanded_filter_shape.dim_sizes());
     57 }
     58 
     59 // Create a mask for depthwise convolution that will make a normal convolution
     60 // produce the same results as a depthwise convolution. For a [2, 2, 3, 2]
     61 // depthwise filter this returns a [2, 2, 3, 6] tesnsor
     62 //   1 1 0 0 0 0   1 1 0 0 0 0
     63 //   0 0 1 1 0 0   0 0 1 1 0 0
     64 //   0 0 0 0 1 1   0 0 0 0 1 1
     65 //
     66 //   1 1 0 0 0 0   1 1 0 0 0 0
     67 //   0 0 1 1 0 0   0 0 1 1 0 0
     68 //   0 0 0 0 1 1   0 0 0 0 1 1
     69 //
     70 // The first step is to create a one tensor, A, that is [3]
     71 //   0 1 2
     72 //
     73 // and another tensor, B,  that is [3 * 2]
     74 //   0 1 2 3 4 5
     75 //
     76 // and divide B it by 2 to get
     77 //   0 0 1 1 2 2
     78 //
     79 // then we broadcast the B to [2, 2, 3, 3 * 2]
     80 //   0 0 1 1 2 2   0 0 1 1 2 2
     81 //   0 0 1 1 2 2   0 0 1 1 2 2
     82 //   0 0 1 1 2 2   0 0 1 1 2 2
     83 //
     84 //   0 0 1 1 2 2   0 0 1 1 2 2
     85 //   0 0 1 1 2 2   0 0 1 1 2 2
     86 //   0 0 1 1 2 2   0 0 1 1 2 2
     87 //
     88 // Finally compare A and broadcasted B in dimension 2 amd return the result at
     89 // the beginning of the comment.
     90 xla::ComputationDataHandle CreateExpandedFilterMask(
     91     const TensorShape& filter_shape, xla::ComputationBuilder* builder) {
     92   TensorShape expanded_filter_shape =
     93       ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
     94   int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1);
     95   int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2);
     96 
     97   // Create a M sized linspace and an M*N sized linspace that will be
     98   // broadcasted into perpendicular dimensions and compared.
     99   xla::ComputationDataHandle input_feature_iota;
    100   // DT_INT32 Iota will always return status::OK().
    101   TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, input_feature,
    102                                &input_feature_iota));
    103   xla::ComputationDataHandle expanded_feature_iota;
    104   TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32,
    105                                input_feature * depthwise_multiplier,
    106                                &expanded_feature_iota));
    107 
    108   // Divide the M*N sized linspace by the depthwise_multiplier to create
    109   // [0 0 1 1 2 2] in the example in the function comment.
    110   expanded_feature_iota =
    111       builder->Div(expanded_feature_iota,
    112                    XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32,
    113                                               depthwise_multiplier));
    114 
    115   // Broadcast the N*M linspace to [H, W, ..., M, M*N].
    116   auto expanded_feature_broadcast_dims = expanded_filter_shape.dim_sizes();
    117   expanded_feature_broadcast_dims.pop_back();
    118   auto broadcasted_expanded_feature_iota = builder->Broadcast(
    119       expanded_feature_iota, expanded_feature_broadcast_dims);
    120 
    121   // Compare the broadcasted linspace to the input feature linspace in the
    122   // input feature dimension to create a diagonal predicate.
    123   return builder->Eq(broadcasted_expanded_feature_iota, input_feature_iota,
    124                      {expanded_filter_shape.dims() - 2});
    125 }
    126 
    127 // Expands a filter of shape [H, W, ..., M, N] to [H, W, ..., M, M*N] by adding
    128 // zeros for the cross-depth filters. Used to build a depthwise convolution.
    129 xla::ComputationDataHandle ExpandFilterForDepthwiseConvolution(
    130     const TensorShape& filter_shape, DataType dtype,
    131     const xla::ComputationDataHandle& filter,
    132     xla::ComputationBuilder* builder) {
    133   int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1);
    134   int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2);
    135   TensorShape expanded_filter_shape =
    136       ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
    137 
    138   // Create a [H, W, ..., 1, N*M] reshape of the filter.
    139   TensorShape implicit_broadcast_filter_shape = expanded_filter_shape;
    140   implicit_broadcast_filter_shape.set_dim(
    141       implicit_broadcast_filter_shape.dims() - 2, 1);
    142   implicit_broadcast_filter_shape.set_dim(
    143       implicit_broadcast_filter_shape.dims() - 1,
    144       depthwise_multiplier * input_feature);
    145   auto implicit_broadcast_filter =
    146       builder->Reshape(filter, implicit_broadcast_filter_shape.dim_sizes());
    147 
    148   // Broadcast the filter to  [H, W, ..., M, M*N].
    149   auto expanded_zero = CreateExpandedZero(filter_shape, dtype, builder);
    150   auto expanded_filter = builder->Add(implicit_broadcast_filter, expanded_zero);
    151 
    152   // If the filter mask is set, choose the broadcasted filter, othwerwise,
    153   // choose zero.
    154   return builder->Select(CreateExpandedFilterMask(filter_shape, builder),
    155                          expanded_filter, expanded_zero);
    156 }
    157 
    158 // Inverse of ExpandFilterForDepthwiseConvolution.
    159 xla::ComputationDataHandle ContractFilterForDepthwiseBackprop(
    160     XlaOpKernelContext* ctx, const TensorShape& filter_shape, DataType dtype,
    161     const xla::ComputationDataHandle& filter_backprop,
    162     xla::ComputationBuilder* builder) {
    163   TensorShape expanded_filter_shape =
    164       ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
    165   auto masked_expanded_filter = builder->Select(
    166       CreateExpandedFilterMask(filter_shape, builder), filter_backprop,
    167       CreateExpandedZero(filter_shape, dtype, builder));
    168   return builder->Reshape(
    169       builder->Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype),
    170                       *ctx->GetOrCreateAdd(dtype),
    171                       {expanded_filter_shape.dims() - 2}),
    172       filter_shape.dim_sizes());
    173 }
    174 
    175 class ConvOp : public XlaOpKernel {
    176  public:
    177   explicit ConvOp(OpKernelConstruction* ctx, int num_spatial_dims,
    178                   bool depthwise)
    179       : XlaOpKernel(ctx),
    180         num_spatial_dims_(num_spatial_dims),
    181         depthwise_(depthwise) {
    182     OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_));
    183     OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
    184     OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
    185 
    186     string data_format;
    187     OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
    188     OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
    189                 errors::InvalidArgument("Invalid data format"));
    190   }
    191 
    192   int num_dims() const { return num_spatial_dims_ + 2; }
    193 
    194   void Compile(XlaOpKernelContext* ctx) override {
    195     OP_REQUIRES(ctx, strides_.size() == num_dims(),
    196                 errors::InvalidArgument("Sliding window strides field must "
    197                                         "specify ",
    198                                         num_dims(), " dimensions"));
    199     int batch_dim = GetTensorBatchDimIndex(num_dims(), data_format_);
    200     int feature_dim = GetTensorFeatureDimIndex(num_dims(), data_format_);
    201     OP_REQUIRES(
    202         ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1,
    203         errors::Unimplemented("Current implementation does not yet support "
    204                               "strides in the batch and depth dimensions."));
    205 
    206     OP_REQUIRES(ctx, dilations_.size() == num_dims(),
    207                 errors::InvalidArgument("Dilations field must "
    208                                         "specify ",
    209                                         num_dims(), " dimensions"));
    210     OP_REQUIRES(
    211         ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1,
    212         errors::Unimplemented("Current implementation does not support "
    213                               "dilations in the batch and depth dimensions."));
    214     for (int i = 0; i < num_spatial_dims_; ++i) {
    215       int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
    216       OP_REQUIRES(ctx, dilations_[input_dim] >= 1,
    217                   errors::Unimplemented("Dilation values must be positive; ", i,
    218                                         "th spatial dimension had dilation ",
    219                                         dilations_[input_dim]));
    220     }
    221 
    222     const TensorShape input_shape = ctx->InputShape(0);
    223     // Input filter is of the following dimensions:
    224     // [ filter_rows, filter_cols, ..., in_depth, out_depth]
    225     const TensorShape filter_shape = ctx->InputShape(1);
    226 
    227     // For 2D convolution, there should be 4 dimensions.
    228     OP_REQUIRES(
    229         ctx, input_shape.dims() == num_dims(),
    230         errors::InvalidArgument("input must be ", num_dims(), "-dimensional",
    231                                 input_shape.DebugString()));
    232     OP_REQUIRES(
    233         ctx, filter_shape.dims() == num_dims(),
    234         errors::InvalidArgument("filter must be ", num_dims(),
    235                                 "-dimensional: ", filter_shape.DebugString()));
    236 
    237     // The last two dimension of the filter are the input and output shapes.
    238     const int64 in_depth = filter_shape.dim_size(num_spatial_dims_);
    239 
    240     // The 'C' dimension for input is in_depth. It must be the same as
    241     // the filter's in_depth.
    242     OP_REQUIRES(ctx, in_depth == input_shape.dim_size(feature_dim),
    243                 errors::InvalidArgument(
    244                     "input and filter must have the same depth: ", in_depth,
    245                     " vs ", input_shape.dim_size(feature_dim)));
    246 
    247     xla::ComputationBuilder* b = ctx->builder();
    248 
    249     xla::ComputationDataHandle filter = ctx->Input(1);
    250     TensorShape expanded_filter_shape = filter_shape;
    251     if (depthwise_) {
    252       filter = ExpandFilterForDepthwiseConvolution(
    253           filter_shape, ctx->input_type(0), filter, b);
    254       expanded_filter_shape =
    255           ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
    256     }
    257 
    258     xla::ConvolutionDimensionNumbers dims;
    259     std::vector<int64> window_strides(num_spatial_dims_);
    260     std::vector<int64> lhs_dilation(num_spatial_dims_, 1);
    261     std::vector<int64> rhs_dilation(num_spatial_dims_);
    262     std::vector<std::pair<int64, int64>> padding(num_spatial_dims_);
    263 
    264     dims.set_input_batch_dimension(batch_dim);
    265     dims.set_output_batch_dimension(batch_dim);
    266     dims.set_input_feature_dimension(feature_dim);
    267     dims.set_output_feature_dimension(feature_dim);
    268     dims.set_kernel_input_feature_dimension(num_spatial_dims_);
    269     dims.set_kernel_output_feature_dimension(num_spatial_dims_ + 1);
    270 
    271     for (int i = 0; i < num_spatial_dims_; ++i) {
    272       const int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
    273       dims.add_input_spatial_dimensions(dim);
    274       dims.add_kernel_spatial_dimensions(i);
    275       dims.add_output_spatial_dimensions(dim);
    276       window_strides[i] = strides_.at(dim);
    277       rhs_dilation[i] = dilations_.at(dim);
    278 
    279       int64 unused_output_size;
    280       OP_REQUIRES_OK(
    281           ctx, GetWindowedOutputSizeVerboseV2(
    282                    input_shape.dim_size(dim), expanded_filter_shape.dim_size(i),
    283                    rhs_dilation[i], window_strides[i], padding_,
    284                    &unused_output_size, &padding[i].first, &padding[i].second));
    285     }
    286 
    287     xla::ComputationDataHandle conv =
    288         b->ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding,
    289                               lhs_dilation, rhs_dilation, dims);
    290     ctx->SetOutput(0, conv);
    291   }
    292 
    293  protected:
    294   const int num_spatial_dims_;
    295   const bool depthwise_;
    296   std::vector<int32> dilations_;
    297   std::vector<int32> strides_;
    298   Padding padding_;
    299   TensorFormat data_format_ = FORMAT_NHWC;
    300 
    301  private:
    302   TF_DISALLOW_COPY_AND_ASSIGN(ConvOp);
    303 };
    304 
    305 class Conv2DOp : public ConvOp {
    306  public:
    307   explicit Conv2DOp(OpKernelConstruction* ctx)
    308       : ConvOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) {}
    309 };
    310 REGISTER_XLA_OP(Name("Conv2D"), Conv2DOp);
    311 
    312 class Conv3DOp : public ConvOp {
    313  public:
    314   explicit Conv3DOp(OpKernelConstruction* ctx)
    315       : ConvOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) {}
    316 };
    317 REGISTER_XLA_OP(Name("Conv3D"), Conv3DOp);
    318 
    319 class DepthwiseConv2DOp : public ConvOp {
    320  public:
    321   explicit DepthwiseConv2DOp(OpKernelConstruction* ctx)
    322       : ConvOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {}
    323 };
    324 REGISTER_XLA_OP(Name("DepthwiseConv2dNative"), DepthwiseConv2DOp);
    325 
    326 // Backprop for input.
    327 class ConvBackpropInputOp : public XlaOpKernel {
    328  public:
    329   explicit ConvBackpropInputOp(OpKernelConstruction* ctx, int num_spatial_dims,
    330                                bool depthwise)
    331       : XlaOpKernel(ctx),
    332         num_spatial_dims_(num_spatial_dims),
    333         depthwise_(depthwise) {
    334     OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_));
    335     OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
    336     OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
    337     string data_format;
    338     OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
    339     OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
    340                 errors::InvalidArgument("Invalid data format"));
    341   }
    342 
    343   int num_dims() const { return num_spatial_dims_ + 2; }
    344 
    345   void Compile(XlaOpKernelContext* ctx) override {
    346     OP_REQUIRES(ctx, strides_.size() == num_dims(),
    347                 errors::InvalidArgument("Sliding window strides field must "
    348                                         "specify ",
    349                                         num_dims(), " dimensions"));
    350     int batch_dim = GetTensorBatchDimIndex(num_dims(), data_format_);
    351     int feature_dim = GetTensorFeatureDimIndex(num_dims(), data_format_);
    352     OP_REQUIRES(
    353         ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1,
    354         errors::Unimplemented("Current implementation does not yet support "
    355                               "strides in the batch and depth dimensions."));
    356 
    357     OP_REQUIRES(ctx, dilations_.size() == num_dims(),
    358                 errors::InvalidArgument("Dilations field must "
    359                                         "specify ",
    360                                         num_dims(), " dimensions"));
    361     OP_REQUIRES(
    362         ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1,
    363         errors::Unimplemented("Current implementation does not support "
    364                               "dilations in the batch and depth dimensions."));
    365     for (int i = 0; i < num_spatial_dims_; ++i) {
    366       int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
    367       OP_REQUIRES(ctx, dilations_[input_dim] >= 1,
    368                   errors::Unimplemented("Dilation values must be positive; ", i,
    369                                         "th spatial dimension had dilation ",
    370                                         dilations_[input_dim]));
    371     }
    372 
    373     TensorShape input_shape;
    374     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape));
    375 
    376     const TensorShape filter_shape = ctx->InputShape(1);
    377     const TensorShape out_backprop_shape = ctx->InputShape(2);
    378 
    379     const TensorShape expanded_filter_shape =
    380         depthwise_ ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
    381                    : filter_shape;
    382     // Reuse dimension computation logic from conv_grad_ops.cc.
    383     ConvBackpropDimensions dims;
    384     OP_REQUIRES_OK(ctx,
    385                    ConvBackpropComputeDimensionsV2(
    386                        type_string(), num_spatial_dims_, input_shape,
    387                        expanded_filter_shape, out_backprop_shape, dilations_,
    388                        strides_, padding_, data_format_, &dims));
    389 
    390     xla::ComputationBuilder* b = ctx->builder();
    391     auto filter = ctx->Input(1);
    392     auto out_backprop = ctx->Input(2);
    393 
    394     // The input gradients are computed by a convolution of the output
    395     // gradients and the filter, with some appropriate padding. See the
    396     // comment at the top of conv_grad_ops.h for details.
    397 
    398     xla::ConvolutionDimensionNumbers dnums;
    399     dnums.set_input_batch_dimension(batch_dim);
    400     dnums.set_output_batch_dimension(batch_dim);
    401     dnums.set_input_feature_dimension(feature_dim);
    402     dnums.set_output_feature_dimension(feature_dim);
    403 
    404     // TF filter shape is [ H, W, ..., inC, outC ]
    405     // Transpose the input and output features for computing the gradient.
    406     dnums.set_kernel_input_feature_dimension(num_spatial_dims_ + 1);
    407     dnums.set_kernel_output_feature_dimension(num_spatial_dims_);
    408 
    409     std::vector<int64> kernel_spatial_dims(num_spatial_dims_);
    410     std::vector<std::pair<int64, int64>> padding(num_spatial_dims_);
    411     std::vector<int64> lhs_dilation(num_spatial_dims_);
    412     std::vector<int64> rhs_dilation(num_spatial_dims_);
    413     std::vector<int64> ones(num_spatial_dims_, 1);
    414     for (int i = 0; i < num_spatial_dims_; ++i) {
    415       int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
    416       dnums.add_input_spatial_dimensions(dim);
    417       dnums.add_kernel_spatial_dimensions(i);
    418       dnums.add_output_spatial_dimensions(dim);
    419 
    420       kernel_spatial_dims[i] = i;
    421       padding[i] = {dims.spatial_dims[i].pad_before,
    422                     dims.spatial_dims[i].pad_after};
    423       lhs_dilation[i] = dims.spatial_dims[i].stride;
    424       rhs_dilation[i] = dilations_[dim];
    425     }
    426 
    427     // If this is a depthwise convolution, expand the filter.
    428     if (depthwise_) {
    429       filter = ExpandFilterForDepthwiseConvolution(
    430           filter_shape, ctx->input_type(1), filter, b);
    431     }
    432 
    433     // Mirror the filter in the spatial dimensions.
    434     xla::ComputationDataHandle mirrored_weights =
    435         b->Rev(filter, kernel_spatial_dims);
    436 
    437     // activation gradients
    438     //   = gradients (with padding and dilation) <conv> mirrored_weights
    439     xla::ComputationDataHandle in_backprop = b->ConvGeneralDilated(
    440         out_backprop, mirrored_weights, /*window_strides=*/ones, padding,
    441         lhs_dilation, rhs_dilation, dnums);
    442 
    443     ctx->SetOutput(0, in_backprop);
    444   }
    445 
    446  protected:
    447   const int num_spatial_dims_;
    448   const bool depthwise_;
    449   std::vector<int32> dilations_;
    450   std::vector<int32> strides_;
    451   Padding padding_;
    452   TensorFormat data_format_ = FORMAT_NHWC;
    453 
    454  private:
    455   TF_DISALLOW_COPY_AND_ASSIGN(ConvBackpropInputOp);
    456 };
    457 
    458 class Conv2DBackpropInputOp : public ConvBackpropInputOp {
    459  public:
    460   explicit Conv2DBackpropInputOp(OpKernelConstruction* ctx)
    461       : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) {}
    462 };
    463 REGISTER_XLA_OP(
    464     Name("Conv2DBackpropInput").CompileTimeConstInput("input_sizes"),
    465     Conv2DBackpropInputOp);
    466 
    467 class Conv3DBackpropInputOp : public ConvBackpropInputOp {
    468  public:
    469   explicit Conv3DBackpropInputOp(OpKernelConstruction* ctx)
    470       : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) {}
    471 };
    472 REGISTER_XLA_OP(
    473     Name("Conv3DBackpropInputV2").CompileTimeConstInput("input_sizes"),
    474     Conv3DBackpropInputOp);
    475 
    476 class DepthwiseConv2DBackpropInputOp : public ConvBackpropInputOp {
    477  public:
    478   explicit DepthwiseConv2DBackpropInputOp(OpKernelConstruction* ctx)
    479       : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {}
    480 };
    481 REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropInput")
    482                     .CompileTimeConstInput("input_sizes"),
    483                 DepthwiseConv2DBackpropInputOp);
    484 
    485 class ConvBackpropFilterOp : public XlaOpKernel {
    486  public:
    487   explicit ConvBackpropFilterOp(OpKernelConstruction* ctx, int num_spatial_dims,
    488                                 bool depthwise)
    489       : XlaOpKernel(ctx),
    490         num_spatial_dims_(num_spatial_dims),
    491         depthwise_(depthwise) {
    492     OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_));
    493     OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
    494     OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
    495     string data_format;
    496     OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
    497     OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
    498                 errors::InvalidArgument("Invalid data format"));
    499   }
    500 
    501   int num_dims() const { return num_spatial_dims_ + 2; }
    502 
    503   void Compile(XlaOpKernelContext* ctx) override {
    504     const int n_dim = GetTensorBatchDimIndex(num_dims(), data_format_);
    505     const int c_dim = GetTensorFeatureDimIndex(num_dims(), data_format_);
    506 
    507     OP_REQUIRES(
    508         ctx, (strides_[n_dim] == 1 && strides_[c_dim] == 1),
    509         errors::InvalidArgument("Current implementation does not yet support "
    510                                 "strides in the batch and depth dimensions."));
    511 
    512     OP_REQUIRES(ctx, dilations_.size() == num_dims(),
    513                 errors::InvalidArgument("Dilations field must "
    514                                         "specify ",
    515                                         num_dims(), " dimensions"));
    516     OP_REQUIRES(
    517         ctx, dilations_[n_dim] == 1 && dilations_[c_dim] == 1,
    518         errors::Unimplemented("Current implementation does not support "
    519                               "dilations in the batch and depth dimensions."));
    520     for (int i = 0; i < num_spatial_dims_; ++i) {
    521       int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
    522       OP_REQUIRES(ctx, dilations_[input_dim] >= 1,
    523                   errors::Unimplemented("Dilation values must be positive; ", i,
    524                                         "th spatial dimension had dilation ",
    525                                         dilations_[input_dim]));
    526     }
    527 
    528     const TensorShape activations_shape = ctx->InputShape(0);
    529     TensorShape filter_shape;
    530     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_shape));
    531     const TensorShape out_backprop_shape = ctx->InputShape(2);
    532 
    533     const TensorShape expanded_filter_shape =
    534         depthwise_ ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
    535                    : filter_shape;
    536 
    537     // Reuse dimension computation logic from conv_grad_ops.cc.
    538     ConvBackpropDimensions dims;
    539     OP_REQUIRES_OK(ctx,
    540                    ConvBackpropComputeDimensionsV2(
    541                        type_string(), num_spatial_dims_, activations_shape,
    542                        expanded_filter_shape, out_backprop_shape, dilations_,
    543                        strides_, padding_, data_format_, &dims));
    544 
    545     xla::ComputationBuilder* b = ctx->builder();
    546     xla::ComputationDataHandle activations = ctx->Input(0);
    547     xla::ComputationDataHandle gradients = ctx->Input(2);
    548 
    549     // The filter gradients are computed by a convolution of the input
    550     // activations and the output gradients, with some appropriate padding.
    551     // See the comment at the top of conv_grad_ops.h for details.
    552 
    553     xla::ConvolutionDimensionNumbers dnums;
    554 
    555     // The activations (inputs) form the LHS of the convolution.
    556     // Activations have shape: [batch, in_rows, in_cols, ..., in_depth]
    557     // For the gradient computation, we flip the roles of the batch and
    558     // feature dimensions.
    559     // Each spatial entry has size in_depth * batch
    560 
    561     // Swap n_dim and c_dim in the activations.
    562     dnums.set_input_batch_dimension(c_dim);
    563     dnums.set_input_feature_dimension(n_dim);
    564 
    565     // The gradients become the RHS of the convolution.
    566     // The gradients have shape [batch, out_rows, out_cols, ..., out_depth]
    567     // where the batch becomes the input feature for the convolution.
    568     dnums.set_kernel_input_feature_dimension(n_dim);
    569     dnums.set_kernel_output_feature_dimension(c_dim);
    570 
    571     std::vector<std::pair<int64, int64>> padding(num_spatial_dims_);
    572     std::vector<int64> rhs_dilation(num_spatial_dims_);
    573     std::vector<int64> window_strides(num_spatial_dims_);
    574     std::vector<int64> ones(num_spatial_dims_, 1);
    575 
    576     // Tensorflow filter shape is [ H, W, ..., inC, outC ].
    577     for (int i = 0; i < num_spatial_dims_; ++i) {
    578       dnums.add_output_spatial_dimensions(i);
    579     }
    580     dnums.set_output_batch_dimension(num_spatial_dims_);
    581     dnums.set_output_feature_dimension(num_spatial_dims_ + 1);
    582 
    583     for (int i = 0; i < num_spatial_dims_; ++i) {
    584       int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
    585       dnums.add_input_spatial_dimensions(dim);
    586       dnums.add_kernel_spatial_dimensions(dim);
    587 
    588       // We will also need to pad the input with zeros such that after the
    589       // convolution, we get the right size for the filter.
    590       // The padded_in_rows should be such that when we convolve this with the
    591       // expanded_out_rows as a filter, we should get filter_rows back.
    592       //
    593       const int64 padded_in_size =
    594           dims.spatial_dims[i].expanded_output_size +
    595           (dims.spatial_dims[i].filter_size - 1) * dilations_[dim];
    596 
    597       // However it can be smaller than input_rows: in this
    598       // case it means some of the inputs are not used.
    599       //
    600       // An example is to have input_cols = 3, filter_cols = 2 and stride = 2:
    601       //
    602       // INPUT =  [ A  B  C ]
    603       //
    604       // FILTER = [ x y ]
    605       //
    606       // and the output will only have one column: a = A * x + B * y
    607       //
    608       // and input "C" is not used at all.
    609       //
    610       // We apply negative padding in this case.
    611       const int64 pad_total = padded_in_size - dims.spatial_dims[i].input_size;
    612 
    613       // + For the VALID padding, we don't pad anything on the top/left side
    614       //   and pad the bottom/right side with the remaining space.
    615       // + For the SAME padding, we pad top/left side the same as bottom/right
    616       //   side.
    617       //
    618       // In addition, if the padded input size is smaller than the input size,
    619       // we need to ignore some training elements of the input. We do this by
    620       // applying negative padding on the right/bottom.
    621       const int64 pad_before =
    622           padding_ == Padding::SAME ? std::max<int64>(pad_total / 2, 0) : 0;
    623 
    624       padding[i] = {pad_before, pad_total - pad_before};
    625       rhs_dilation[i] = dims.spatial_dims[i].stride;
    626       window_strides[i] = dilations_[dim];
    627     }
    628 
    629     // Besides padding the input, we will also expand output_rows to
    630     //    expanded_out_rows = (output_rows - 1) * stride + 1
    631     // with zeros in between:
    632     //
    633     //      a . . . b . . . c . . . d . . . e
    634     //
    635     // This is done by specifying the window dilation factors in the
    636     // convolution HLO below.
    637     auto filter_backprop =
    638         b->ConvGeneralDilated(activations, gradients, window_strides, padding,
    639                               /*lhs_dilation=*/ones, rhs_dilation, dnums);
    640 
    641     if (depthwise_) {
    642       filter_backprop = ContractFilterForDepthwiseBackprop(
    643           ctx, filter_shape, ctx->input_type(0), filter_backprop, b);
    644     }
    645     ctx->SetOutput(0, filter_backprop);
    646   }
    647 
    648  protected:
    649   const int num_spatial_dims_;
    650   const bool depthwise_;
    651   std::vector<int32> dilations_;
    652   std::vector<int32> strides_;
    653   Padding padding_;
    654   TensorFormat data_format_ = FORMAT_NHWC;
    655 
    656  private:
    657   TF_DISALLOW_COPY_AND_ASSIGN(ConvBackpropFilterOp);
    658 };
    659 
    660 class Conv2DBackpropFilterOp : public ConvBackpropFilterOp {
    661  public:
    662   explicit Conv2DBackpropFilterOp(OpKernelConstruction* ctx)
    663       : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) {
    664   }
    665 };
    666 REGISTER_XLA_OP(
    667     Name("Conv2DBackpropFilter").CompileTimeConstInput("filter_sizes"),
    668     Conv2DBackpropFilterOp);
    669 
    670 class Conv3DBackpropFilterOp : public ConvBackpropFilterOp {
    671  public:
    672   explicit Conv3DBackpropFilterOp(OpKernelConstruction* ctx)
    673       : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) {
    674   }
    675 };
    676 REGISTER_XLA_OP(
    677     Name("Conv3DBackpropFilterV2").CompileTimeConstInput("filter_sizes"),
    678     Conv3DBackpropFilterOp);
    679 
    680 class DepthwiseConv2DBackpropFilterOp : public ConvBackpropFilterOp {
    681  public:
    682   explicit DepthwiseConv2DBackpropFilterOp(OpKernelConstruction* ctx)
    683       : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {}
    684 };
    685 REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropFilter")
    686                     .CompileTimeConstInput("filter_sizes"),
    687                 DepthwiseConv2DBackpropFilterOp);
    688 
    689 }  // namespace
    690 }  // namespace tensorflow
    691