Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/kernels/pooling_ops_common.h"
     17 
     18 #include <vector>
     19 #include "tensorflow/core/common_runtime/device.h"
     20 #include "tensorflow/core/framework/register_types.h"
     21 #include "tensorflow/core/framework/tensor.h"
     22 
     23 #if GOOGLE_CUDA
     24 #include "tensorflow/core/kernels/conv_2d.h"
     25 #include "tensorflow/core/kernels/pooling_ops_common_gpu.h"
     26 #include "tensorflow/core/platform/stream_executor.h"
     27 #endif  // GOOGLE_CUDA
     28 
     29 namespace tensorflow {
     30 
     31 PoolParameters::PoolParameters(OpKernelContext* context,
     32                                const std::vector<int32>& ksize,
     33                                const std::vector<int32>& stride,
     34                                Padding padding, TensorFormat data_format,
     35                                const TensorShape& tensor_in_shape) {
     36   // For maxpooling, tensor_in should have 2 spatial dimensions.
     37   // Note: the total number of dimensions could be 4 for NHWC, NCHW,
     38   // or 5 for NCHW_VECT_C.
     39   OP_REQUIRES(context,
     40               GetTensorSpatialDims(tensor_in_shape.dims(), data_format) == 2,
     41               errors::InvalidArgument(
     42                   "tensor_in_shape must have 2 spatial dimensions. ",
     43                   tensor_in_shape.dims(), " ", data_format));
     44 
     45   this->data_format = data_format;
     46   depth = GetTensorDim(tensor_in_shape, data_format, 'C') *
     47           (data_format == FORMAT_NCHW_VECT_C ? 4 : 1);
     48   tensor_in_cols = GetTensorDim(tensor_in_shape, data_format, 'W');
     49   tensor_in_rows = GetTensorDim(tensor_in_shape, data_format, 'H');
     50   tensor_in_batch = GetTensorDim(tensor_in_shape, data_format, 'N');
     51   window_rows = GetTensorDim(ksize, data_format, 'H');
     52   window_cols = GetTensorDim(ksize, data_format, 'W');
     53   depth_window = GetTensorDim(ksize, data_format, 'C');
     54   row_stride = GetTensorDim(stride, data_format, 'H');
     55   col_stride = GetTensorDim(stride, data_format, 'W');
     56   depth_stride = GetTensorDim(stride, data_format, 'C');
     57 
     58   // We only support 2D pooling across width/height and depthwise
     59   // pooling, not a combination.
     60   OP_REQUIRES(context,
     61               (depth_window == 1 || (window_rows == 1 && window_cols == 1)),
     62               errors::Unimplemented(
     63                   "MaxPooling supports exactly one of pooling across depth "
     64                   "or pooling across width/height."));
     65 
     66   if (depth_window == 1) {
     67     OP_REQUIRES_OK(
     68         context, GetWindowedOutputSize(tensor_in_rows, window_rows, row_stride,
     69                                        padding, &out_height, &pad_rows));
     70     OP_REQUIRES_OK(
     71         context, GetWindowedOutputSize(tensor_in_cols, window_cols, col_stride,
     72                                        padding, &out_width, &pad_cols));
     73     pad_depth = 0;
     74     out_depth = depth;
     75   } else {
     76     // Our current version of depthwise max pooling does not support
     77     // any padding, and expects the depth_window to equal the
     78     // depth_stride (no overlapping).
     79     OP_REQUIRES(
     80         context, depth % depth_window == 0,
     81         errors::Unimplemented("Depthwise max pooling requires the depth "
     82                               "window to evenly divide the input depth"));
     83     OP_REQUIRES(
     84         context, depth_stride == depth_window,
     85         errors::Unimplemented("Depthwise max pooling requires the depth "
     86                               "window to equal the depth stride"));
     87 
     88     // The current version of depthwise max is only implemented on CPU.
     89     OP_REQUIRES(context,
     90                 (DeviceType(static_cast<Device*>(context->device())
     91                                 ->attributes()
     92                                 .device_type()) == DeviceType(DEVICE_CPU)),
     93                 errors::Unimplemented("Depthwise max pooling is currently "
     94                                       "only implemented for CPU devices."));
     95 
     96     pad_depth = 0;
     97     out_depth = depth / depth_window;
     98   }
     99 }
    100 
    101 TensorShape PoolParameters::forward_output_shape() {
    102   if (depth_window == 1) {
    103     // Spatial pooling
    104     return ShapeFromFormat(data_format, tensor_in_batch, out_height, out_width,
    105                            depth);
    106   } else {
    107     // Depthwise pooling
    108     return TensorShape(
    109         {tensor_in_batch, tensor_in_rows, tensor_in_cols, out_depth});
    110   }
    111 }
    112 
    113 #ifdef GOOGLE_CUDA
    114 
    115 namespace {
    116 template <typename T>
    117 perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory,
    118                                                     uint64 size) {
    119   perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory),
    120                                                 size * sizeof(T));
    121   perftools::gputools::DeviceMemory<T> typed(wrapped);
    122   return typed;
    123 }
    124 }  // namespace
    125 
    126 // Forward declarations of the functor specializations for GPU.
    127 namespace functor {
    128 #define DECLARE_GPU_SPEC(T)                                         \
    129   template <>                                                       \
    130   void TransformDepth<GPUDevice, T, Eigen::DenseIndex>::operator()( \
    131       const GPUDevice& d, typename TTypes<T, 4>::ConstTensor in,    \
    132       const Eigen::DSizes<Eigen::DenseIndex, 4>& shuffle,           \
    133       typename TTypes<T, 4>::Tensor out);                           \
    134   extern template struct TransformDepth<GPUDevice, T, Eigen::DenseIndex>;
    135 
    136 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC)
    137 #undef DECLARE_GPU_SPEC
    138 }  // namespace functor
    139 
    140 template <typename T>
    141 void DnnPoolingOp<T>::Compute(
    142     OpKernelContext* context,
    143     perftools::gputools::dnn::PoolingMode pooling_mode,
    144     const std::vector<int32>& size, const std::vector<int32>& stride,
    145     Padding padding, TensorFormat data_format, const Tensor& tensor_in,
    146     const TensorShape& tensor_out_shape, bool propagate_nans) {
    147   Tensor* tensor_out = nullptr;
    148   OP_REQUIRES_OK(context,
    149                  context->allocate_output(0, tensor_out_shape, &tensor_out));
    150   if (tensor_in.shape().num_elements() == 0) {
    151     return;
    152   }
    153 
    154   PoolParameters params{context, size,        stride,
    155                         padding, data_format, tensor_in.shape()};
    156   if (!context->status().ok()) {
    157     return;
    158   }
    159 
    160   /// For now, cudnn does not support NHWC format, so we need to convert it
    161   /// to NCHW before calling cudnn. We need to get rid of this once it is done
    162   Tensor transformed_input;
    163   if (data_format == FORMAT_NHWC) {
    164     OP_REQUIRES_OK(context, context->allocate_temp(
    165                                 DataTypeToEnum<T>::value,
    166                                 ShapeFromFormat(FORMAT_NCHW, tensor_in.shape(),
    167                                                 data_format),
    168                                 &transformed_input));
    169     functor::NHWCToNCHW<GPUDevice, T, 4>()(context->eigen_device<Device>(),
    170                                            tensor_in.tensor<T, 4>(),
    171                                            transformed_input.tensor<T, 4>());
    172   } else {
    173     transformed_input = tensor_in;
    174   }
    175   Tensor transformed_output;
    176   if (data_format == FORMAT_NHWC) {
    177     OP_REQUIRES_OK(context, context->allocate_temp(
    178                                 DataTypeToEnum<T>::value,
    179                                 ShapeFromFormat(FORMAT_NCHW, tensor_out_shape,
    180                                                 data_format),
    181                                 &transformed_output));
    182   } else {
    183     transformed_output = *tensor_out;
    184   }
    185 
    186   /// Get ready to call cudnn
    187   perftools::gputools::dnn::PoolingDescriptor pooling_desc;
    188   pooling_desc.set_pooling_mode(pooling_mode)
    189       .set_window_height(params.window_rows)
    190       .set_window_width(params.window_cols)
    191       .set_vertical_stride(params.row_stride)
    192       .set_horizontal_stride(params.col_stride)
    193       .set_vertical_padding(params.pad_rows)
    194       .set_horizontal_padding(params.pad_cols)
    195       .set_propagate_nans(propagate_nans);
    196 
    197   perftools::gputools::dnn::BatchDescriptor input_desc;
    198   input_desc.set_count(params.tensor_in_batch)
    199       .set_height(params.tensor_in_rows)
    200       .set_width(params.tensor_in_cols)
    201       .set_feature_map_count(params.depth)
    202       .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
    203 
    204   perftools::gputools::dnn::BatchDescriptor output_desc;
    205   output_desc.set_count(params.tensor_in_batch)
    206       .set_height(params.out_height)
    207       .set_width(params.out_width)
    208       .set_feature_map_count(params.depth)
    209       .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
    210 
    211   auto input_data = AsDeviceMemory(transformed_input.template flat<T>().data(),
    212                                    transformed_input.template flat<T>().size());
    213   auto output_data =
    214       AsDeviceMemory(transformed_output.template flat<T>().data(),
    215                      transformed_output.template flat<T>().size());
    216 
    217   auto* stream = context->op_device_context()->stream();
    218   OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
    219 
    220   bool status = stream
    221                     ->ThenPoolForward(pooling_desc, input_desc, input_data,
    222                                       output_desc, &output_data)
    223                     .ok();
    224   OP_REQUIRES(context, status,
    225               errors::Internal("cudnn PoolForward launch failed"));
    226 
    227   if (data_format == FORMAT_NHWC) {
    228     /// Transform the output data from NCHW back to NHWC
    229     auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
    230     functor::NCHWToNHWC<GPUDevice, T, 4>()(
    231         context->eigen_device<Device>(),
    232         toConstTensor(transformed_output).template tensor<T, 4>(),
    233         tensor_out->tensor<T, 4>());
    234   }
    235 }
    236 
    237 template <typename T>
    238 void DnnPoolingGradOp<T>::Compute(
    239     OpKernelContext* context,
    240     perftools::gputools::dnn::PoolingMode pooling_mode,
    241     const std::vector<int32>& size, const std::vector<int32>& stride,
    242     Padding padding, TensorFormat data_format, const Tensor* tensor_in,
    243     const Tensor* tensor_out, const Tensor& out_backprop,
    244     const TensorShape& tensor_in_shape, bool propagate_nans) {
    245   CHECK((pooling_mode != perftools::gputools::dnn::PoolingMode::kMaximum) ||
    246         (tensor_in && tensor_out))
    247       << "For MaxPoolGrad, both tensor_in and tensor_out needs to be "
    248          "specified";
    249 
    250   Tensor* input_backprop = nullptr;
    251   OP_REQUIRES_OK(context,
    252                  context->allocate_output(0, tensor_in_shape, &input_backprop));
    253   if (tensor_in_shape.num_elements() == 0) {
    254     return;
    255   }
    256 
    257   PoolParameters params{context, size,        stride,
    258                         padding, data_format, tensor_in_shape};
    259   if (!context->status().ok()) {
    260     return;
    261   }
    262 
    263   /// For now, cudnn does not support NHWC format, so we need to convert it
    264   /// to NCHW before calling cudnn. We need to get rid of this once it is done
    265   Tensor transformed_input;
    266   TensorShape transformed_input_shape;
    267   if (data_format == FORMAT_NHWC || !tensor_in) {
    268     transformed_input_shape =
    269         ShapeFromFormat(FORMAT_NCHW, tensor_in_shape, data_format);
    270     OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value,
    271                                                    transformed_input_shape,
    272                                                    &transformed_input));
    273   } else {
    274     transformed_input = *tensor_in;
    275   }
    276   Tensor transformed_output;
    277   TensorShape transformed_output_shape;
    278   if (data_format == FORMAT_NHWC || !tensor_out) {
    279     transformed_output_shape =
    280         ShapeFromFormat(FORMAT_NCHW, out_backprop.shape(), data_format);
    281     OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value,
    282                                                    transformed_output_shape,
    283                                                    &transformed_output));
    284   } else {
    285     transformed_output = *tensor_out;
    286   }
    287   Tensor transformed_input_backprop;
    288   if (data_format == FORMAT_NHWC) {
    289     OP_REQUIRES_OK(context,
    290                    context->allocate_temp(DataTypeToEnum<T>::value,
    291                                           transformed_input_shape,
    292                                           &transformed_input_backprop));
    293   } else {
    294     transformed_input_backprop = *input_backprop;
    295   }
    296   Tensor transformed_output_backprop;
    297   if (data_format == FORMAT_NHWC) {
    298     OP_REQUIRES_OK(context,
    299                    context->allocate_temp(DataTypeToEnum<T>::value,
    300                                           transformed_output_shape,
    301                                           &transformed_output_backprop));
    302   } else {
    303     transformed_output_backprop = out_backprop;
    304   }
    305 
    306   if (data_format == FORMAT_NHWC) {
    307     /// Convert the data from NHWC to NCHW if necessary.
    308     if (tensor_in) {
    309       // For AvgPoolGrad, the original input tensor is not necessary. However,
    310       // cudnn still requires them to run, although they do not affect the
    311       // results.
    312       functor::NHWCToNCHW<GPUDevice, T, 4>()(context->eigen_device<Device>(),
    313                                              tensor_in->tensor<T, 4>(),
    314                                              transformed_input.tensor<T, 4>());
    315     }
    316     if (tensor_out) {
    317       // For AvgPoolGrad, the original output tensor is not necessary. However,
    318       // cudnn still requires them to run, although they do not affect the
    319       // results.
    320       functor::NHWCToNCHW<GPUDevice, T, 4>()(context->eigen_device<Device>(),
    321                                              tensor_out->tensor<T, 4>(),
    322                                              transformed_output.tensor<T, 4>());
    323     }
    324     functor::NHWCToNCHW<GPUDevice, T, 4>()(
    325         context->eigen_device<Device>(), out_backprop.tensor<T, 4>(),
    326         transformed_output_backprop.tensor<T, 4>());
    327   }
    328 
    329   /// Get ready to call cudnn
    330   perftools::gputools::dnn::PoolingDescriptor pooling_desc;
    331   pooling_desc.set_pooling_mode(pooling_mode)
    332       .set_window_height(params.window_rows)
    333       .set_window_width(params.window_cols)
    334       .set_vertical_stride(params.row_stride)
    335       .set_horizontal_stride(params.col_stride)
    336       .set_vertical_padding(params.pad_rows)
    337       .set_horizontal_padding(params.pad_cols)
    338       .set_propagate_nans(propagate_nans);
    339 
    340   perftools::gputools::dnn::BatchDescriptor orig_output_desc;
    341   orig_output_desc.set_count(params.tensor_in_batch)
    342       .set_height(params.out_height)
    343       .set_width(params.out_width)
    344       .set_feature_map_count(params.depth)
    345       .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
    346 
    347   perftools::gputools::dnn::BatchDescriptor orig_input_desc;
    348   orig_input_desc.set_count(params.tensor_in_batch)
    349       .set_height(params.tensor_in_rows)
    350       .set_width(params.tensor_in_cols)
    351       .set_feature_map_count(params.depth)
    352       .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
    353 
    354   auto orig_output_data =
    355       AsDeviceMemory(transformed_output.template flat<T>().data(),
    356                      transformed_output.template flat<T>().size());
    357   auto orig_input_data =
    358       AsDeviceMemory(transformed_input.template flat<T>().data(),
    359                      transformed_input.template flat<T>().size());
    360   auto output_backprop_data =
    361       AsDeviceMemory(transformed_output_backprop.template flat<T>().data(),
    362                      transformed_output_backprop.template flat<T>().size());
    363   auto input_backprop_data =
    364       AsDeviceMemory(transformed_input_backprop.template flat<T>().data(),
    365                      transformed_input_backprop.template flat<T>().size());
    366 
    367   auto* stream = context->op_device_context()->stream();
    368   OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
    369 
    370   bool status =
    371       stream
    372           ->ThenPoolBackward(pooling_desc, orig_input_desc, orig_input_data,
    373                              orig_output_desc, orig_output_data,
    374                              output_backprop_data, &input_backprop_data)
    375           .ok();
    376   OP_REQUIRES(context, status,
    377               errors::Internal("cudnn PoolBackward launch failed"));
    378 
    379   if (data_format == FORMAT_NHWC) {
    380     /// Transform the output data from NCHW back to NHWC.
    381     auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
    382     functor::NCHWToNHWC<GPUDevice, T, 4>()(
    383         context->eigen_device<Device>(),
    384         toConstTensor(transformed_input_backprop).template tensor<T, 4>(),
    385         input_backprop->tensor<T, 4>());
    386   }
    387 }
    388 
    389 #define DEFINE_DNN_OPS(T)         \
    390   template class DnnPoolingOp<T>; \
    391   template class DnnPoolingGradOp<T>;
    392 TF_CALL_GPU_NUMBER_TYPES(DEFINE_DNN_OPS)
    393 #undef DEFINE_DNN_OPS
    394 
    395 #endif  // GOOGLE_CUDA
    396 
    397 }  // namespace tensorflow
    398