Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_KERNELS_POOLING_OPS_COMMON_H_
     17 #define TENSORFLOW_KERNELS_POOLING_OPS_COMMON_H_
     18 
     19 #include <vector>
     20 
     21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     22 #include "tensorflow/core/framework/numeric_op.h"
     23 #include "tensorflow/core/framework/op_kernel.h"
     24 #include "tensorflow/core/framework/tensor_shape.h"
     25 #include "tensorflow/core/kernels/avgpooling_op.h"
     26 #include "tensorflow/core/kernels/maxpooling_op.h"
     27 #include "tensorflow/core/kernels/ops_util.h"
     28 #include "tensorflow/core/util/padding.h"
     29 #include "tensorflow/core/util/tensor_format.h"
     30 #include "tensorflow/core/util/work_sharder.h"
     31 
     32 #if GOOGLE_CUDA
     33 #include "tensorflow/core/kernels/maxpooling_op_gpu.h"
     34 #endif  // GOOGLE_CUDA
     35 
     36 namespace tensorflow {
     37 
     38 typedef Eigen::GpuDevice GPUDevice;
     39 
     40 // A helper class to manage sizes and shapes for pooling operations.
     41 struct PoolParameters {
     42   // Updates context->status if there is an invalid input.
     43   PoolParameters(OpKernelContext* context, const std::vector<int32>& ksize,
     44                  const std::vector<int32>& stride, Padding padding,
     45                  TensorFormat data_format, const TensorShape& tensor_in_shape);
     46 
     47   // Returns the shape of the output for "forward" pooling operations.
     48   TensorShape forward_output_shape();
     49 
     50   int depth;
     51 
     52   int tensor_in_cols;
     53   int tensor_in_rows;
     54   int tensor_in_batch;
     55 
     56   int window_rows;
     57   int window_cols;
     58   int depth_window;
     59 
     60   int row_stride;
     61   int col_stride;
     62   int depth_stride;
     63 
     64   int64 out_height;
     65   int64 out_width;
     66   int out_depth;
     67 
     68   int64 pad_rows;
     69   int64 pad_cols;
     70   int pad_depth;
     71 
     72   TensorFormat data_format;
     73 };
     74 
     75 // An implementation of MaxPooling (forward).
     76 // TODO (yongtang): Remove MaxPoolingOp and use MaxPoolingV2Op,
     77 //     QuantizedMaxPoolingOp depends on MaxPoolingOp so keep intact for now
     78 template <typename Device, typename T>
     79 class MaxPoolingOp : public OpKernel {
     80  public:
     81   explicit MaxPoolingOp(OpKernelConstruction* context) : OpKernel(context) {
     82     string data_format;
     83     auto status = context->GetAttr("data_format", &data_format);
     84     if (status.ok()) {
     85       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
     86                   errors::InvalidArgument("Invalid data format"));
     87       OP_REQUIRES(
     88           context, data_format_ == FORMAT_NHWC,
     89           errors::InvalidArgument("Default MaxPoolingOp only supports NHWC ",
     90                                   "on device type ",
     91                                   DeviceTypeString(context->device_type())));
     92     } else {
     93       data_format_ = FORMAT_NHWC;
     94     }
     95     OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
     96     OP_REQUIRES(context, ksize_.size() == 4,
     97                 errors::InvalidArgument("Sliding window ksize field must "
     98                                         "specify 4 dimensions"));
     99     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
    100     OP_REQUIRES(context, stride_.size() == 4,
    101                 errors::InvalidArgument("Sliding window stride field must "
    102                                         "specify 4 dimensions"));
    103     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
    104     OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
    105                 errors::Unimplemented(
    106                     "Pooling is not yet supported on the batch dimension."));
    107   }
    108 
    109   void Compute(OpKernelContext* context) override {
    110     const Tensor& tensor_in = context->input(0);
    111     PoolParameters params{context,  ksize_,      stride_,
    112                           padding_, FORMAT_NHWC, tensor_in.shape()};
    113     if (!context->status().ok()) {
    114       return;
    115     }
    116 
    117     Tensor* output = nullptr;
    118     OP_REQUIRES_OK(context, context->allocate_output(
    119                                 0, params.forward_output_shape(), &output));
    120 
    121     if (params.depth_window > 1) {
    122       // Validate spec against the current implementation.  A
    123       // relaxation of these requirements would be ideal.
    124       OP_REQUIRES(context, params.depth % params.depth_window == 0,
    125                   errors::Unimplemented(
    126                       "Depthwise max pooling requires "
    127                       "the depth window to evenly divide the input depth."));
    128       OP_REQUIRES(
    129           context, params.depth_window == params.depth_stride,
    130           errors::Unimplemented("Depthwise max pooling requires "
    131                                 "the depth window to equal the depth stride."));
    132 
    133       DepthwiseMaxPool(context, output, tensor_in, params);
    134     } else {
    135       SpatialMaxPool(context, output, tensor_in, params, padding_);
    136     }
    137   }
    138 
    139  private:
    140   // Single-threaded implementation of DepthwiseMaxPool which
    141   // does not handle all of the same options as SpatialMaxPool
    142   // (strict assumptions on no padding, stride).
    143   //
    144   // TODO(vrv): implement a more general depthwise-max pool that works
    145   // on GPU as well.
    146   void DepthwiseMaxPool(OpKernelContext* context, Tensor* output,
    147                         const Tensor& tensor_in, const PoolParameters& params) {
    148     Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
    149         in_by_pool(tensor_in.flat<T>().data(), params.depth_window,
    150                    tensor_in.NumElements() / params.depth_window);
    151     Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> out_by_pool(
    152         output->flat<T>().data(), 1, output->NumElements());
    153     out_by_pool = in_by_pool.colwise().maxCoeff();
    154   }
    155 
    156   void SpatialMaxPool(OpKernelContext* context, Tensor* output,
    157                       const Tensor& tensor_in, const PoolParameters& params,
    158                       const Padding& padding) {
    159     // On GPU, use Eigen's Spatial Max Pooling.  On CPU, use an
    160     // EigenMatrix version that is currently faster than Eigen's
    161     // Spatial MaxPooling implementation.
    162     //
    163     // TODO(vrv): Remove this once we no longer need it.
    164     if (std::is_same<Device, GPUDevice>::value) {
    165       Eigen::PaddingType pt = BrainPadding2EigenPadding(padding);
    166       functor::SpatialMaxPooling<Device, T>()(
    167           context->eigen_device<Device>(), output->tensor<T, 4>(),
    168           tensor_in.tensor<T, 4>(), params.window_rows, params.window_cols,
    169           params.row_stride, params.col_stride, pt);
    170     } else {
    171       typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
    172           ConstEigenMatrixMap;
    173       typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
    174           EigenMatrixMap;
    175 
    176       ConstEigenMatrixMap in_mat(tensor_in.flat<T>().data(), params.depth,
    177                                  params.tensor_in_cols * params.tensor_in_rows *
    178                                      params.tensor_in_batch);
    179       EigenMatrixMap out_mat(
    180           output->flat<T>().data(), params.depth,
    181           params.out_width * params.out_height * params.tensor_in_batch);
    182 
    183       const DeviceBase::CpuWorkerThreads& worker_threads =
    184           *(context->device()->tensorflow_cpu_worker_threads());
    185 
    186       // The following code basically does the following:
    187       // 1. Flattens the input and output tensors into two dimensional arrays.
    188       //    tensor_in_as_matrix:
    189       //      depth by (tensor_in_cols * tensor_in_rows * tensor_in_batch)
    190       //    output_as_matrix:
    191       //      depth by (out_width * out_height * tensor_in_batch)
    192       //
    193       // 2. Walks through the set of columns in the flattened
    194       // tensor_in_as_matrix,
    195       //    and updates the corresponding column(s) in output_as_matrix with the
    196       //    max value.
    197       auto shard = [&params, &in_mat, &out_mat](int64 start, int64 limit) {
    198         const int32 in_rows = params.tensor_in_rows;
    199         const int32 in_cols = params.tensor_in_cols;
    200         const int32 pad_rows = params.pad_rows;
    201         const int32 pad_cols = params.pad_cols;
    202         const int32 window_rows = params.window_rows;
    203         const int32 window_cols = params.window_cols;
    204         const int32 row_stride = params.row_stride;
    205         const int32 col_stride = params.col_stride;
    206         const int32 out_height = params.out_height;
    207         const int32 out_width = params.out_width;
    208 
    209         {
    210           // Initializes the output tensor with MIN<T>.
    211           const int32 output_image_size = out_height * out_width * params.depth;
    212           EigenMatrixMap out_shard(out_mat.data() + start * output_image_size,
    213                                    1, (limit - start) * output_image_size);
    214           out_shard.setConstant(Eigen::NumTraits<T>::lowest());
    215         }
    216 
    217         for (int32 b = start; b < limit; ++b) {
    218           const int32 out_offset_batch = b * out_height;
    219           for (int32 h = 0; h < in_rows; ++h) {
    220             for (int32 w = 0; w < in_cols; ++w) {
    221               // (h_start, h_end) * (w_start, w_end) is the range that the input
    222               // vector projects to.
    223               const int32 hpad = h + pad_rows;
    224               const int32 wpad = w + pad_cols;
    225               const int32 h_start = (hpad < window_rows)
    226                                         ? 0
    227                                         : (hpad - window_rows) / row_stride + 1;
    228               const int32 h_end = std::min(hpad / row_stride + 1, out_height);
    229               const int32 w_start = (wpad < window_cols)
    230                                         ? 0
    231                                         : (wpad - window_cols) / col_stride + 1;
    232               const int32 w_end = std::min(wpad / col_stride + 1, out_width);
    233               // compute elementwise max
    234               const int32 in_offset = (b * in_rows + h) * in_cols + w;
    235               for (int32 ph = h_start; ph < h_end; ++ph) {
    236                 const int32 out_offset_base =
    237                     (out_offset_batch + ph) * out_width;
    238                 for (int32 pw = w_start; pw < w_end; ++pw) {
    239                   const int32 out_offset = out_offset_base + pw;
    240                   out_mat.col(out_offset) =
    241                       out_mat.col(out_offset).cwiseMax(in_mat.col(in_offset));
    242                 }
    243               }
    244             }
    245           }
    246         }
    247       };
    248 
    249       // TODO(andydavis) Consider sharding across batch x rows x cols.
    250       // TODO(andydavis) Consider a higher resolution shard cost model.
    251       const int64 shard_cost =
    252           params.tensor_in_rows * params.tensor_in_cols * params.depth;
    253       Shard(worker_threads.num_threads, worker_threads.workers,
    254             params.tensor_in_batch, shard_cost, shard);
    255     }
    256   }
    257 
    258   std::vector<int32> ksize_;
    259   std::vector<int32> stride_;
    260   Padding padding_;
    261   TensorFormat data_format_;
    262 };
    263 
    264 template <typename Device>
    265 struct LaunchMaxPoolingNoMask_NCHW_VECT_C;
    266 
    267 #ifdef GOOGLE_CUDA
    268 template <>
    269 struct LaunchMaxPoolingNoMask_NCHW_VECT_C<Eigen::GpuDevice> {
    270   static void launch(OpKernelContext* context, const PoolParameters& params,
    271                      const Tensor& input, Tensor* output) {
    272     bool status = functor::MaxPoolForwardNoMask_NCHW_VECT_C()(
    273         reinterpret_cast<const int32*>(input.flat<qint8>().data()),
    274         params.tensor_in_batch, params.tensor_in_rows, params.tensor_in_cols,
    275         params.depth, params.out_height, params.out_width, params.window_rows,
    276         params.window_cols, params.row_stride, params.col_stride,
    277         params.pad_rows, params.pad_cols,
    278         reinterpret_cast<int32*>(output->flat<qint8>().data()),
    279         context->eigen_gpu_device());
    280     if (!status) {
    281       context->SetStatus(errors::Internal(
    282           "Failed launching LaunchMaxPoolingNoMask_NCHW_VECT_C"));
    283     }
    284   }
    285 };
    286 #endif
    287 
    288 template <typename Device, typename T>
    289 class MaxPoolingV2Op : public OpKernel {
    290  public:
    291   explicit MaxPoolingV2Op(OpKernelConstruction* context) : OpKernel(context) {
    292     string data_format;
    293     auto status = context->GetAttr("data_format", &data_format);
    294     if (status.ok()) {
    295       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
    296                   errors::InvalidArgument("Invalid data format"));
    297       OP_REQUIRES(
    298           context,
    299           data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW_VECT_C,
    300           errors::InvalidArgument(
    301               "MaxPoolingV2Op only supports NHWC or NCHW_VECT_C. Got: ",
    302               data_format));
    303     } else {
    304       data_format_ = FORMAT_NHWC;
    305     }
    306     if (context->num_inputs() == 1) {
    307       OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
    308       OP_REQUIRES(context, ksize_.size() == 4,
    309                   errors::InvalidArgument("Sliding window ksize field must "
    310                                           "specify 4 dimensions"));
    311       OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
    312       OP_REQUIRES(context, stride_.size() == 4,
    313                   errors::InvalidArgument("Sliding window stride field must "
    314                                           "specify 4 dimensions"));
    315       OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
    316                   errors::Unimplemented(
    317                       "Pooling is not yet supported on the batch dimension."));
    318     }
    319     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
    320   }
    321 
    322   void Compute(OpKernelContext* context) override {
    323     const Tensor& tensor_in = context->input(0);
    324 
    325     std::vector<int32> ksize = ksize_;
    326     std::vector<int32> stride = stride_;
    327 
    328     if (context->num_inputs() != 1) {
    329       const Tensor& tensor_ksize = context->input(1);
    330       auto value_ksize = tensor_ksize.flat<int32>();
    331       ksize.resize(tensor_ksize.shape().num_elements());
    332       std::copy_n(&value_ksize(0), ksize.size(), ksize.begin());
    333 
    334       const Tensor& tensor_stride = context->input(2);
    335       auto value_stride = tensor_stride.flat<int32>();
    336       stride.resize(tensor_stride.shape().num_elements());
    337       std::copy_n(&value_stride(0), stride.size(), stride.begin());
    338     }
    339 
    340     OP_REQUIRES(context, ksize.size() == 4,
    341                 errors::InvalidArgument("Sliding window ksize field must "
    342                                         "specify 4 dimensions"));
    343     OP_REQUIRES(context, stride.size() == 4,
    344                 errors::InvalidArgument("Sliding window stride field must "
    345                                         "specify 4 dimensions"));
    346     OP_REQUIRES(context, ksize[0] == 1 && stride[0] == 1,
    347                 errors::Unimplemented(
    348                     "Pooling is not yet supported on the batch dimension."));
    349 
    350     PoolParameters params{context,  ksize,        stride,
    351                           padding_, data_format_, tensor_in.shape()};
    352     if (!context->status().ok()) {
    353       return;
    354     }
    355 
    356     Tensor* output = nullptr;
    357     OP_REQUIRES_OK(context, context->allocate_output(
    358                                 0, params.forward_output_shape(), &output));
    359 
    360     if (params.depth_window > 1) {
    361       // Validate spec against the current implementation.  A
    362       // relaxation of these requirements would be ideal.
    363       OP_REQUIRES(context, params.depth % params.depth_window == 0,
    364                   errors::Unimplemented(
    365                       "Depthwise max pooling requires "
    366                       "the depth window to evenly divide the input depth."));
    367       OP_REQUIRES(
    368           context, params.depth_window == params.depth_stride,
    369           errors::Unimplemented("Depthwise max pooling requires "
    370                                 "the depth window to equal the depth stride."));
    371 
    372       DepthwiseMaxPool(context, output, tensor_in, params);
    373     } else {
    374       SpatialMaxPool(context, output, tensor_in, params, padding_);
    375     }
    376   }
    377 
    378  private:
    379   // Single-threaded implementation of DepthwiseMaxPool which
    380   // does not handle all of the same options as SpatialMaxPool
    381   // (strict assumptions on no padding, stride).
    382   //
    383   // TODO(vrv): implement a more general depthwise-max pool that works
    384   // on GPU as well.
    385   void DepthwiseMaxPool(OpKernelContext* context, Tensor* output,
    386                         const Tensor& tensor_in, const PoolParameters& params) {
    387     Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
    388         in_by_pool(tensor_in.flat<T>().data(), params.depth_window,
    389                    tensor_in.NumElements() / params.depth_window);
    390     Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> out_by_pool(
    391         output->flat<T>().data(), 1, output->NumElements());
    392     out_by_pool = in_by_pool.colwise().maxCoeff();
    393   }
    394 
    395   void SpatialMaxPool(OpKernelContext* context, Tensor* output,
    396                       const Tensor& tensor_in, const PoolParameters& params,
    397                       const Padding& padding) {
    398     // On GPU, use Eigen's Spatial Max Pooling.  On CPU, use an
    399     // EigenMatrix version that is currently faster than Eigen's
    400     // Spatial MaxPooling implementation.
    401     //
    402     // TODO(vrv): Remove this once we no longer need it.
    403 #ifdef GOOGLE_CUDA
    404     if (std::is_same<Device, GPUDevice>::value) {
    405       Eigen::PaddingType pt = BrainPadding2EigenPadding(padding);
    406       if (std::is_same<T, qint8>::value) {
    407         LaunchMaxPoolingNoMask_NCHW_VECT_C<GPUDevice>::launch(
    408             context, params, tensor_in, output);
    409       } else {
    410         functor::SpatialMaxPooling<Device, T>()(
    411             context->eigen_device<Device>(), output->tensor<T, 4>(),
    412             tensor_in.tensor<T, 4>(), params.window_rows, params.window_cols,
    413             params.row_stride, params.col_stride, pt);
    414       }
    415     } else
    416 #endif
    417     {
    418       typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
    419           ConstEigenMatrixMap;
    420       typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
    421           EigenMatrixMap;
    422 
    423       ConstEigenMatrixMap in_mat(tensor_in.flat<T>().data(), params.depth,
    424                                  params.tensor_in_cols * params.tensor_in_rows *
    425                                      params.tensor_in_batch);
    426       EigenMatrixMap out_mat(
    427           output->flat<T>().data(), params.depth,
    428           params.out_width * params.out_height * params.tensor_in_batch);
    429 
    430       const DeviceBase::CpuWorkerThreads& worker_threads =
    431           *(context->device()->tensorflow_cpu_worker_threads());
    432 
    433       // The following code basically does the following:
    434       // 1. Flattens the input and output tensors into two dimensional arrays.
    435       //    tensor_in_as_matrix:
    436       //      depth by (tensor_in_cols * tensor_in_rows * tensor_in_batch)
    437       //    output_as_matrix:
    438       //      depth by (out_width * out_height * tensor_in_batch)
    439       //
    440       // 2. Walks through the set of columns in the flattened
    441       // tensor_in_as_matrix,
    442       //    and updates the corresponding column(s) in output_as_matrix with the
    443       //    max value.
    444       auto shard = [&params, &in_mat, &out_mat](int64 start, int64 limit) {
    445         const int32 in_rows = params.tensor_in_rows;
    446         const int32 in_cols = params.tensor_in_cols;
    447         const int32 pad_rows = params.pad_rows;
    448         const int32 pad_cols = params.pad_cols;
    449         const int32 window_rows = params.window_rows;
    450         const int32 window_cols = params.window_cols;
    451         const int32 row_stride = params.row_stride;
    452         const int32 col_stride = params.col_stride;
    453         const int32 out_height = params.out_height;
    454         const int32 out_width = params.out_width;
    455 
    456         {
    457           // Initializes the output tensor with MIN<T>.
    458           const int32 output_image_size = out_height * out_width * params.depth;
    459           EigenMatrixMap out_shard(out_mat.data() + start * output_image_size,
    460                                    1, (limit - start) * output_image_size);
    461           out_shard.setConstant(Eigen::NumTraits<T>::lowest());
    462         }
    463 
    464         for (int32 b = start; b < limit; ++b) {
    465           const int32 out_offset_batch = b * out_height;
    466           for (int32 h = 0; h < in_rows; ++h) {
    467             for (int32 w = 0; w < in_cols; ++w) {
    468               // (h_start, h_end) * (w_start, w_end) is the range that the input
    469               // vector projects to.
    470               const int32 hpad = h + pad_rows;
    471               const int32 wpad = w + pad_cols;
    472               const int32 h_start = (hpad < window_rows)
    473                                         ? 0
    474                                         : (hpad - window_rows) / row_stride + 1;
    475               const int32 h_end = std::min(hpad / row_stride + 1, out_height);
    476               const int32 w_start = (wpad < window_cols)
    477                                         ? 0
    478                                         : (wpad - window_cols) / col_stride + 1;
    479               const int32 w_end = std::min(wpad / col_stride + 1, out_width);
    480               // compute elementwise max
    481               const int32 in_offset = (b * in_rows + h) * in_cols + w;
    482               for (int32 ph = h_start; ph < h_end; ++ph) {
    483                 const int32 out_offset_base =
    484                     (out_offset_batch + ph) * out_width;
    485                 for (int32 pw = w_start; pw < w_end; ++pw) {
    486                   const int32 out_offset = out_offset_base + pw;
    487                   out_mat.col(out_offset) =
    488                       out_mat.col(out_offset).cwiseMax(in_mat.col(in_offset));
    489                 }
    490               }
    491             }
    492           }
    493         }
    494       };
    495 
    496       // TODO(andydavis) Consider sharding across batch x rows x cols.
    497       // TODO(andydavis) Consider a higher resolution shard cost model.
    498       const int64 shard_cost =
    499           params.tensor_in_rows * params.tensor_in_cols * params.depth;
    500       Shard(worker_threads.num_threads, worker_threads.workers,
    501             params.tensor_in_batch, shard_cost, shard);
    502     }
    503   }
    504 
    505   std::vector<int32> ksize_;
    506   std::vector<int32> stride_;
    507   Padding padding_;
    508   TensorFormat data_format_;
    509 };
    510 
    511 template <typename Device, typename T>
    512 void SpatialAvgPool(OpKernelContext* context, Tensor* output,
    513                     const Tensor& input, const PoolParameters& params,
    514                     const Padding& padding) {
    515   typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
    516       ConstEigenMatrixMap;
    517   typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
    518       EigenMatrixMap;
    519 
    520   auto in_flat = input.flat<T>();
    521   auto out_flat = output->flat<T>();
    522 
    523   auto shard = [&params, &in_flat, &out_flat](int64 start, int64 limit) {
    524     // Calculate indices for this shards chunk of work.
    525     const int64 input_image_size =
    526         params.tensor_in_rows * params.tensor_in_cols * params.depth;
    527     const int64 output_image_size =
    528         params.out_width * params.out_height * params.depth;
    529     const int64 shard_batch_size = limit - start;
    530 
    531     ConstEigenMatrixMap in_mat(
    532         in_flat.data() + start * input_image_size, params.depth,
    533         params.tensor_in_cols * params.tensor_in_rows * shard_batch_size);
    534     EigenMatrixMap out_mat(
    535         out_flat.data() + start * output_image_size, params.depth,
    536         params.out_width * params.out_height * shard_batch_size);
    537     Eigen::Matrix<T, Eigen::Dynamic, 1> out_count(out_mat.cols());
    538     out_count.setZero();
    539 
    540     // Initializes output to zero.
    541     out_mat.setZero();
    542 
    543     // The following code basically does the following:
    544     // 1. Flattens the input and output tensors into two dimensional arrays.
    545     //    tensor_in_as_matrix:
    546     //      depth by (tensor_in_cols * tensor_in_rows * tensor_in_batch)
    547     //    output_as_matrix:
    548     //      depth by (out_width * out_height * tensor_in_batch)
    549     //
    550     // 2. Walks through the set of columns in the flattened
    551     // tensor_in_as_matrix,
    552     //    and updates the corresponding column(s) in output_as_matrix with the
    553     //    average value.
    554     for (int b = 0; b < shard_batch_size; ++b) {
    555       for (int h = 0; h < params.tensor_in_rows; ++h) {
    556         for (int w = 0; w < params.tensor_in_cols; ++w) {
    557           // (h_start, h_end) * (w_start, w_end) is the range that the input
    558           // vector projects to.
    559           const int hpad = h + params.pad_rows;
    560           const int wpad = w + params.pad_cols;
    561           const int h_start =
    562               (hpad < params.window_rows)
    563                   ? 0
    564                   : (hpad - params.window_rows) / params.row_stride + 1;
    565           const int h_end =
    566               std::min<int>(hpad / params.row_stride + 1, params.out_height);
    567           const int w_start =
    568               (wpad < params.window_cols)
    569                   ? 0
    570                   : (wpad - params.window_cols) / params.col_stride + 1;
    571           const int w_end =
    572               std::min<int>(wpad / params.col_stride + 1, params.out_width);
    573           const int in_offset =
    574               (b * params.tensor_in_rows + h) * params.tensor_in_cols + w;
    575           Eigen::DSizes<Eigen::DenseIndex, 2> in_indices(0, in_offset);
    576           for (int ph = h_start; ph < h_end; ++ph) {
    577             for (int pw = w_start; pw < w_end; ++pw) {
    578               const int out_offset =
    579                   (b * params.out_height + ph) * params.out_width + pw;
    580               out_mat.col(out_offset) += in_mat.col(in_offset);
    581               out_count(out_offset) += T(1);
    582             }
    583           }
    584         }
    585       }
    586     }
    587 
    588     DCHECK_GT(out_count.minCoeff(), T(0));
    589     out_mat.array().rowwise() /= out_count.transpose().array();
    590   };
    591 
    592   const int64 work_unit_size =
    593       params.tensor_in_rows * params.tensor_in_cols * params.depth;
    594   // NOTE: Constants in calculation below were estimated based on benchmarking.
    595   // Nanoseconds/work_unit for benchmarks ranged from 0.01 to 0.001, and
    596   // so the factor 0.01 (i.e. 1/100) with a max of 10000, was chosen to limit
    597   // the work unit cost to an operating range in which it emperically performed
    598   // best.
    599   const int64 work_unit_cost = std::max(10000LL, work_unit_size / 100LL);
    600   const DeviceBase::CpuWorkerThreads& worker_threads =
    601       *(context->device()->tensorflow_cpu_worker_threads());
    602   Shard(worker_threads.num_threads, worker_threads.workers,
    603         params.tensor_in_batch, work_unit_cost, shard);
    604 }
    605 
    606 }  // namespace tensorflow
    607 
    608 #endif  // TENSORFLOW_KERNELS_POOLING_OPS_COMMON_H_
    609