Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #define EIGEN_USE_THREADS
     16 
     17 #include "tensorflow/core/kernels/pooling_ops_3d.h"
     18 
     19 #include <array>
     20 
     21 #include "third_party/eigen3/Eigen/Core"
     22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     23 #include "tensorflow/core/framework/numeric_op.h"
     24 #include "tensorflow/core/framework/op_kernel.h"
     25 #include "tensorflow/core/framework/register_types.h"
     26 #include "tensorflow/core/framework/tensor.h"
     27 #include "tensorflow/core/framework/tensor_shape.h"
     28 #include "tensorflow/core/framework/tensor_slice.h"
     29 #include "tensorflow/core/kernels/eigen_pooling.h"
     30 #include "tensorflow/core/kernels/ops_util.h"
     31 #include "tensorflow/core/lib/core/errors.h"
     32 #include "tensorflow/core/util/padding.h"
     33 #include "tensorflow/core/util/tensor_format.h"
     34 #include "tensorflow/core/util/work_sharder.h"
     35 
     36 #if GOOGLE_CUDA
     37 #include "tensorflow/core/kernels/cudnn_pooling_gpu.h"
     38 #include "tensorflow/core/kernels/pooling_ops_3d_gpu.h"
     39 #endif
     40 
     41 #ifdef TENSORFLOW_USE_SYCL
     42 #include "tensorflow/core/kernels/pooling_ops_3d_sycl.h"
     43 #endif  // TENSORFLOW_USE_SYCL
     44 
     45 namespace tensorflow {
     46 
     47 typedef Eigen::ThreadPoolDevice CPUDevice;
     48 typedef Eigen::GpuDevice GPUDevice;
     49 #ifdef TENSORFLOW_USE_SYCL
     50 typedef Eigen::SyclDevice SYCLDevice;
     51 #endif  // TENSORFLOW_USE_SYCL
     52 
     53 Pool3dParameters::Pool3dParameters(OpKernelContext* context,
     54                                    const std::vector<int32>& ksize,
     55                                    const std::vector<int32>& stride,
     56                                    Padding padding, TensorFormat data_format,
     57                                    const TensorShape& tensor_in_shape) {
     58   // For maxpooling, tensor_in should have 4 dimensions.
     59   OP_REQUIRES(context, tensor_in_shape.dims() == 5,
     60               errors::InvalidArgument("tensor_in must be 4-dimensional"));
     61 
     62   this->data_format = data_format;
     63   depth = GetTensorDim(tensor_in_shape, data_format, 'C');
     64   tensor_in_planes = GetTensorDim(tensor_in_shape, data_format, '0');
     65   tensor_in_rows = GetTensorDim(tensor_in_shape, data_format, '1');
     66   tensor_in_cols = GetTensorDim(tensor_in_shape, data_format, '2');
     67   tensor_in_batch = GetTensorDim(tensor_in_shape, data_format, 'N');
     68   window_planes = GetTensorDim(ksize, data_format, '0');
     69   window_rows = GetTensorDim(ksize, data_format, '1');
     70   window_cols = GetTensorDim(ksize, data_format, '2');
     71   depth_window = GetTensorDim(ksize, data_format, 'C');
     72   plane_stride = GetTensorDim(stride, data_format, '0');
     73   row_stride = GetTensorDim(stride, data_format, '1');
     74   col_stride = GetTensorDim(stride, data_format, '2');
     75   depth_stride = GetTensorDim(stride, data_format, 'C');
     76 
     77   // We only support 3D pooling across plane/width/height. Depthwise
     78   // pooling is not supported.
     79   OP_REQUIRES(
     80       context, depth_window == 1 && depth_stride == 1,
     81       errors::Unimplemented(
     82           "Pooling3d only supports pooling across plane/width/height."));
     83 
     84   OP_REQUIRES_OK(context, GetWindowedOutputSize(tensor_in_planes, window_planes,
     85                                                 plane_stride, padding,
     86                                                 &out_plane, &pad_planes));
     87   OP_REQUIRES_OK(context,
     88                  GetWindowedOutputSize(tensor_in_rows, window_rows, row_stride,
     89                                        padding, &out_height, &pad_rows));
     90   OP_REQUIRES_OK(context,
     91                  GetWindowedOutputSize(tensor_in_cols, window_cols, col_stride,
     92                                        padding, &out_width, &pad_cols));
     93 }
     94 
     95 TensorShape Pool3dParameters::forward_output_shape() {
     96   return ShapeFromFormat(data_format, tensor_in_batch,
     97                          {{out_plane, out_height, out_width}}, depth);
     98 }
     99 
    100 template <typename T>
    101 struct LaunchPoolingOp<CPUDevice, T, AVG> {
    102   static void launch(OpKernelContext* context, const Tensor& tensor_in,
    103                      const std::array<int64, 3>& window,
    104                      const std::array<int64, 3>& stride,
    105                      const std::array<int64, 3>& padding,
    106                      TensorFormat data_format, Padding padding_type,
    107                      Tensor* output) {
    108     output->tensor<T, 5>().device(context->eigen_device<CPUDevice>()) =
    109         Eigen::CuboidAvgPooling(tensor_in.tensor<T, 5>(), window[0], window[1],
    110                                 window[2], stride[0], stride[1], stride[2],
    111                                 BrainPadding2EigenPadding(padding_type));
    112   }
    113 };
    114 
    115 template <typename T>
    116 struct LaunchPoolingOp<CPUDevice, T, MAX> {
    117   static void launch(OpKernelContext* context, const Tensor& tensor_in,
    118                      const std::array<int64, 3>& window,
    119                      const std::array<int64, 3>& stride,
    120                      const std::array<int64, 3>& padding,
    121                      TensorFormat data_format, Padding padding_type,
    122                      Tensor* output) {
    123     output->tensor<T, 5>().device(context->eigen_device<CPUDevice>()) =
    124         Eigen::CuboidMaxPooling(tensor_in.tensor<T, 5>(), window[0], window[1],
    125                                 window[2], stride[0], stride[1], stride[2],
    126                                 BrainPadding2EigenPadding(padding_type));
    127   }
    128 };
    129 
    130 template <typename Device, typename T, PoolingType Type>
    131 class Pooling3DOp : public UnaryOp<T> {
    132  public:
    133   explicit Pooling3DOp(OpKernelConstruction* context) : UnaryOp<T>(context) {
    134     string data_format;
    135     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
    136     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
    137                 errors::InvalidArgument("Invalid data format"));
    138     if (context->device_type() == DEVICE_CPU) {
    139       OP_REQUIRES(
    140           context, data_format_ == FORMAT_NHWC,
    141           errors::InvalidArgument("Default Pooling3DOp only supports NDHWC ",
    142                                   "on device type ",
    143                                   DeviceTypeString(context->device_type())));
    144     }
    145     OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
    146     OP_REQUIRES(context, ksize_.size() == 5,
    147                 errors::InvalidArgument("Sliding window ksize field must "
    148                                         "specify 5 dimensions"));
    149     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
    150     OP_REQUIRES(context, stride_.size() == 5,
    151                 errors::InvalidArgument("Sliding window stride field must "
    152                                         "specify 5 dimensions"));
    153     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
    154     OP_REQUIRES(context,
    155                 (GetTensorDim(ksize_, data_format_, 'N') == 1 &&
    156                  GetTensorDim(stride_, data_format_, 'N') == 1),
    157                 errors::Unimplemented(
    158                     "Pooling is not yet supported on the batch dimension."));
    159     OP_REQUIRES(context,
    160                 (GetTensorDim(ksize_, data_format_, 'C') == 1 &&
    161                  GetTensorDim(stride_, data_format_, 'C') == 1),
    162                 errors::Unimplemented(
    163                     "Pooling is not yet supported on the depth dimension."));
    164   }
    165 
    166   void Compute(OpKernelContext* context) override {
    167     const Tensor& tensor_in = context->input(0);
    168 
    169     OP_REQUIRES(context, tensor_in.dims() == 5,
    170                 errors::InvalidArgument("tensor_in must be 5-dimensional"));
    171     const int64 depth = GetTensorDim(tensor_in, data_format_, 'C');
    172     const int64 in_batch = GetTensorDim(tensor_in, data_format_, 'N');
    173 
    174     // Dimension order for these arrays is: x, y, z.
    175     std::array<int64, 3> input_size{
    176         {GetTensorDim(tensor_in, data_format_, '2'),
    177          GetTensorDim(tensor_in, data_format_, '1'),
    178          GetTensorDim(tensor_in, data_format_, '0')}};
    179     std::array<int64, 3> window{{GetTensorDim(ksize_, data_format_, '2'),
    180                                  GetTensorDim(ksize_, data_format_, '1'),
    181                                  GetTensorDim(ksize_, data_format_, '0')}};
    182     std::array<int64, 3> stride{{GetTensorDim(stride_, data_format_, '2'),
    183                                  GetTensorDim(stride_, data_format_, '1'),
    184                                  GetTensorDim(stride_, data_format_, '0')}};
    185     std::array<int64, 3> padding, out;
    186 
    187     OP_REQUIRES_OK(context, Get3dOutputSize(input_size, window, stride,
    188                                             padding_, &out, &padding));
    189 
    190     TensorShape out_shape = ShapeFromFormat(data_format_, in_batch,
    191                                             {{out[2], out[1], out[0]}}, depth);
    192     Tensor* output;
    193     OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
    194     LaunchPoolingOp<Device, T, Type>::launch(context, tensor_in, window, stride,
    195                                              padding, data_format_, padding_,
    196                                              output);
    197   }
    198 
    199  private:
    200   std::vector<int32> ksize_;
    201   std::vector<int32> stride_;
    202   Padding padding_;
    203   TensorFormat data_format_;
    204 };
    205 
    206 template <typename T>
    207 struct LaunchMaxPooling3dGradOp<CPUDevice, T> {
    208   static void launch(OpKernelContext* context, const Tensor& tensor_in,
    209                      const Tensor& tensor_out, const Tensor& out_backprop,
    210                      const std::array<int64, 3>& window,
    211                      const std::array<int64, 3>& stride,
    212                      const std::array<int64, 3>& out,
    213                      const std::array<int64, 3>& padding,
    214                      TensorFormat data_format, Tensor* output) {
    215     output->flat<T>().setZero();
    216     for (int64 p = 0; p < out_backprop.dim_size(3); ++p) {
    217       // Calculate broadcast size for planes/rows/cols. For SAME padding,
    218       // current index could be in the padding area, and
    219       //   p * stride_planes + window_planes
    220       // could be beyond the input tensor's boundary. In such cases, change
    221       // the starting index and reduce the broadcast size.
    222       //
    223       // The same procedure is repeated for every spatial dimension in the
    224       // nested loops below.
    225       int pindex, psize;
    226       std::array<int64, 3> input_size{{tensor_in.dim_size(3),
    227                                        tensor_in.dim_size(2),
    228                                        tensor_in.dim_size(1)}};
    229       OP_REQUIRES_OK(context,
    230                      GetBroadcastSize(p, input_size[0], window[0], stride[0],
    231                                       padding[0], &pindex, &psize));
    232       for (int64 r = 0; r < out_backprop.dim_size(2); ++r) {
    233         int rindex, rsize;
    234         OP_REQUIRES_OK(context,
    235                        GetBroadcastSize(r, input_size[1], window[1], stride[1],
    236                                         padding[1], &rindex, &rsize));
    237         for (int64 c = 0; c < out_backprop.dim_size(1); ++c) {
    238           int cindex, csize;
    239           OP_REQUIRES_OK(
    240               context, GetBroadcastSize(c, input_size[2], window[2], stride[2],
    241                                         padding[2], &cindex, &csize));
    242           TensorSlice src{{0, -1}, {c, 1}, {r, 1}, {p, 1}, {0, -1}};
    243           TensorSlice dst{{0, -1},
    244                           {cindex, csize},
    245                           {rindex, rsize},
    246                           {pindex, psize},
    247                           {0, -1}};
    248           Eigen::DSizes<Eigen::DenseIndex, 5> src_indices;
    249           Eigen::DSizes<Eigen::DenseIndex, 5> src_sizes;
    250           Eigen::DSizes<Eigen::DenseIndex, 5> dst_indices;
    251           Eigen::DSizes<Eigen::DenseIndex, 5> dst_sizes;
    252           src.FillIndicesAndSizes<5>(out_backprop.shape(), &src_indices,
    253                                      &src_sizes);
    254           dst.FillIndicesAndSizes<5>(tensor_in.shape(), &dst_indices,
    255                                      &dst_sizes);
    256 
    257 #if !defined(EIGEN_HAS_INDEX_LIST)
    258           Eigen::array<int, 5> bcast = {1, csize, rsize, psize, 1};
    259 #else
    260           Eigen::IndexList<Eigen::type2index<1>, int, int, int,
    261                            Eigen::type2index<1>>
    262               bcast;
    263           bcast.set(1, csize);
    264           bcast.set(2, rsize);
    265           bcast.set(3, psize);
    266 #endif
    267 
    268           // Slice from tensor_in.
    269           Eigen::Tensor<T, 5, Eigen::RowMajor> tensor_in_slice(dst_sizes);
    270           tensor_in_slice.device(context->eigen_cpu_device()) =
    271               tensor_in.tensor<T, 5>().slice(dst_indices, dst_sizes);
    272 
    273           // Slice from tensor_out.
    274           Eigen::Tensor<T, 5, Eigen::RowMajor> tensor_out_slice(src_sizes);
    275           tensor_out_slice.device(context->eigen_cpu_device()) =
    276               tensor_out.tensor<T, 5>().slice(src_indices, src_sizes);
    277 
    278           // Backprop slice.
    279           Eigen::Tensor<T, 5, Eigen::RowMajor> out_backprop_slice(src_sizes);
    280           out_backprop_slice.device(context->eigen_cpu_device()) =
    281               out_backprop.tensor<T, 5>().slice(src_indices, src_sizes);
    282 
    283           // The true backprop slice: if an element is the max, choose
    284           // the backprop slice; otherwise set to 0.
    285           Eigen::Tensor<T, 5, Eigen::RowMajor> select_slice(dst_sizes);
    286           Eigen::Tensor<T, 5, Eigen::RowMajor> mat0(dst_sizes);
    287           mat0.setZero();
    288           select_slice =
    289               ((tensor_in_slice - tensor_out_slice.broadcast(bcast)).abs() <
    290                tensor_in_slice.constant(1e-5))
    291                   .select(out_backprop_slice.broadcast(bcast), mat0);
    292 
    293           output->tensor<T, 5>()
    294               .slice(dst_indices, dst_sizes)
    295               .device(context->eigen_cpu_device()) += select_slice;
    296         }
    297       }
    298     }
    299   }
    300 };
    301 
    302 template <class Device, class T>
    303 class MaxPooling3dGradOp : public OpKernel {
    304  public:
    305   explicit MaxPooling3dGradOp(OpKernelConstruction* context)
    306       : OpKernel(context) {
    307     string data_format;
    308     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
    309     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
    310                 errors::InvalidArgument("Invalid data format"));
    311     if (context->device_type() == DEVICE_CPU) {
    312       OP_REQUIRES(
    313           context, data_format_ == FORMAT_NHWC,
    314           errors::InvalidArgument(
    315               "Default MaxPooling3dGradOp only supports NDHWC ",
    316               "on device type ", DeviceTypeString(context->device_type())));
    317     }
    318     OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
    319     OP_REQUIRES(context, ksize_.size() == 5,
    320                 errors::InvalidArgument("Sliding window ksize field must "
    321                                         "specify 5 dimensions"));
    322     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
    323     OP_REQUIRES(context, stride_.size() == 5,
    324                 errors::InvalidArgument("Sliding window stride field must "
    325                                         "specify 5 dimensions"));
    326     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
    327     OP_REQUIRES(context,
    328                 (GetTensorDim(ksize_, data_format_, 'N') == 1 &&
    329                  GetTensorDim(stride_, data_format_, 'N') == 1),
    330                 errors::Unimplemented(
    331                     "Pooling is not yet supported on the batch dimension."));
    332     OP_REQUIRES(context,
    333                 (GetTensorDim(ksize_, data_format_, 'C') == 1 &&
    334                  GetTensorDim(stride_, data_format_, 'C') == 1),
    335                 errors::Unimplemented(
    336                     "Pooling is not yet supported on the depth dimension."));
    337   }
    338 
    339   void Compute(OpKernelContext* context) override {
    340     const Tensor& tensor_in = context->input(0);
    341     const Tensor& tensor_out = context->input(1);
    342     const Tensor& out_backprop = context->input(2);
    343     OP_REQUIRES(context, tensor_in.dims() == 5,
    344                 errors::InvalidArgument("tensor_in must be 5-dimensional"));
    345     OP_REQUIRES(context, tensor_out.dims() == 5,
    346                 errors::InvalidArgument("tensor_out must be 5-dimensional"));
    347     OP_REQUIRES(context, out_backprop.dims() == 5,
    348                 errors::InvalidArgument("out_backprop must be 5-dimensional"));
    349 
    350     const TensorShape& output_shape = tensor_in.shape();
    351     Tensor* input_backprop;
    352     OP_REQUIRES_OK(context,
    353                    context->allocate_output(0, output_shape, &input_backprop));
    354     std::array<int64, 3> input_size{
    355         {GetTensorDim(output_shape, data_format_, '2'),
    356          GetTensorDim(output_shape, data_format_, '1'),
    357          GetTensorDim(output_shape, data_format_, '0')}};
    358     std::array<int64, 3> window{{GetTensorDim(ksize_, data_format_, '2'),
    359                                  GetTensorDim(ksize_, data_format_, '1'),
    360                                  GetTensorDim(ksize_, data_format_, '0')}};
    361     std::array<int64, 3> stride{{GetTensorDim(stride_, data_format_, '2'),
    362                                  GetTensorDim(stride_, data_format_, '1'),
    363                                  GetTensorDim(stride_, data_format_, '0')}};
    364     std::array<int64, 3> out, padding;
    365 
    366     OP_REQUIRES_OK(context, Get3dOutputSize(input_size, window, stride,
    367                                             padding_, &out, &padding));
    368     LaunchMaxPooling3dGradOp<Device, T>::launch(
    369         context, tensor_in, tensor_out, out_backprop, window, stride, out,
    370         padding, data_format_, input_backprop);
    371   }
    372 
    373  private:
    374   std::vector<int32> ksize_;
    375   std::vector<int32> stride_;
    376   Padding padding_;
    377   TensorFormat data_format_;
    378 };
    379 
    380 template <typename T>
    381 struct LaunchAvgPooling3dGradOp<CPUDevice, T> {
    382   static void launch(OpKernelContext* context,
    383                      const TensorShape& tensor_in_shape,
    384                      const Tensor& out_backprop,
    385                      const std::array<int64, 3>& window,
    386                      const std::array<int64, 3>& stride,
    387                      const std::array<int64, 3>& output_shape,
    388                      const std::array<int64, 3>& padding,
    389                      TensorFormat data_format, Tensor* output) {
    390     output->flat<T>().setZero();
    391     std::array<int64, 3> input_size = {{tensor_in_shape.dim_size(3),
    392                                         tensor_in_shape.dim_size(2),
    393                                         tensor_in_shape.dim_size(1)}};
    394     for (int64 p = 0; p < out_backprop.dim_size(3); ++p) {
    395       // Calculate broadcast size for planes/rows/cols. For SAME padding,
    396       // current index could be in the padding area, and
    397       //   p * stride_planes + window_planes
    398       // could be beyond the input tensor's boundary. In such cases, change
    399       // the starting index and reduce the broadcast size.
    400       //
    401       // The same procedure is repeated for every spatial dimension in the
    402       // nested loops below.
    403       int pindex, psize;
    404       OP_REQUIRES_OK(context,
    405                      GetBroadcastSize(p, input_size[0], window[0], stride[0],
    406                                       padding[0], &pindex, &psize));
    407       for (int64 r = 0; r < out_backprop.dim_size(2); ++r) {
    408         int rindex, rsize;
    409         OP_REQUIRES_OK(context,
    410                        GetBroadcastSize(r, input_size[1], window[1], stride[1],
    411                                         padding[1], &rindex, &rsize));
    412         for (int64 c = 0; c < out_backprop.dim_size(1); ++c) {
    413           int cindex, csize;
    414           OP_REQUIRES_OK(
    415               context, GetBroadcastSize(c, input_size[2], window[2], stride[2],
    416                                         padding[2], &cindex, &csize));
    417           TensorSlice src{{0, -1}, {c, 1}, {r, 1}, {p, 1}, {0, -1}};
    418           TensorSlice dst{{0, -1},
    419                           {cindex, csize},
    420                           {rindex, rsize},
    421                           {pindex, psize},
    422                           {0, -1}};
    423           Eigen::DSizes<Eigen::DenseIndex, 5> src_indices;
    424           Eigen::DSizes<Eigen::DenseIndex, 5> src_sizes;
    425           Eigen::DSizes<Eigen::DenseIndex, 5> dst_indices;
    426           Eigen::DSizes<Eigen::DenseIndex, 5> dst_sizes;
    427           src.FillIndicesAndSizes<5>(out_backprop.shape(), &src_indices,
    428                                      &src_sizes);
    429           dst.FillIndicesAndSizes<5>(tensor_in_shape, &dst_indices, &dst_sizes);
    430 #if !defined(EIGEN_HAS_INDEX_LIST)
    431           Eigen::array<int, 5> bcast = {1, csize, rsize, psize, 1};
    432 #else
    433           Eigen::IndexList<Eigen::type2index<1>, int, int, int,
    434                            Eigen::type2index<1>>
    435               bcast;
    436           bcast.set(1, csize);
    437           bcast.set(2, rsize);
    438           bcast.set(3, psize);
    439 #endif
    440           Eigen::Tensor<T, 5, Eigen::RowMajor> slices(src_sizes);
    441           slices.device(context->eigen_cpu_device()) =
    442               out_backprop.tensor<T, 5>().slice(src_indices, src_sizes);
    443           // Divide by the size of the actual patch (psize * rsize * csize).
    444           float divide_size = rsize * csize * psize * 1.0f;
    445           slices *= slices.constant(1.0f / divide_size);
    446 
    447           output->tensor<T, 5>()
    448               .slice(dst_indices, dst_sizes)
    449               .device(context->eigen_cpu_device()) += slices.broadcast(bcast);
    450         }
    451       }
    452     }
    453   }
    454 };
    455 
    456 template <class Device, class T>
    457 class AvgPooling3dGradOp : public OpKernel {
    458  public:
    459   explicit AvgPooling3dGradOp(OpKernelConstruction* context)
    460       : OpKernel(context) {
    461     string data_format;
    462     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
    463     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
    464                 errors::InvalidArgument("Invalid data format"));
    465     if (context->device_type() == DEVICE_CPU) {
    466       OP_REQUIRES(
    467           context, data_format_ == FORMAT_NHWC,
    468           errors::InvalidArgument(
    469               "Default AvgPooling3dGradOp only supports NDHWC ",
    470               "on device type ", DeviceTypeString(context->device_type())));
    471     }
    472     OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
    473     OP_REQUIRES(context, ksize_.size() == 5,
    474                 errors::InvalidArgument("Sliding window ksize field must "
    475                                         "specify 5 dimensions"));
    476     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
    477     OP_REQUIRES(context, stride_.size() == 5,
    478                 errors::InvalidArgument("Sliding window stride field must "
    479                                         "specify 5 dimensions"));
    480     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
    481     OP_REQUIRES(context,
    482                 (GetTensorDim(ksize_, data_format_, 'N') == 1 &&
    483                  GetTensorDim(stride_, data_format_, 'N') == 1),
    484                 errors::Unimplemented(
    485                     "Pooling is not yet supported on the batch dimension."));
    486     OP_REQUIRES(context,
    487                 (GetTensorDim(ksize_, data_format_, 'C') == 1 &&
    488                  GetTensorDim(stride_, data_format_, 'C') == 1),
    489                 errors::Unimplemented(
    490                     "Pooling is not yet supported on the depth dimension."));
    491   }
    492 
    493   void Compute(OpKernelContext* context) override {
    494     const Tensor& tensor_in_shape = context->input(0);
    495     const Tensor& out_backprop = context->input(1);
    496     OP_REQUIRES(
    497         context,
    498         tensor_in_shape.dims() == 1 && tensor_in_shape.NumElements() == 5,
    499         errors::InvalidArgument("tensor_in must be 1-dimensional and 5 "
    500                                 "elements"));
    501     OP_REQUIRES(context, out_backprop.dims() == 5,
    502                 errors::InvalidArgument("out_backprop must be 5-dimensional"));
    503 
    504     TensorShape output_shape;
    505     auto shape_vec = tensor_in_shape.vec<int32>();
    506     for (int64 i = 0; i < tensor_in_shape.NumElements(); ++i) {
    507       output_shape.AddDim(shape_vec(i));
    508     }
    509 
    510     Tensor* output;
    511     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
    512 
    513     // Dimension order for these arrays is x, y, z.
    514     std::array<int64, 3> input_size{
    515         {GetTensorDim(output_shape, data_format_, '2'),
    516          GetTensorDim(output_shape, data_format_, '1'),
    517          GetTensorDim(output_shape, data_format_, '0')}};
    518     std::array<int64, 3> window{{GetTensorDim(ksize_, data_format_, '2'),
    519                                  GetTensorDim(ksize_, data_format_, '1'),
    520                                  GetTensorDim(ksize_, data_format_, '0')}};
    521     std::array<int64, 3> stride{{GetTensorDim(stride_, data_format_, '2'),
    522                                  GetTensorDim(stride_, data_format_, '1'),
    523                                  GetTensorDim(stride_, data_format_, '0')}};
    524     std::array<int64, 3> padding, out;
    525 
    526     OP_REQUIRES_OK(context, Get3dOutputSize(input_size, window, stride,
    527                                             padding_, &out, &padding));
    528 
    529     LaunchAvgPooling3dGradOp<Device, T>::launch(
    530         context, output_shape, out_backprop, window, stride, out, padding,
    531         data_format_, output);
    532   }
    533 
    534  private:
    535   std::vector<int32> ksize_;
    536   std::vector<int32> stride_;
    537   Padding padding_;
    538   TensorFormat data_format_;
    539 };
    540 
    541 template <typename T>
    542 struct LaunchMaxPooling3dGradGradOp<CPUDevice, T> {
    543   static void launch(OpKernelContext* context, const Pool3dParameters& params,
    544                      const Tensor& tensor_in, const Tensor& tensor_out,
    545                      const Tensor& tensor_top_diff,
    546                      Tensor* tensor_bottom_diff) {
    547     OP_REQUIRES(
    548         context, params.data_format == FORMAT_NHWC,
    549         errors::InvalidArgument("Default MaxPooling3dGradGradOp only supports",
    550                                 "NDHWC on CPU device type"));
    551 
    552     typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
    553         ConstEigenMatrixMap;
    554     typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
    555         EigenMatrixMap;
    556 
    557     ConstEigenMatrixMap in_mat(tensor_in.flat<T>().data(), params.depth,
    558                                params.tensor_in_planes * params.tensor_in_cols *
    559                                    params.tensor_in_rows *
    560                                    params.tensor_in_batch);
    561     ConstEigenMatrixMap out_mat(tensor_out.flat<T>().data(), params.depth,
    562                                 params.out_plane * params.out_width *
    563                                     params.out_height * params.tensor_in_batch);
    564     ConstEigenMatrixMap top_diff_mat(
    565         tensor_top_diff.flat<T>().data(), params.depth,
    566         params.tensor_in_planes * params.tensor_in_cols *
    567             params.tensor_in_rows * params.tensor_in_batch);
    568     EigenMatrixMap bottom_diff_mat(
    569         tensor_bottom_diff->flat<T>().data(), params.depth,
    570         params.out_plane * params.out_width * params.out_height *
    571             params.tensor_in_batch);
    572 
    573     const DeviceBase::CpuWorkerThreads& worker_threads =
    574         *(context->device()->tensorflow_cpu_worker_threads());
    575 
    576     auto shard = [&params, &in_mat, &out_mat, &top_diff_mat, &bottom_diff_mat](
    577                      int64 start, int64 limit) {
    578       const int32 depth = params.depth;
    579       const int32 in_planes = params.tensor_in_planes;
    580       const int32 in_rows = params.tensor_in_rows;
    581       const int32 in_cols = params.tensor_in_cols;
    582       const int32 pad_planes = params.pad_planes;
    583       const int32 pad_rows = params.pad_rows;
    584       const int32 pad_cols = params.pad_cols;
    585       const int32 window_planes = params.window_planes;
    586       const int32 window_rows = params.window_rows;
    587       const int32 window_cols = params.window_cols;
    588       const int32 plane_stride = params.plane_stride;
    589       const int32 row_stride = params.row_stride;
    590       const int32 col_stride = params.col_stride;
    591       const int32 out_plane = params.out_plane;
    592       const int32 out_height = params.out_height;
    593       const int32 out_width = params.out_width;
    594 
    595       {
    596         // Initializes the output grad backprop tensor with 0.
    597         const int32 output_image_size =
    598             out_plane * out_height * out_width * params.depth;
    599         EigenMatrixMap bottom_diff_shard(
    600             bottom_diff_mat.data() + start * output_image_size, 1,
    601             (limit - start) * output_image_size);
    602         bottom_diff_shard.setZero();
    603       }
    604 
    605       for (int b = start; b < limit; ++b) {
    606         for (int pp = 0; pp < out_plane; ++pp) {
    607           for (int ph = 0; ph < out_height; ++ph) {
    608             for (int pw = 0; pw < out_width; ++pw) {
    609               // (p_start, p_end) * (h_start, h_end) * (w_start, w_end) is the
    610               // range that the input vector projects to.
    611               int p_start = pp * plane_stride - pad_planes;
    612               const int p_end = std::min(p_start + window_planes, in_planes);
    613               int h_start = ph * row_stride - pad_rows;
    614               const int h_end = std::min(h_start + window_rows, in_rows);
    615               int w_start = pw * col_stride - pad_cols;
    616               const int w_end = std::min(w_start + window_cols, in_cols);
    617               p_start = std::max(p_start, 0);
    618               h_start = std::max(h_start, 0);
    619               w_start = std::max(w_start, 0);
    620               const int out_index =
    621                   ((b * out_plane + pp) * out_height + ph) * out_width + pw;
    622               // Find value corresponding to the input maximum in top_diff.
    623               for (int d = 0; d < depth; ++d) {
    624                 const T& output_ref = out_mat.coeffRef(d, out_index);
    625                 bool should_stop = false;
    626                 for (int p = p_start; p < p_end && !should_stop; ++p) {
    627                   for (int h = h_start; h < h_end && !should_stop; ++h) {
    628                     for (int w = w_start; w < w_end && !should_stop; ++w) {
    629                       const int in_index =
    630                           ((b * in_planes + p) * in_rows + h) * in_cols + w;
    631                       const T& input_ref = in_mat.coeffRef(d, in_index);
    632                       if (output_ref == input_ref) {
    633                         T& bottom_diff_ref =
    634                             bottom_diff_mat.coeffRef(d, out_index);
    635                         bottom_diff_ref = top_diff_mat.coeffRef(d, in_index);
    636                         should_stop = true;
    637                       }
    638                     }
    639                   }
    640                 }
    641               }
    642             }
    643           }
    644         }
    645       }
    646     };
    647     const int64 shard_cost =
    648         params.out_plane * params.out_height * params.out_width * params.depth *
    649         params.window_planes * params.window_rows * params.window_cols;
    650     Shard(worker_threads.num_threads, worker_threads.workers,
    651           params.tensor_in_batch, shard_cost, shard);
    652   }
    653 };
    654 
    655 template <class Device, class T>
    656 class MaxPooling3dGradGradOp : public OpKernel {
    657  public:
    658   explicit MaxPooling3dGradGradOp(OpKernelConstruction* context)
    659       : OpKernel(context) {
    660     string data_format;
    661     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
    662     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
    663                 errors::InvalidArgument("Invalid data format"));
    664     OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
    665     OP_REQUIRES(context, ksize_.size() == 5,
    666                 errors::InvalidArgument("Sliding window ksize field must "
    667                                         "specify 5 dimensions"));
    668     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
    669     OP_REQUIRES(context, stride_.size() == 5,
    670                 errors::InvalidArgument("Sliding window strides field must "
    671                                         "specify 5 dimensions"));
    672     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
    673     OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
    674                 errors::Unimplemented(
    675                     "Pooling is not yet supported on the batch dimension."));
    676     const int32 ksize_c = GetTensorDim(ksize_, data_format_, 'C');
    677     const int32 stride_c = GetTensorDim(stride_, data_format_, 'C');
    678     OP_REQUIRES(context, ksize_c == 1 && stride_c == 1,
    679                 errors::Unimplemented("MaxPooling3dGradGrad is not yet "
    680                                       "supported on the depth dimension."));
    681   }
    682 
    683   void Compute(OpKernelContext* context) override {
    684     const Tensor& tensor_in = context->input(0);
    685     const Tensor& tensor_out = context->input(1);
    686     const Tensor& out_grad_backprop = context->input(2);
    687 
    688     // For maxpooling3d, tensor_in should have 5 dimensions.
    689     OP_REQUIRES(context, tensor_in.dims() == 5,
    690                 errors::InvalidArgument("tensor_in must be 5-dimensional"));
    691     OP_REQUIRES(context, tensor_out.dims() == 5,
    692                 errors::InvalidArgument("tensor_out must be 5-dimensional"));
    693     // For maxpooling3d, out_grad_backprop should have 5 dimensions.
    694     OP_REQUIRES(
    695         context, out_grad_backprop.dims() == 5,
    696         errors::InvalidArgument("out_grad_backprop must be 5-dimensional"));
    697 
    698     Pool3dParameters params{context,  ksize_,       stride_,
    699                             padding_, data_format_, tensor_in.shape()};
    700 
    701     Tensor* output = nullptr;
    702     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
    703                                 {2}, 0, tensor_out.shape(), &output));
    704 
    705     LaunchMaxPooling3dGradGradOp<Device, T>::launch(
    706         context, params, tensor_in, tensor_out, out_grad_backprop, output);
    707   }
    708 
    709  private:
    710   std::vector<int32> ksize_;
    711   std::vector<int32> stride_;
    712   Padding padding_;
    713   TensorFormat data_format_;
    714 };
    715 
    716 #define REGISTER_KERNELS(D, T)                                             \
    717   REGISTER_KERNEL_BUILDER(                                                 \
    718       Name("MaxPool3D").Device(DEVICE_##D).TypeConstraint<T>("T"),         \
    719       Pooling3DOp<D##Device, T, MAX>);                                     \
    720   REGISTER_KERNEL_BUILDER(Name("MaxPool3DGrad")                            \
    721                               .Device(DEVICE_##D)                          \
    722                               .TypeConstraint<T>("T")                      \
    723                               .TypeConstraint<T>("TInput"),                \
    724                           MaxPooling3dGradOp<D##Device, T>);               \
    725   REGISTER_KERNEL_BUILDER(                                                 \
    726       Name("MaxPool3DGradGrad").Device(DEVICE_##D).TypeConstraint<T>("T"), \
    727       MaxPooling3dGradGradOp<D##Device, T>);                               \
    728   REGISTER_KERNEL_BUILDER(                                                 \
    729       Name("AvgPool3D").Device(DEVICE_##D).TypeConstraint<T>("T"),         \
    730       Pooling3DOp<D##Device, T, AVG>);                                     \
    731   REGISTER_KERNEL_BUILDER(Name("AvgPool3DGrad")                            \
    732                               .Device(DEVICE_##D)                          \
    733                               .TypeConstraint<T>("T")                      \
    734                               .HostMemory("orig_input_shape"),             \
    735                           AvgPooling3dGradOp<D##Device, T>);
    736 
    737 #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T)
    738 TF_CALL_float(REGISTER_CPU_KERNELS);
    739 #undef REGISTER_CPU_KERNELS
    740 
    741 #if GOOGLE_CUDA
    742 
    743 template <typename T>
    744 struct LaunchPoolingOp<GPUDevice, T, AVG> {
    745   static void launch(OpKernelContext* context, const Tensor& tensor_in,
    746                      const std::array<int64, 3>& window,
    747                      const std::array<int64, 3>& stride,
    748                      const std::array<int64, 3>& padding,
    749                      TensorFormat data_format, Padding padding_type,
    750                      Tensor* output) {
    751     DnnPooling3dOp<T>::Compute(
    752         context, perftools::gputools::dnn::PoolingMode::kAverage, window,
    753         stride, padding, data_format, tensor_in, output);
    754   }
    755 };
    756 
    757 template <typename T>
    758 struct LaunchPoolingOp<GPUDevice, T, MAX> {
    759   static void launch(OpKernelContext* context, const Tensor& tensor_in,
    760                      const std::array<int64, 3>& window,
    761                      const std::array<int64, 3>& stride,
    762                      const std::array<int64, 3>& padding,
    763                      TensorFormat data_format, Padding padding_type,
    764                      Tensor* output) {
    765     DnnPooling3dOp<T>::Compute(
    766         context, perftools::gputools::dnn::PoolingMode::kMaximum, window,
    767         stride, padding, data_format, tensor_in, output);
    768   }
    769 };
    770 
    771 template <typename T>
    772 struct LaunchMaxPooling3dGradOp<GPUDevice, T> {
    773   static void launch(OpKernelContext* context, const Tensor& tensor_in,
    774                      const Tensor& tensor_out, const Tensor& out_backprop,
    775                      const std::array<int64, 3>& window,
    776                      const std::array<int64, 3>& stride,
    777                      const std::array<int64, 3>& out,
    778                      const std::array<int64, 3>& padding,
    779                      TensorFormat data_format, Tensor* input_backprop) {
    780     const TensorShape output_shape = tensor_in.shape();
    781     DnnPooling3dGradOp<T>::Compute(
    782         context, perftools::gputools::dnn::PoolingMode::kMaximum, window,
    783         stride, padding, out, data_format, out_backprop, output_shape,
    784         &tensor_in, &tensor_out, input_backprop);
    785   }
    786 };
    787 
    788 template <typename T>
    789 struct LaunchAvgPooling3dGradOp<GPUDevice, T> {
    790   static void launch(OpKernelContext* context,
    791                      const TensorShape& tensor_in_shape,
    792                      const Tensor& out_backprop,
    793                      const std::array<int64, 3>& window,
    794                      const std::array<int64, 3>& stride,
    795                      const std::array<int64, 3>& out,
    796                      const std::array<int64, 3>& padding,
    797                      TensorFormat data_format, Tensor* output) {
    798     DnnPooling3dGradOp<T>::Compute(
    799         context, perftools::gputools::dnn::PoolingMode::kAverage, window,
    800         stride, padding, out, data_format, out_backprop, tensor_in_shape,
    801         nullptr, nullptr, output);
    802   }
    803 };
    804 
    805 template <typename T>
    806 struct LaunchMaxPooling3dGradGradOp<GPUDevice, T> {
    807   static void launch(OpKernelContext* context, const Pool3dParameters& params,
    808                      const Tensor& tensor_in, const Tensor& tensor_out,
    809                      const Tensor& tensor_top_diff,
    810                      Tensor* tensor_bottom_diff) {
    811     bool status = functor::MaxPool3dGradBackward<T>()(
    812         params.data_format, tensor_in.flat<T>().data(),
    813         tensor_out.flat<T>().data(), params.tensor_in_batch, params.out_plane,
    814         params.out_height, params.out_width, params.depth,
    815         params.tensor_in_planes, params.tensor_in_rows, params.tensor_in_cols,
    816         params.window_planes, params.window_rows, params.window_cols,
    817         params.plane_stride, params.row_stride, params.col_stride,
    818         params.pad_planes, params.pad_rows, params.pad_cols,
    819         tensor_top_diff.flat<T>().data(), tensor_bottom_diff->flat<T>().data(),
    820         context->eigen_gpu_device());
    821     if (!status) {
    822       context->SetStatus(
    823           errors::Internal("Failed launching MaxPool3dGradBackward"));
    824     }
    825   }
    826 };
    827 
    828 #define REGISTER_GPU_KERNELS(T) REGISTER_KERNELS(GPU, T)
    829 TF_CALL_float(REGISTER_GPU_KERNELS) TF_CALL_half(REGISTER_GPU_KERNELS)
    830 #undef REGISTER_GPU_KERNELS
    831 
    832 #endif  // GOOGLE_CUDA
    833 
    834 #ifdef TENSORFLOW_USE_SYCL
    835 #define REGISTER_SYCL_KERNELS(T) REGISTER_KERNELS(SYCL, T)
    836     TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNELS)
    837 #undef REGISTER_SYCL_KERNELS
    838 #endif  // TENSORFLOW_USE_SYCL
    839 
    840 #undef REGISTER_KERNELS
    841 
    842 }  // namespace tensorflow
    843