Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/nn_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #include "tensorflow/core/kernels/avgpooling_op.h"
     21 
     22 #include <vector>
     23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     24 #include "tensorflow/core/framework/numeric_op.h"
     25 #include "tensorflow/core/framework/op_kernel.h"
     26 #include "tensorflow/core/framework/register_types.h"
     27 #include "tensorflow/core/framework/tensor.h"
     28 #include "tensorflow/core/framework/tensor_shape.h"
     29 #include "tensorflow/core/framework/tensor_slice.h"
     30 #include "tensorflow/core/kernels/eigen_pooling.h"
     31 #include "tensorflow/core/kernels/ops_util.h"
     32 #include "tensorflow/core/kernels/pooling_ops_common.h"
     33 #include "tensorflow/core/lib/core/errors.h"
     34 #include "tensorflow/core/lib/gtl/array_slice.h"
     35 #include "tensorflow/core/platform/logging.h"
     36 #include "tensorflow/core/util/padding.h"
     37 #include "tensorflow/core/util/tensor_format.h"
     38 
     39 #if GOOGLE_CUDA
     40 #include "tensorflow/core/kernels/maxpooling_op_gpu.h"
     41 #include "tensorflow/core/kernels/pooling_ops_common_gpu.h"
     42 #endif  // GOOGLE_CUDA
     43 
     44 namespace tensorflow {
     45 
     46 typedef Eigen::ThreadPoolDevice CPUDevice;
     47 typedef Eigen::GpuDevice GPUDevice;
     48 
     49 template <typename Device, typename T>
     50 class AvgPoolingOp : public UnaryOp<T> {
     51  public:
     52   explicit AvgPoolingOp(OpKernelConstruction* context) : UnaryOp<T>(context) {
     53     string data_format;
     54     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
     55     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
     56                 errors::InvalidArgument("Invalid data format"));
     57     OP_REQUIRES(
     58         context, data_format_ == FORMAT_NHWC,
     59         errors::InvalidArgument("Default AvgPoolingOp only supports NHWC ",
     60                                 "on device type ",
     61                                 DeviceTypeString(context->device_type())));
     62     OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
     63     OP_REQUIRES(context, ksize_.size() == 4,
     64                 errors::InvalidArgument("Sliding window ksize field must "
     65                                         "specify 4 dimensions"));
     66     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
     67     OP_REQUIRES(context, stride_.size() == 4,
     68                 errors::InvalidArgument("Sliding window stride field must "
     69                                         "specify 4 dimensions"));
     70     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
     71     OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
     72                 errors::Unimplemented(
     73                     "Pooling is not yet supported on the batch dimension."));
     74   }
     75 
     76   void Compute(OpKernelContext* context) override {
     77     const Tensor& tensor_in = context->input(0);
     78     PoolParameters params{context,  ksize_,       stride_,
     79                           padding_, data_format_, tensor_in.shape()};
     80     if (!context->status().ok()) {
     81       return;
     82     }
     83     OP_REQUIRES(context, params.depth_window == 1,
     84                 errors::Unimplemented("Non-spatial pooling is not "
     85                                       "yet supported. Volunteers? :)"));
     86 
     87     // For avgpooling, tensor_in should have 4 dimensions.
     88     OP_REQUIRES(context, tensor_in.dims() == 4,
     89                 errors::InvalidArgument("tensor_in must be 4-dimensional"));
     90 
     91     Tensor* output = nullptr;
     92     OP_REQUIRES_OK(context, context->allocate_output(
     93                                 0, params.forward_output_shape(), &output));
     94 
     95     SpatialAvgPool<Device, T>(context, output, tensor_in, params, padding_);
     96   }
     97 
     98  private:
     99   std::vector<int32> ksize_;
    100   std::vector<int32> stride_;
    101   Padding padding_;
    102   TensorFormat data_format_;
    103 };
    104 
    105 REGISTER_KERNEL_BUILDER(
    106     Name("AvgPool").Device(DEVICE_CPU).TypeConstraint<float>("T"),
    107     AvgPoolingOp<CPUDevice, float>);
    108 REGISTER_KERNEL_BUILDER(
    109     Name("AvgPool").Device(DEVICE_CPU).TypeConstraint<Eigen::half>("T"),
    110     AvgPoolingOp<CPUDevice, Eigen::half>);
    111 
    112 #if GOOGLE_CUDA
    113 template <typename T>
    114 class AvgPoolingOp<GPUDevice, T> : public UnaryOp<T> {
    115  public:
    116   typedef GPUDevice Device;
    117   explicit AvgPoolingOp(OpKernelConstruction* context) : UnaryOp<T>(context) {
    118     string data_format;
    119     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
    120     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
    121                 errors::InvalidArgument("Invalid data format"));
    122     OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
    123     OP_REQUIRES(context, ksize_.size() == 4,
    124                 errors::InvalidArgument("Sliding window ksize field must "
    125                                         "specify 4 dimensions"));
    126     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
    127     OP_REQUIRES(context, stride_.size() == 4,
    128                 errors::InvalidArgument("Sliding window stride field must "
    129                                         "specify 4 dimensions"));
    130     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
    131     const int32 ksize_n = GetTensorDim(ksize_, data_format_, 'N');
    132     const int32 stride_n = GetTensorDim(stride_, data_format_, 'N');
    133     OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
    134                 errors::Unimplemented(
    135                     "Pooling is not yet supported on the batch dimension."));
    136   }
    137 
    138   void Compute(OpKernelContext* context) override {
    139     const Tensor& tensor_in = context->input(0);
    140     PoolParameters params{context,  ksize_,       stride_,
    141                           padding_, data_format_, tensor_in.shape()};
    142     if (!context->status().ok()) {
    143       return;
    144     }
    145     OP_REQUIRES(context, params.depth_window == 1,
    146                 errors::Unimplemented("Non-spatial pooling is not "
    147                                       "yet supported. Volunteers? :)"));
    148 
    149     // For avgpooling, tensor_in should have 4 dimensions.
    150     OP_REQUIRES(context, tensor_in.dims() == 4,
    151                 errors::InvalidArgument("tensor_in must be 4-dimensional"));
    152 
    153     TensorShape output_shape = params.forward_output_shape();
    154 
    155     if (data_format_ == FORMAT_NCHW) {
    156       DnnPoolingOp<T>::Compute(
    157           context, perftools::gputools::dnn::PoolingMode::kAverage, ksize_,
    158           stride_, padding_, data_format_, tensor_in, output_shape,
    159           /*propagate_nans=*/false);
    160     } else {
    161       Tensor* output = nullptr;
    162       OP_REQUIRES_OK(context,
    163                      context->allocate_output(0, output_shape, &output));
    164       Eigen::PaddingType pt = BrainPadding2EigenPadding(padding_);
    165       functor::SpatialAvgPooling<Device, T>()(
    166           context->eigen_device<Device>(), output->tensor<T, 4>(),
    167           tensor_in.tensor<T, 4>(), params.window_rows, params.window_cols,
    168           params.row_stride, params.col_stride, pt);
    169     }
    170   }
    171 
    172  private:
    173   std::vector<int32> ksize_;
    174   std::vector<int32> stride_;
    175   Padding padding_;
    176   TensorFormat data_format_;
    177 };
    178 
    179 // Forward declarations of the functor specializations for GPU.
    180 namespace functor {
    181 #define DECLARE_GPU_SPEC(T)                                      \
    182   template <>                                                    \
    183   void SpatialAvgPooling<GPUDevice, T>::operator()(              \
    184       const GPUDevice& d, typename TTypes<T, 4>::Tensor output,  \
    185       typename TTypes<T, 4>::ConstTensor input, int window_rows, \
    186       int window_cols, int row_stride, int col_stride,           \
    187       const Eigen::PaddingType& padding);                        \
    188   extern template struct SpatialAvgPooling<GPUDevice, T>;
    189 
    190 DECLARE_GPU_SPEC(Eigen::half);
    191 DECLARE_GPU_SPEC(float);
    192 #undef DECLARE_GPU_SPEC
    193 }  // namespace functor
    194 
    195 REGISTER_KERNEL_BUILDER(
    196     Name("AvgPool").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
    197     AvgPoolingOp<GPUDevice, Eigen::half>);
    198 REGISTER_KERNEL_BUILDER(
    199     Name("AvgPool").Device(DEVICE_GPU).TypeConstraint<float>("T"),
    200     AvgPoolingOp<GPUDevice, float>);
    201 #endif  // GOOGLE_CUDA
    202 
    203 // The operation to compute AvgPool gradients.
    204 // It takes two inputs:
    205 //   - The original input tensor shape
    206 //   - Backprop tensor for output
    207 // It produces one output: backprop tensor for input.
    208 template <typename Device, class T>
    209 class AvgPoolingGradOp : public OpKernel {
    210  public:
    211   explicit AvgPoolingGradOp(OpKernelConstruction* context) : OpKernel(context) {
    212     string data_format;
    213     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
    214     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
    215                 errors::InvalidArgument("Invalid data format"));
    216     OP_REQUIRES(
    217         context, data_format_ == FORMAT_NHWC,
    218         errors::InvalidArgument("Default AvgPoolingGradOp only supports NHWC ",
    219                                 "on device type ",
    220                                 DeviceTypeString(context->device_type())));
    221     OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
    222     OP_REQUIRES(context, ksize_.size() == 4,
    223                 errors::InvalidArgument("Sliding window ksize field must "
    224                                         "specify 4 dimensions"));
    225     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
    226     OP_REQUIRES(context, stride_.size() == 4,
    227                 errors::InvalidArgument("Sliding window strides field must "
    228                                         "specify 4 dimensions"));
    229     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
    230     OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
    231                 errors::Unimplemented(
    232                     "Pooling is not yet supported on the batch dimension."));
    233   }
    234 
    235   void Compute(OpKernelContext* context) override {
    236     const Tensor& tensor_in_shape = context->input(0);
    237     const Tensor& out_backprop = context->input(1);
    238     // For avgpooling, tensor_in_shape should have 1 dimension, and 4 elements.
    239     OP_REQUIRES(
    240         context,
    241         tensor_in_shape.dims() == 1 && tensor_in_shape.NumElements() == 4,
    242         errors::InvalidArgument("out_backprop must be 1-dimensional and 4 "
    243                                 "elements"));
    244     // For avgpooling, out_backprop should have 4 dimensions.
    245     OP_REQUIRES(context, out_backprop.dims() == 4,
    246                 errors::InvalidArgument("out_backprop must be 4-dimensional"));
    247     const int64 out_backprop_batch = out_backprop.dim_size(0);
    248     const int64 out_backprop_rows = out_backprop.dim_size(1);
    249     const int64 out_backprop_cols = out_backprop.dim_size(2);
    250     const int64 out_backprop_depth = out_backprop.dim_size(3);
    251 
    252     TensorShape output_shape;
    253     auto shape_vec = tensor_in_shape.vec<int32>();
    254     for (int64 i = 0; i < tensor_in_shape.NumElements(); ++i) {
    255       output_shape.AddDim(shape_vec(i));
    256     }
    257     const int64 in_rows = output_shape.dim_size(1);
    258     const int64 in_cols = output_shape.dim_size(2);
    259 
    260     Tensor* output = nullptr;
    261     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
    262     output->flat<T>().setZero();
    263 
    264     const int window_rows = ksize_[1];
    265     const int window_cols = ksize_[2];
    266     const int depth_window = ksize_[3];
    267 
    268     const int row_stride = stride_[1];
    269     const int col_stride = stride_[2];
    270 
    271     // We (will) use different code for spatial pooling and
    272     // non-spatial pooling.
    273     //
    274     // Spatial pooling is when depth_window = 1
    275     OP_REQUIRES(context, depth_window == 1,
    276                 errors::Unimplemented("Non-spatial pooling is not "
    277                                       "yet supported. Volunteers? :)"));
    278 
    279     int64 out_height, out_width, pad_rows, pad_cols;
    280     OP_REQUIRES_OK(context,
    281                    GetWindowedOutputSize(in_rows, window_rows, row_stride,
    282                                          padding_, &out_height, &pad_rows));
    283     OP_REQUIRES_OK(context,
    284                    GetWindowedOutputSize(in_cols, window_cols, col_stride,
    285                                          padding_, &out_width, &pad_cols));
    286 
    287     const T* out_backprop_ptr = out_backprop.flat<T>().data();
    288     T* input_backprop_ptr = output->flat<T>().data();
    289 
    290     auto shard = [context, out_backprop_ptr, input_backprop_ptr,
    291                   out_backprop_rows, out_backprop_cols, out_backprop_depth,
    292                   in_rows, in_cols, window_rows, window_cols, row_stride,
    293                   col_stride, pad_rows, pad_cols](int64 start, int64 limit) {
    294       for (int64 b = start; b < limit; ++b) {
    295         for (int64 r = 0; r < out_backprop_rows; ++r) {
    296           // Calculates row broadcast size.  For SAME padding, current
    297           // index could be in the padding area, and r*row_stride +
    298           // window_rows could be beyond the input tensor's boundary. In
    299           // such cases, change the starting index and reduce the
    300           // broadcast size.
    301           int rindex, rsize;
    302           OP_REQUIRES_OK(context,
    303                          GetBroadcastSize(r, in_rows, window_rows, row_stride,
    304                                           pad_rows, &rindex, &rsize));
    305           for (int64 c = 0; c < out_backprop_cols; ++c) {
    306             // Calculates col broadcast size.  For SAME padding, current
    307             // index could be in the padding area, and c*col_stride +
    308             // window_cols could be beyond the input tensor's boundary. In
    309             // such cases, change the starting index and reduce the
    310             // broadcast size.
    311             int cindex, csize;
    312             OP_REQUIRES_OK(context,
    313                            GetBroadcastSize(c, in_cols, window_cols, col_stride,
    314                                             pad_cols, &cindex, &csize));
    315 
    316             T divide_coeff(1.0 / (rsize * csize));
    317             int64 output_index =
    318                 (b * out_backprop_rows + r) * out_backprop_cols + c;
    319             for (int64 r_dst = rindex; r_dst < rindex + rsize; ++r_dst) {
    320               for (int64 c_dst = cindex; c_dst < cindex + csize; ++c_dst) {
    321                 int64 input_index = (b * in_rows + r_dst) * in_cols + c_dst;
    322                 const T* output_offset =
    323                     out_backprop_ptr + output_index * out_backprop_depth;
    324                 T* input_offset =
    325                     input_backprop_ptr + input_index * out_backprop_depth;
    326                 for (int64 d = 0; d < out_backprop_depth; ++d) {
    327                   *input_offset += *output_offset * divide_coeff;
    328                   ++output_offset;
    329                   ++input_offset;
    330                 }
    331               }
    332             }
    333           }
    334         }
    335       }
    336     };
    337 
    338     const DeviceBase::CpuWorkerThreads& worker_threads =
    339         *(context->device()->tensorflow_cpu_worker_threads());
    340     const int64 shard_cost =
    341         window_rows * window_cols * depth_window * in_rows * in_rows * in_cols;
    342     Shard(worker_threads.num_threads, worker_threads.workers,
    343           out_backprop_batch, shard_cost, shard);
    344   }
    345 
    346  private:
    347   std::vector<int32> ksize_;
    348   std::vector<int32> stride_;
    349   Padding padding_;
    350   TensorFormat data_format_;
    351 };
    352 
    353 #define REGISTER_CPU_KERNEL(T)                                 \
    354   REGISTER_KERNEL_BUILDER(Name("AvgPoolGrad")                  \
    355                               .Device(DEVICE_CPU)              \
    356                               .TypeConstraint<T>("T")          \
    357                               .HostMemory("orig_input_shape"), \
    358                           AvgPoolingGradOp<CPUDevice, T>);
    359 
    360 TF_CALL_float(REGISTER_CPU_KERNEL);
    361 TF_CALL_double(REGISTER_CPU_KERNEL);
    362 TF_CALL_half(REGISTER_CPU_KERNEL);
    363 
    364 #if GOOGLE_CUDA
    365 
    366 // A CUDNN based AvgPoolingGrad implementation. It includes the padding as the
    367 // candidates for the pooling operation.
    368 template <class T>
    369 class AvgPoolingGradOp<GPUDevice, T> : public OpKernel {
    370  public:
    371   typedef GPUDevice Device;
    372 
    373   explicit AvgPoolingGradOp(OpKernelConstruction* context) : OpKernel(context) {
    374     string data_format;
    375     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
    376     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
    377                 errors::InvalidArgument("Invalid data format"));
    378     OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
    379     OP_REQUIRES(context, ksize_.size() == 4,
    380                 errors::InvalidArgument("Sliding window ksize field must "
    381                                         "specify 4 dimensions"));
    382     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
    383     OP_REQUIRES(context, stride_.size() == 4,
    384                 errors::InvalidArgument("Sliding window strides field must "
    385                                         "specify 4 dimensions"));
    386     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
    387     const int32 ksize_n = GetTensorDim(ksize_, data_format_, 'N');
    388     const int32 stride_n = GetTensorDim(stride_, data_format_, 'N');
    389     OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
    390                 errors::Unimplemented(
    391                     "Pooling is not yet supported on the batch dimension."));
    392   }
    393 
    394   void Compute(OpKernelContext* context) override {
    395     const Tensor& tensor_in_shape = context->input(0);
    396     const Tensor& out_backprop = context->input(1);
    397     // For avgpooling, tensor_in_shape should have 1 dimension, and 4 elements.
    398     OP_REQUIRES(
    399         context,
    400         tensor_in_shape.dims() == 1 && tensor_in_shape.NumElements() == 4,
    401         errors::InvalidArgument("out_backprop must be 1-dimensional and 4 "
    402                                 "elements"));
    403     // For avgpooling, out_backprop should have 4 dimensions.
    404     OP_REQUIRES(context, out_backprop.dims() == 4,
    405                 errors::InvalidArgument("out_backprop must be 4-dimensional"));
    406 
    407     TensorShape output_shape;
    408     auto shape_vec = tensor_in_shape.vec<int32>();
    409     for (int64 i = 0; i < tensor_in_shape.NumElements(); ++i) {
    410       output_shape.AddDim(shape_vec(i));
    411     }
    412 
    413     DnnPoolingGradOp<T>::Compute(
    414         context, perftools::gputools::dnn::PoolingMode::kAverage, ksize_,
    415         stride_, padding_, data_format_, nullptr, nullptr, out_backprop,
    416         output_shape, /*propagate_nans=*/false);
    417   }
    418 
    419  private:
    420   std::vector<int32> ksize_;
    421   std::vector<int32> stride_;
    422   Padding padding_;
    423   TensorFormat data_format_;
    424 };
    425 
    426 REGISTER_KERNEL_BUILDER(Name("AvgPoolGrad")
    427                             .Device(DEVICE_GPU)
    428                             .TypeConstraint<float>("T")
    429                             .HostMemory("orig_input_shape")
    430                             .Label("cudnn"),
    431                         AvgPoolingGradOp<GPUDevice, float>);
    432 REGISTER_KERNEL_BUILDER(Name("AvgPoolGrad")
    433                             .Device(DEVICE_GPU)
    434                             .TypeConstraint<Eigen::half>("T")
    435                             .HostMemory("orig_input_shape")
    436                             .Label("cudnn"),
    437                         AvgPoolingGradOp<GPUDevice, Eigen::half>);
    438 
    439 // A custom GPU kernel based AvgPoolingGrad implementation. It includes the
    440 // padding as the candidates for the pooling operation.
    441 template <class T>
    442 class AvgPoolingGradOpCustomGPUKernel : public OpKernel {
    443  public:
    444   typedef GPUDevice Device;
    445 
    446   explicit AvgPoolingGradOpCustomGPUKernel(OpKernelConstruction* context)
    447       : OpKernel(context) {
    448     string data_format;
    449     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
    450     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
    451                 errors::InvalidArgument("Invalid data format"));
    452     OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
    453     OP_REQUIRES(context, ksize_.size() == 4,
    454                 errors::InvalidArgument("Sliding window ksize field must "
    455                                         "specify 4 dimensions"));
    456     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
    457     OP_REQUIRES(context, stride_.size() == 4,
    458                 errors::InvalidArgument("Sliding window strides field must "
    459                                         "specify 4 dimensions"));
    460     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
    461     const int32 ksize_n = GetTensorDim(ksize_, data_format_, 'N');
    462     const int32 stride_n = GetTensorDim(stride_, data_format_, 'N');
    463     OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
    464                 errors::Unimplemented(
    465                     "Pooling is not yet supported on the batch dimension."));
    466   }
    467 
    468   void Compute(OpKernelContext* context) override {
    469     const Tensor& tensor_in_shape = context->input(0);
    470     const Tensor& out_backprop = context->input(1);
    471     // For avgpooling, tensor_in_shape should have 1 dimension, and 4 elements.
    472     OP_REQUIRES(
    473         context,
    474         tensor_in_shape.dims() == 1 && tensor_in_shape.NumElements() == 4,
    475         errors::InvalidArgument("out_backprop must be 1-dimensional and 4 "
    476                                 "elements"));
    477     // For avgpooling, out_backprop should have 4 dimensions.
    478     OP_REQUIRES(context, out_backprop.dims() == 4,
    479                 errors::InvalidArgument("out_backprop must be 4-dimensional"));
    480     TensorShape output_shape;
    481     auto shape_vec = tensor_in_shape.vec<int32>();
    482     for (int64 i = 0; i < tensor_in_shape.NumElements(); ++i) {
    483       output_shape.AddDim(shape_vec(i));
    484     }
    485 
    486     if (data_format_ == FORMAT_NHWC) {
    487       const int64 out_backprop_batch = out_backprop.dim_size(0);
    488       const int64 out_backprop_rows = out_backprop.dim_size(1);
    489       const int64 out_backprop_cols = out_backprop.dim_size(2);
    490       const int64 out_backprop_depth = out_backprop.dim_size(3);
    491 
    492       const int64 in_rows = output_shape.dim_size(1);
    493       const int64 in_cols = output_shape.dim_size(2);
    494       Tensor* output = nullptr;
    495       OP_REQUIRES_OK(context,
    496                      context->allocate_output(0, output_shape, &output));
    497 
    498       const int window_rows = ksize_[1];
    499       const int window_cols = ksize_[2];
    500       const int depth_window = ksize_[3];
    501 
    502       const int row_stride = stride_[1];
    503       const int col_stride = stride_[2];
    504 
    505       // We (will) use different code for spatial pooling and
    506       // non-spatial pooling.
    507       //
    508       // Spatial pooling is when depth_window = 1
    509       OP_REQUIRES(context, depth_window == 1,
    510                   errors::Unimplemented("Non-spatial pooling is not "
    511                                         "yet supported. Volunteers? :)"));
    512 
    513       int64 out_height, out_width, pad_rows, pad_cols;
    514       OP_REQUIRES_OK(context,
    515                      GetWindowedOutputSize(in_rows, window_rows, row_stride,
    516                                            padding_, &out_height, &pad_rows));
    517       OP_REQUIRES_OK(context,
    518                      GetWindowedOutputSize(in_cols, window_cols, col_stride,
    519                                            padding_, &out_width, &pad_cols));
    520 
    521       RunAvePoolBackwardNHWC<T>(out_backprop.flat<T>().data(),  // top_diff
    522                                 out_backprop_batch,             // num
    523                                 in_rows,                        // height
    524                                 in_cols,                        // width
    525                                 out_backprop_depth,             // channels
    526                                 out_backprop_rows,              // pooled_height
    527                                 out_backprop_cols,              // pooled_width
    528                                 window_rows,                    // kernel_h
    529                                 window_cols,                    // kernel_w
    530                                 row_stride,                     // stride_h
    531                                 col_stride,                     // stride_w
    532                                 pad_rows,                       // pad_t
    533                                 pad_cols,                       // pad_l
    534                                 output->flat<T>().data(),       // bottom_diff
    535                                 context->eigen_gpu_device());   // d
    536     } else {
    537       DnnPoolingGradOp<T>::Compute(
    538           context, perftools::gputools::dnn::PoolingMode::kAverage, ksize_,
    539           stride_, padding_, data_format_, nullptr, nullptr, out_backprop,
    540           output_shape, /*propagate_nans=*/false);
    541     }
    542   }
    543 
    544  private:
    545   std::vector<int32> ksize_;
    546   std::vector<int32> stride_;
    547   Padding padding_;
    548   TensorFormat data_format_;
    549 };
    550 
    551 REGISTER_KERNEL_BUILDER(Name("AvgPoolGrad")
    552                             .Device(DEVICE_GPU)
    553                             .TypeConstraint<float>("T")
    554                             .HostMemory("orig_input_shape"),
    555                         AvgPoolingGradOpCustomGPUKernel<float>);
    556 REGISTER_KERNEL_BUILDER(Name("AvgPoolGrad")
    557                             .Device(DEVICE_GPU)
    558                             .TypeConstraint<Eigen::half>("T")
    559                             .HostMemory("orig_input_shape"),
    560                         AvgPoolingGradOpCustomGPUKernel<Eigen::half>);
    561 
    562 #endif  // GOOGLE_CUDA
    563 
    564 }  // namespace tensorflow
    565