Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // XLA specific pooling ops.
     17 
     18 #include "tensorflow/compiler/tf2xla/type_util.h"
     19 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
     20 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
     21 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     22 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
     23 #include "tensorflow/compiler/xla/literal_util.h"
     24 #include "tensorflow/compiler/xla/util.h"
     25 #include "tensorflow/core/framework/op_kernel.h"
     26 #include "tensorflow/core/framework/register_types.h"
     27 #include "tensorflow/core/framework/tensor.h"
     28 #include "tensorflow/core/kernels/bounds_check.h"
     29 #include "tensorflow/core/kernels/conv_grad_ops.h"
     30 #include "tensorflow/core/kernels/pooling_ops_common.h"
     31 
     32 namespace tensorflow {
     33 namespace {
     34 
     35 // Superclass of pooling ops.
     36 class PoolingOp : public XlaOpKernel {
     37  public:
     38   PoolingOp(OpKernelConstruction* ctx, int num_spatial_dims)
     39       : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
     40     if (ctx->num_inputs() == 1) {
     41       std::vector<int32> ksize_int;
     42       std::vector<int32> stride_int;
     43       OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_int));
     44       OP_REQUIRES(ctx, ksize_int.size() == num_dims(),
     45                   errors::InvalidArgument("Sliding window ksize field must "
     46                                           "specify ",
     47                                           num_dims(), " dimensions"));
     48       OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_int));
     49       OP_REQUIRES(ctx, stride_int.size() == num_dims(),
     50                   errors::InvalidArgument("Sliding window stride field must "
     51                                           "specify ",
     52                                           num_dims(), " dimensions"));
     53       for (int i = 0; i < num_dims(); ++i) {
     54         ksize_.push_back(ksize_int[i]);
     55         stride_.push_back(stride_int[i]);
     56       }
     57     }
     58     Padding padding;
     59     OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding));
     60     padding_ = (padding == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
     61   }
     62 
     63   int num_dims() const { return num_spatial_dims_ + 2; }
     64 
     65   // Method that builds an initial value to use in reductions.
     66   virtual xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b,
     67                                                DataType data_type) = 0;
     68 
     69   // The reduction operation to apply to each window.
     70   virtual const xla::Computation* Reduction(XlaOpKernelContext* ctx,
     71                                             DataType dtype) = 0;
     72 
     73   // A post-processing operation to apply on the outputs of the ReduceWindow.
     74   virtual xla::ComputationDataHandle PostProcessOutput(
     75       XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output,
     76       DataType dtype, const TensorShape& input_shape) = 0;
     77 
     78   void Compile(XlaOpKernelContext* ctx) override {
     79     xla::ComputationDataHandle input = ctx->Input(0);
     80     const TensorShape input_shape = ctx->InputShape(0);
     81 
     82     std::vector<int64> ksize = ksize_;
     83     std::vector<int64> stride = stride_;
     84     if (ctx->num_inputs() != 1) {
     85       const TensorShape ksize_shape = ctx->InputShape(1);
     86       // Validate input sizes.
     87       OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape),
     88                   errors::InvalidArgument("ksize must be a vector, not shape ",
     89                                           ksize_shape.DebugString()));
     90       OP_REQUIRES(ctx, ksize_shape.num_elements() == num_dims(),
     91                   errors::InvalidArgument("Sliding window ksize field must "
     92                                           "specify ",
     93                                           num_dims(), " dimensions"));
     94       ksize.clear();
     95       OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &ksize));
     96 
     97       const TensorShape stride_shape = ctx->InputShape(2);
     98       // Validate input sizes.
     99       OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape),
    100                   errors::InvalidArgument("stride must be a vector, not shape ",
    101                                           stride_shape.DebugString()));
    102       OP_REQUIRES(ctx, stride_shape.num_elements() == num_dims(),
    103                   errors::InvalidArgument("Sliding window stride field must "
    104                                           "specify ",
    105                                           num_dims(), " dimensions"));
    106       stride.clear();
    107       OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &stride));
    108     }
    109     OP_REQUIRES(ctx, input_shape.dims() == num_dims(),
    110                 errors::InvalidArgument("Input to ", type_string(),
    111                                         " operator must have ", num_dims(),
    112                                         " dimensions"));
    113 
    114     const DataType type = input_type(0);
    115     xla::ComputationDataHandle pooled = ctx->builder()->ReduceWindow(
    116         input, InitValue(ctx->builder(), type), *Reduction(ctx, type), ksize,
    117         stride, padding_);
    118     ctx->SetOutput(0, PostProcessOutput(ctx, pooled, type, input_shape));
    119   }
    120 
    121  protected:
    122   const int num_spatial_dims_;
    123   std::vector<int64> ksize_;
    124   std::vector<int64> stride_;
    125   xla::Padding padding_;
    126   TensorFormat data_format_ = FORMAT_NHWC;
    127 };
    128 
    129 class MaxPoolOp : public PoolingOp {
    130  public:
    131   MaxPoolOp(OpKernelConstruction* ctx, int num_spatial_dims)
    132       : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims) {}
    133 
    134   xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b,
    135                                        DataType data_type) override {
    136     return XlaHelpers::MinValue(b, data_type);
    137   }
    138 
    139   const xla::Computation* Reduction(XlaOpKernelContext* ctx,
    140                                     DataType dtype) override {
    141     return ctx->GetOrCreateMax(dtype);
    142   }
    143 
    144   xla::ComputationDataHandle PostProcessOutput(
    145       XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output,
    146       DataType dtype, const TensorShape& input_shape) override {
    147     return output;
    148   }
    149 };
    150 
    151 class MaxPool2DOp : public MaxPoolOp {
    152  public:
    153   explicit MaxPool2DOp(OpKernelConstruction* ctx)
    154       : MaxPoolOp(ctx, /*num_spatial_dims=*/2) {
    155     string data_format_str;
    156     OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str));
    157     OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_),
    158                 errors::InvalidArgument("Invalid data format"));
    159   }
    160 };
    161 REGISTER_XLA_OP(Name("MaxPool"), MaxPool2DOp);
    162 REGISTER_XLA_OP(Name("MaxPoolV2")
    163                     .CompileTimeConstInput("ksize")
    164                     .CompileTimeConstInput("strides"),
    165                 MaxPool2DOp);
    166 
    167 class MaxPool3DOp : public MaxPoolOp {
    168  public:
    169   explicit MaxPool3DOp(OpKernelConstruction* ctx)
    170       : MaxPoolOp(ctx, /*num_spatial_dims=*/3) {}
    171 };
    172 REGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp);
    173 
    174 // Common computation shared between AvgPool and AvgPoolGrad. Divide each
    175 // element of an image by the count of elements that contributed to that
    176 // element during pooling.
    177 static xla::ComputationDataHandle AvgPoolDivideByCount(
    178     XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output,
    179     DataType dtype, const TensorShape& input_shape, xla::Padding padding,
    180     const std::vector<int64>& ksize, const std::vector<int64>& stride,
    181     int num_spatial_dims, TensorFormat data_format) {
    182   if (padding == xla::Padding::kValid) {
    183     // In VALID padding, all windows have the same number of elements
    184     // contributing to each average. Divide by the window size everywhere to
    185     // get the average.
    186     int64 window_size = std::accumulate(ksize.begin(), ksize.end(), 1,
    187                                         [](int64 a, int64 b) { return a * b; });
    188 
    189     auto divisor =
    190         XlaHelpers::IntegerLiteral(ctx->builder(), dtype, window_size);
    191     return ctx->builder()->Div(output, divisor);
    192   } else {
    193     // For SAME padding, the padding shouldn't be included in the
    194     // counts. We use another ReduceWindow to find the right counts.
    195 
    196     // TODO(phawkins): use a less brute-force way to compute this. Only
    197     // the boundary regions will have interesting values here.
    198 
    199     std::vector<int64> input_dim_sizes(num_spatial_dims);
    200     std::vector<int64> window_dims(num_spatial_dims);
    201     std::vector<int64> window_ksize(num_spatial_dims);
    202     std::vector<int64> window_stride(num_spatial_dims);
    203     for (int i = 0; i < num_spatial_dims; ++i) {
    204       int dim = GetTensorSpatialDimIndex(num_spatial_dims + 2, data_format, i);
    205       input_dim_sizes[i] = input_shape.dim_size(dim);
    206       window_dims[i] = dim;
    207       window_ksize[i] = ksize[dim];
    208       window_stride[i] = stride[dim];
    209     }
    210 
    211     // Build a matrix of all 1s, with the same width/height as the input.
    212     auto ones = ctx->builder()->Broadcast(
    213         XlaHelpers::One(ctx->builder(), dtype), input_dim_sizes);
    214 
    215     // Perform a ReduceWindow with the same window size, strides, and padding
    216     // to count the number of contributions to each result element.
    217     auto counts = ctx->builder()->ReduceWindow(
    218         ones, XlaHelpers::Zero(ctx->builder(), dtype),
    219         *ctx->GetOrCreateAdd(dtype), window_ksize, window_stride,
    220         xla::Padding::kSame);
    221 
    222     return ctx->builder()->Div(output, counts, window_dims);
    223   }
    224 }
    225 
    226 class AvgPoolOp : public PoolingOp {
    227  public:
    228   AvgPoolOp(OpKernelConstruction* ctx, int num_spatial_dims)
    229       : PoolingOp(ctx, num_spatial_dims) {}
    230 
    231   xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b,
    232                                        DataType data_type) override {
    233     return XlaHelpers::Zero(b, data_type);
    234   }
    235 
    236   const xla::Computation* Reduction(XlaOpKernelContext* ctx,
    237                                     DataType dtype) override {
    238     return ctx->GetOrCreateAdd(dtype);
    239   }
    240 
    241   xla::ComputationDataHandle PostProcessOutput(
    242       XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output,
    243       DataType dtype, const TensorShape& input_shape) override {
    244     return AvgPoolDivideByCount(ctx, output, dtype, input_shape, padding_,
    245                                 ksize_, stride_, num_spatial_dims_,
    246                                 data_format_);
    247   }
    248 };
    249 
    250 class AvgPool2DOp : public AvgPoolOp {
    251  public:
    252   explicit AvgPool2DOp(OpKernelConstruction* ctx)
    253       : AvgPoolOp(ctx, /*num_spatial_dims=*/2) {
    254     string data_format_str;
    255     OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str));
    256     OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_),
    257                 errors::InvalidArgument("Invalid data format"));
    258   }
    259 };
    260 REGISTER_XLA_OP(Name("AvgPool"), AvgPool2DOp);
    261 
    262 class AvgPool3DOp : public AvgPoolOp {
    263  public:
    264   explicit AvgPool3DOp(OpKernelConstruction* ctx)
    265       : AvgPoolOp(ctx, /*num_spatial_dims=*/3) {}
    266 };
    267 REGISTER_XLA_OP(Name("AvgPool3D"), AvgPool3DOp);
    268 
    269 // The operation to compute MaxPool gradients.
    270 // It takes three inputs:
    271 //   - The original input tensor
    272 //   - The original output tensor
    273 //   - Backprop tensor for output
    274 // It produces one output: backprop tensor for input.
    275 class MaxPoolGradOp : public XlaOpKernel {
    276  public:
    277   MaxPoolGradOp(OpKernelConstruction* ctx, int num_spatial_dims)
    278       : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
    279     if (ctx->num_inputs() == 3) {
    280       OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
    281       OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
    282     }
    283     OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
    284   }
    285 
    286   int num_dims() const { return num_spatial_dims_ + 2; }
    287 
    288   void Compile(XlaOpKernelContext* ctx) override {
    289     if (ctx->num_inputs() != 3) {
    290       OP_REQUIRES(
    291           ctx, ctx->num_inputs() == 5,
    292           errors::InvalidArgument("Must supply ksize and stride arguments."));
    293       const TensorShape ksize_shape = ctx->InputShape(3);
    294       // Validate input sizes.
    295       OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape),
    296                   errors::InvalidArgument("ksize must be a vector, not shape ",
    297                                           ksize_shape.DebugString()));
    298       OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(3, &ksize_));
    299 
    300       const TensorShape stride_shape = ctx->InputShape(4);
    301       // Validate input sizes.
    302       OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape),
    303                   errors::InvalidArgument("stride must be a vector, not shape ",
    304                                           stride_shape.DebugString()));
    305       OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(4, &stride_));
    306     }
    307 
    308     OP_REQUIRES(ctx, ksize_.size() == num_dims(),
    309                 errors::InvalidArgument("Sliding window ksize field must "
    310                                         "specify ",
    311                                         num_dims(), " dimensions"));
    312     OP_REQUIRES(ctx, stride_.size() == num_dims(),
    313                 errors::InvalidArgument("Sliding window strides field must "
    314                                         "specify ",
    315                                         num_dims(), " dimensions"));
    316 
    317     const TensorShape tensor_in_shape = ctx->InputShape(0);
    318     const TensorShape tensor_out_shape = ctx->InputShape(1);
    319     const TensorShape out_backprop_shape = ctx->InputShape(2);
    320 
    321     // For maxpooling, tensor_in should have num_dims() dimensions.
    322     OP_REQUIRES(ctx, tensor_in_shape.dims() == num_dims(),
    323                 errors::InvalidArgument("tensor_in must be ", num_dims(),
    324                                         "-dimensional"));
    325     OP_REQUIRES(ctx, tensor_out_shape.dims() == num_dims(),
    326                 errors::InvalidArgument("tensor_out must be ", num_dims(),
    327                                         "-dimensional"));
    328     // For maxpooling, out_backprop should have num_dims() dimensions.
    329     OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(),
    330                 errors::InvalidArgument("out_backprop must be ", num_dims(),
    331                                         "-dimensional"));
    332 
    333     // TODO(phawkins): The XLA version doesn't need tensor_out. Investigate
    334     // whether this is a good time/space tradeoff.
    335     auto input = ctx->Input(0);
    336     auto out_backprop = ctx->Input(2);
    337 
    338     xla::Padding xla_padding =
    339         (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
    340 
    341     xla::PrimitiveType element_type;
    342     OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type));
    343     xla::ComputationDataHandle init_value =
    344         XlaHelpers::Zero(ctx->builder(), input_type(2));
    345     auto select = CreateScalarGeComputation(element_type, ctx->builder());
    346     auto scatter = CreateScalarAddComputation(element_type, ctx->builder());
    347     xla::ComputationDataHandle gradients = ctx->builder()->SelectAndScatter(
    348         input, select, ksize_, stride_, xla_padding, out_backprop, init_value,
    349         scatter);
    350 
    351     ctx->SetOutput(0, gradients);
    352   }
    353 
    354  protected:
    355   const int num_spatial_dims_;
    356   std::vector<int64> ksize_;
    357   std::vector<int64> stride_;
    358   Padding padding_;
    359   TensorFormat data_format_ = FORMAT_NHWC;
    360 };
    361 
    362 class MaxPool2DGradOp : public MaxPoolGradOp {
    363  public:
    364   explicit MaxPool2DGradOp(OpKernelConstruction* ctx)
    365       : MaxPoolGradOp(ctx, /*num_spatial_dims=*/2) {
    366     string data_format;
    367     OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
    368     OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
    369                 errors::InvalidArgument("Invalid data format"));
    370   }
    371 };
    372 REGISTER_XLA_OP(Name("MaxPoolGrad"), MaxPool2DGradOp);
    373 REGISTER_XLA_OP(Name("MaxPoolGradV2")
    374                     .CompileTimeConstInput("ksize")
    375                     .CompileTimeConstInput("strides"),
    376                 MaxPool2DGradOp);
    377 
    378 class MaxPool3DGradOp : public MaxPoolGradOp {
    379  public:
    380   explicit MaxPool3DGradOp(OpKernelConstruction* ctx)
    381       : MaxPoolGradOp(ctx, /*num_spatial_dims=*/3) {}
    382 };
    383 REGISTER_XLA_OP(Name("MaxPool3DGrad"), MaxPool3DGradOp);
    384 
    385 // Average-pooling gradient
    386 class AvgPoolGradOp : public XlaOpKernel {
    387  public:
    388   AvgPoolGradOp(OpKernelConstruction* ctx, int num_spatial_dims)
    389       : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
    390     OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
    391     OP_REQUIRES(ctx, ksize_.size() == num_dims(),
    392                 errors::InvalidArgument("Sliding window ksize field must "
    393                                         "specify ",
    394                                         num_dims(), " dimensions"));
    395     OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
    396     OP_REQUIRES(ctx, stride_.size() == num_dims(),
    397                 errors::InvalidArgument("Sliding window strides field must "
    398                                         "specify ",
    399                                         num_dims(), " dimensions"));
    400     OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
    401     OP_REQUIRES(ctx, ksize_[0] == 1 && stride_[0] == 1,
    402                 errors::Unimplemented(
    403                     "Pooling is not yet supported on the batch dimension."));
    404   }
    405 
    406   int num_dims() const { return num_spatial_dims_ + 2; }
    407 
    408   void Compile(XlaOpKernelContext* ctx) override {
    409     TensorShape gradients_shape;
    410     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &gradients_shape));
    411 
    412     const TensorShape out_backprop_shape = ctx->InputShape(1);
    413 
    414     // For avgpooling, tensor_in_shape should have num_dims() dimensions.
    415     OP_REQUIRES(ctx, gradients_shape.dims() == num_dims(),
    416                 errors::InvalidArgument("orig_input_shape must be ", num_dims(),
    417                                         "-dimensional"));
    418 
    419     // For avgpooling, out_backprop should have num_dims() dimensions.
    420     OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(),
    421                 errors::InvalidArgument("out_backprop must be ", num_dims(),
    422                                         "-dimensional"));
    423 
    424     int depth_dim = GetTensorFeatureDimIndex(num_dims(), data_format_);
    425     int64 depth = out_backprop_shape.dim_size(depth_dim);
    426 
    427     // We can think of average-pooling as:
    428     // * a convolution with a kernel consisting entirely of 1s, where the
    429     //   input feature and output feature are equal, and 0s everywhere else.
    430     // * followed by dividing by the counts.
    431     //
    432     // This then gives us an algorithm to build the gradient:
    433     // * divide out_backprop by the counts, followed by
    434     // * Conv2DBackpropInput specialized for that kernel, which simplifies to
    435     //   a Pad and a ReduceWindow.
    436     //
    437     // For an explanation of backpropagation for convolution, see the comments
    438     // in third_party/tensorflow/core/kernels/conv_grad_ops.h
    439 
    440     // TF filter shape is [ H, W, ..., inC, outC ]
    441     std::vector<int64> filter_dims(num_dims());
    442     for (int i = 0; i < num_spatial_dims_; ++i) {
    443       int dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
    444       filter_dims[i] = ksize_[dim];
    445     }
    446     filter_dims[num_dims() - 2] = depth;
    447     filter_dims[num_dims() - 1] = depth;
    448     TensorShape filter_shape(filter_dims);
    449 
    450     // Reuse the logic from Conv2DBackpropInput to compute padding.
    451     ConvBackpropDimensions dims;
    452     OP_REQUIRES_OK(
    453         ctx, ConvBackpropComputeDimensions(
    454                  type_string(), /*num_spatial_dims=*/num_spatial_dims_,
    455                  gradients_shape, filter_shape, out_backprop_shape, stride_,
    456                  padding_, data_format_, &dims));
    457 
    458     auto out_backprop = ctx->Input(1);
    459 
    460     // The input gradients are computed by a convolution of the output
    461     // gradients
    462     // and the filter, with some appropriate padding. See the comment at
    463     // the top of conv_grad_ops.h for details.
    464     DataType dtype = input_type(1);
    465 
    466     xla::Padding xla_padding =
    467         (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
    468 
    469     // Divide the out_backprop values by the counts for each spatial position.
    470     std::vector<int64> stride_int64s(stride_.begin(), stride_.end());
    471     auto out_backprop_div = AvgPoolDivideByCount(
    472         ctx, out_backprop, dtype, gradients_shape, xla_padding, ksize_,
    473         stride_int64s, num_spatial_dims_, data_format_);
    474 
    475     // Pad the gradients in the spatial dimensions. We use the same padding
    476     // as Conv2DBackpropInput.
    477     xla::PaddingConfig padding_config = xla::MakeNoPaddingConfig(num_dims());
    478     for (int i = 0; i < num_spatial_dims_; ++i) {
    479       int dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
    480       auto* padding = padding_config.mutable_dimensions(dim);
    481       padding->set_edge_padding_low(dims.spatial_dims[i].pad_before);
    482       padding->set_edge_padding_high(dims.spatial_dims[i].pad_after);
    483       padding->set_interior_padding(dims.spatial_dims[i].stride - 1);
    484     }
    485 
    486     auto zero = XlaHelpers::Zero(ctx->builder(), dtype);
    487     auto padded_gradients =
    488         ctx->builder()->Pad(out_backprop_div, zero, padding_config);
    489 
    490     // in_backprop = padded_gradients <conv> ones
    491     std::vector<int64> ones(num_dims(), 1LL);
    492     xla::ComputationDataHandle in_backprop = ctx->builder()->ReduceWindow(
    493         padded_gradients, zero, *ctx->GetOrCreateAdd(dtype), ksize_,
    494         /* window_strides=*/ones, xla::Padding::kValid);
    495 
    496     ctx->SetOutput(0, in_backprop);
    497   }
    498 
    499  protected:
    500   const int num_spatial_dims_;
    501   std::vector<int64> ksize_;
    502   std::vector<int32> stride_;
    503   Padding padding_;
    504   TensorFormat data_format_ = FORMAT_NHWC;
    505 };
    506 
    507 class AvgPool2DGradOp : public AvgPoolGradOp {
    508  public:
    509   explicit AvgPool2DGradOp(OpKernelConstruction* ctx)
    510       : AvgPoolGradOp(ctx, /*num_spatial_dims=*/2) {
    511     string data_format;
    512     OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
    513     OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
    514                 errors::InvalidArgument("Invalid data format"));
    515   }
    516 };
    517 REGISTER_XLA_OP(Name("AvgPoolGrad").CompileTimeConstInput("orig_input_shape"),
    518                 AvgPool2DGradOp);
    519 
    520 class AvgPool3DGradOp : public AvgPoolGradOp {
    521  public:
    522   explicit AvgPool3DGradOp(OpKernelConstruction* ctx)
    523       : AvgPoolGradOp(ctx, /*num_spatial_dims=*/3) {}
    524 };
    525 REGISTER_XLA_OP(Name("AvgPool3DGrad").CompileTimeConstInput("orig_input_shape"),
    526                 AvgPool3DGradOp);
    527 
    528 }  // anonymous namespace
    529 }  // namespace tensorflow
    530