Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifdef INTEL_MKL
     17 
     18 #include "tensorflow/core/kernels/mkl_pooling_ops_common.h"
     19 #include <limits>
     20 #include <vector>
     21 #include "tensorflow/core/common_runtime/device.h"
     22 #include "tensorflow/core/framework/bounds_check.h"
     23 #include "tensorflow/core/framework/common_shape_fns.h"
     24 
     25 namespace tensorflow {
     26 
     27 #ifndef INTEL_MKL_ML_ONLY
     28 
     29 using mkldnn::pooling_avg;
     30 using mkldnn::pooling_avg_exclude_padding;
     31 using mkldnn::pooling_avg_include_padding;
     32 using mkldnn::pooling_max;
     33 using mkldnn::prop_kind;
     34 
     35 template <typename T>
     36 void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) {
     37   DCHECK(fwdParams.alg_kind == pooling_max ||
     38          fwdParams.alg_kind == pooling_avg ||
     39          fwdParams.alg_kind == pooling_avg_include_padding ||
     40          fwdParams.alg_kind == pooling_avg_exclude_padding)
     41       << "Pooling algorithm kind is not supported";
     42 
     43   context_.alg_kind = fwdParams.alg_kind;
     44   context_.prop_kind = fwdParams.prop_kind;
     45 
     46   // create memory desc
     47   // FIXME: Pooling doesn't expose to get the src_primitive_desc,
     48   //        so src format is currently hard-coded.
     49   //        A utility function is used to do this,
     50   //        which may be broken with future CPU architectures
     51   bool is_2d = (fwdParams.src_dims.size() == 4);
     52   if (std::is_same<T, qint8>::value || std::is_same<T, quint8>::value)
     53     context_.src_fmt = is_2d ? memory::format::nhwc : memory::format::ndhwc;
     54   else
     55     context_.src_fmt = get_desired_format(fwdParams.src_dims[1], is_2d);
     56 
     57   context_.src_md.reset(new memory::desc({fwdParams.src_dims}, MklDnnType<T>(),
     58                                          context_.src_fmt));
     59   context_.dst_md.reset(new memory::desc({fwdParams.dst_dims}, MklDnnType<T>(),
     60                                          memory::format::any));
     61 
     62   // create a pooling descriptor
     63   context_.fwd_desc.reset(new pooling_forward::desc(
     64       fwdParams.prop_kind, fwdParams.alg_kind, *context_.src_md,
     65       *context_.dst_md, fwdParams.strides, fwdParams.filter_dims,
     66       fwdParams.padding_left, fwdParams.padding_right, padding_kind::zero));
     67   context_.fwd_pd.reset(
     68       new pooling_forward::primitive_desc(*context_.fwd_desc, cpu_engine_));
     69 
     70   // store expected primitive format
     71   context_.dst_fmt = static_cast<mkldnn::memory::format>(
     72       context_.fwd_pd.get()->dst_primitive_desc().desc().data.format);
     73 
     74   // create MKL-DNN internal memory object with dummy data
     75   context_.src_mem.reset(new memory(
     76       {{{fwdParams.src_dims}, MklDnnType<T>(), context_.src_fmt}, cpu_engine_},
     77       DummyData));
     78   context_.dst_mem.reset(
     79       new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData));
     80 
     81   // for max pooling, need to return workspace(ws) for backward computing
     82   if (fwdParams.alg_kind == pooling_max &&
     83       fwdParams.prop_kind == prop_kind::forward_training) {
     84     auto ws_pd = context_.fwd_pd.get()->workspace_primitive_desc().desc().data;
     85     // store workspace's dims and format to create workspace tensor
     86     context_.ws_fmt = static_cast<mkldnn::memory::format>(ws_pd.format);
     87     context_.ws_dims.assign(ws_pd.dims, ws_pd.dims + ws_pd.ndims);
     88     context_.ws_dt = static_cast<mkldnn::memory::data_type>(ws_pd.data_type);
     89     context_.ws_size =
     90         context_.fwd_pd.get()->workspace_primitive_desc().get_size();
     91     context_.ws_mem.reset(new memory(
     92         context_.fwd_pd.get()->workspace_primitive_desc(), DummyData));
     93     context_.fwd.reset(new pooling_forward(*context_.fwd_pd, *context_.src_mem,
     94                                            *context_.dst_mem,
     95                                            *context_.ws_mem));
     96   } else {
     97     context_.fwd.reset(new pooling_forward(*context_.fwd_pd, *context_.src_mem,
     98                                            *context_.dst_mem));
     99   }
    100 
    101   context_.fwd_primitives.push_back(*context_.fwd);
    102 }
    103 
    104 template <typename T>
    105 void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data,
    106                                         void* ws_data) {
    107   context_.src_mem->set_data_handle(
    108       static_cast<void*>(const_cast<T*>(src_data)));
    109   context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
    110   if (context_.alg_kind == pooling_max &&
    111       context_.prop_kind ==
    112           prop_kind::forward_training) {  // max pooling must have ws
    113     DCHECK(ws_data != nullptr);
    114     context_.ws_mem->set_data_handle(ws_data);
    115   }
    116   context_.fwd_stream->submit(context_.fwd_primitives);
    117 
    118   // set back data handle
    119   context_.src_mem->set_data_handle(DummyData);
    120   context_.dst_mem->set_data_handle(DummyData);
    121   if (context_.alg_kind == pooling_max &&
    122       context_.prop_kind ==
    123           prop_kind::forward_training) {  // max pooling must have ws
    124     DCHECK(ws_data != nullptr);
    125     context_.ws_mem->set_data_handle(DummyData);
    126   }
    127 }
    128 
    129 template class MklPoolingFwdPrimitive<float>;
    130 template class MklPoolingFwdPrimitive<quint8>;
    131 template class MklPoolingFwdPrimitive<qint8>;
    132 
    133 template <typename T>
    134 void MklPoolingBwdPrimitive<T>::Setup(const MklPoolingParams& bwdParams) {
    135   DCHECK(bwdParams.alg_kind == pooling_max ||
    136          bwdParams.alg_kind == pooling_avg ||
    137          bwdParams.alg_kind == pooling_avg_include_padding ||
    138          bwdParams.alg_kind == pooling_avg_exclude_padding)
    139       << "Pooling algorithm kind is not supported";
    140   context_.alg_kind = bwdParams.alg_kind;
    141 
    142   // check whether it is 2d or 3d
    143   bool is_2d = (bwdParams.dst_dims.size() == 4);
    144   // Create memory desc
    145   context_.diff_src_md.reset(new memory::desc(
    146       {bwdParams.src_dims}, MklDnnType<T>(), memory::format::any));
    147   context_.diff_dst_md.reset(
    148       new memory::desc({bwdParams.dst_dims}, MklDnnType<T>(),
    149                        get_desired_format(bwdParams.dst_dims[1], is_2d)));
    150   context_.bwd_desc.reset(new pooling_backward::desc(
    151       bwdParams.alg_kind, *context_.diff_src_md, *context_.diff_dst_md,
    152       bwdParams.strides, bwdParams.filter_dims, bwdParams.padding_left,
    153       bwdParams.padding_right, padding_kind::zero));
    154 
    155   // create a forward primitive,
    156   // which will be used as a hint for creating backward primitive
    157   context_.fwd_desc.reset(new pooling_forward::desc(
    158       bwdParams.prop_kind, bwdParams.alg_kind, *context_.diff_src_md,
    159       *context_.diff_dst_md, bwdParams.strides, bwdParams.filter_dims,
    160       bwdParams.padding_left, bwdParams.padding_right, padding_kind::zero));
    161   context_.fwd_pd.reset(
    162       new pooling_forward::primitive_desc(*context_.fwd_desc, cpu_engine));
    163   context_.bwd_pd.reset(new pooling_backward::primitive_desc(
    164       *context_.bwd_desc, cpu_engine, *context_.fwd_pd));
    165 
    166   // store expected primitive format
    167   context_.diff_src_fmt = static_cast<mkldnn::memory::format>(
    168       context_.bwd_pd.get()->diff_src_primitive_desc().desc().data.format);
    169   context_.diff_dst_fmt = get_desired_format(bwdParams.dst_dims[1], is_2d);
    170 
    171   // create MKL-DNN internal memory object with dummy data
    172   context_.diff_src_mem.reset(
    173       new memory(context_.bwd_pd.get()->diff_src_primitive_desc(), DummyData));
    174   context_.diff_dst_mem.reset(new memory(
    175       {{{bwdParams.dst_dims}, MklDnnType<T>(), context_.diff_dst_fmt},
    176        cpu_engine},
    177       DummyData));
    178 
    179   // for max pooling, need to return workspace for backward
    180   if (bwdParams.alg_kind == pooling_max) {
    181     auto ws_pd = context_.fwd_pd.get()->workspace_primitive_desc().desc().data;
    182     context_.ws_dims.assign(ws_pd.dims, ws_pd.dims + ws_pd.ndims);
    183     context_.ws_fmt = get_desired_format(context_.ws_dims[1], is_2d);
    184     context_.ws_dt = static_cast<mkldnn::memory::data_type>(ws_pd.data_type);
    185     context_.ws_mem.reset(new memory(
    186         {{{context_.ws_dims}, context_.ws_dt, context_.ws_fmt}, cpu_engine},
    187         DummyData));
    188     context_.bwd.reset(
    189         new pooling_backward(*context_.bwd_pd, *context_.diff_dst_mem,
    190                              *context_.ws_mem, *context_.diff_src_mem));
    191   } else {
    192     context_.bwd.reset(new pooling_backward(
    193         *context_.bwd_pd, *context_.diff_dst_mem, *context_.diff_src_mem));
    194   }
    195   context_.bwd_primitives.push_back(*context_.bwd);
    196 }
    197 
    198 template <typename T>
    199 void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data,
    200                                         T* diff_src_data, const void* ws_data) {
    201   context_.diff_dst_mem->set_data_handle(
    202       static_cast<void*>(const_cast<T*>(diff_dst_data)));
    203   context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
    204   if (context_.alg_kind == pooling_max) {
    205     DCHECK(ws_data != nullptr);
    206     context_.ws_mem->set_data_handle(const_cast<void*>(ws_data));
    207   }
    208 
    209   context_.bwd_stream->submit(context_.bwd_primitives);
    210   //  set back data handle
    211   context_.diff_dst_mem->set_data_handle(DummyData);
    212   context_.diff_src_mem->set_data_handle(DummyData);
    213   if (context_.alg_kind == pooling_max) {
    214     DCHECK(ws_data != nullptr);
    215     context_.ws_mem->set_data_handle(DummyData);
    216   }
    217 }
    218 
    219 template class MklPoolingBwdPrimitive<float>;
    220 
    221 #endif
    222 
    223 // Initialization for TensorFlow format
    224 void MklPoolParameters::Init(OpKernelContext* context,
    225                              const std::vector<int32>& ksize,
    226                              const std::vector<int32>& stride, Padding padding,
    227                              TensorFormat data_format,
    228                              const TensorShape& tensor_in_shape) {
    229   // For maxpooling, tensor_in should have 4 or 5 dimensions.
    230   OP_REQUIRES(context,
    231               tensor_in_shape.dims() == 4 || tensor_in_shape.dims() == 5,
    232               errors::InvalidArgument("tensor_in must be 4 or 5-dimensional"));
    233 
    234   depth = GetTensorDim(tensor_in_shape, data_format, 'C');
    235   if (tensor_in_shape.dims() == 4) {
    236     // Pool2D
    237     tensor_in_cols = GetTensorDim(tensor_in_shape, data_format, 'W');
    238     tensor_in_rows = GetTensorDim(tensor_in_shape, data_format, 'H');
    239   } else {
    240     // Pool3D
    241     tensor_in_planes = GetTensorDim(tensor_in_shape, data_format, '0');
    242     tensor_in_rows = GetTensorDim(tensor_in_shape, data_format, '1');
    243     tensor_in_cols = GetTensorDim(tensor_in_shape, data_format, '2');
    244   }
    245   tensor_in_batch = GetTensorDim(tensor_in_shape, data_format, 'N');
    246 
    247   Init(context, ksize, stride, padding, data_format);
    248 }
    249 
    250 #ifdef INTEL_MKL_ML_ONLY
    251 // Initialization for MKL format
    252 void MklPoolParameters::Init(OpKernelContext* context,
    253                              const std::vector<int32>& ksize,
    254                              const std::vector<int32>& stride, Padding padding,
    255                              TensorFormat data_format,
    256                              const MklShape* mklInputShape) {
    257   // Get the input sizes
    258   depth = mklInputShape->GetSizes()[2];
    259   tensor_in_cols = mklInputShape->GetSizes()[0];
    260   tensor_in_rows = mklInputShape->GetSizes()[1];
    261   tensor_in_batch = mklInputShape->GetSizes()[3];
    262 
    263   Init(context, ksize, stride, padding, data_format);
    264 }
    265 #else
    266 // Initialization for MKL format
    267 void MklPoolParameters::Init(OpKernelContext* context,
    268                              const std::vector<int32>& ksize,
    269                              const std::vector<int32>& stride, Padding padding,
    270                              TensorFormat data_format,
    271                              const MklDnnShape* mklInputShape) {
    272   // Get the input sizes
    273   if (ksize.size() == 4) {
    274     // Pool2D
    275     depth = mklInputShape->GetDimension('C');
    276     tensor_in_cols = mklInputShape->GetDimension('W');
    277     tensor_in_rows = mklInputShape->GetDimension('H');
    278     tensor_in_batch = mklInputShape->GetDimension('N');
    279   } else {
    280     // Pool3D
    281     depth = mklInputShape->GetDimension3D('C');
    282     tensor_in_cols = mklInputShape->GetDimension3D('W');
    283     tensor_in_rows = mklInputShape->GetDimension3D('H');
    284     tensor_in_planes = mklInputShape->GetDimension3D('D');
    285     tensor_in_batch = mklInputShape->GetDimension3D('N');
    286   }
    287 
    288   Init(context, ksize, stride, padding, data_format);
    289 }
    290 #endif  // INTEL_MKL_ML_ONLY
    291 // Common Initialization for TensorFlow and MKL formats
    292 void MklPoolParameters::Init(OpKernelContext* context,
    293                              const std::vector<int32>& ksize,
    294                              const std::vector<int32>& stride, Padding padding,
    295                              TensorFormat data_format) {
    296   // Get the data format
    297   this->data_format = data_format;
    298 
    299   bool is_pool2d = (ksize.size() == 4);
    300   if (is_pool2d) {
    301     // Pool2D
    302     // Get the output sizes
    303     window_rows = GetTensorDim(ksize, data_format, 'H');
    304     window_cols = GetTensorDim(ksize, data_format, 'W');
    305     depth_window = GetTensorDim(ksize, data_format, 'C');
    306 
    307     // Get the strides
    308     row_stride = GetTensorDim(stride, data_format, 'H');
    309     col_stride = GetTensorDim(stride, data_format, 'W');
    310     depth_stride = GetTensorDim(stride, data_format, 'C');
    311 
    312     // We only support 2D pooling across width/height and depthwise
    313     // pooling, not a combination.
    314     OP_REQUIRES(context,
    315                 (depth_window == 1 || (window_rows == 1 && window_cols == 1)),
    316                 errors::Unimplemented(
    317                     "MaxPooling supports exactly one of pooling across depth "
    318                     "or pooling across width/height."));
    319   } else {
    320     // Pool3D
    321     // Get the output sizes
    322     window_planes = GetTensorDim(ksize, data_format, '0');
    323     window_rows = GetTensorDim(ksize, data_format, '1');
    324     window_cols = GetTensorDim(ksize, data_format, '2');
    325     depth_window = GetTensorDim(ksize, data_format, 'C');
    326 
    327     // Get the strides
    328     planes_stride = GetTensorDim(stride, data_format, '0');
    329     row_stride = GetTensorDim(stride, data_format, '1');
    330     col_stride = GetTensorDim(stride, data_format, '2');
    331     depth_stride = GetTensorDim(stride, data_format, 'C');
    332 
    333     // We only support 3D pooling across depth/width/height and depthwise
    334     // pooling, not a combination.
    335     OP_REQUIRES(context,
    336                 (depth_window == 1 ||
    337                  (window_rows == 1 && window_cols == 1 && window_planes == 1)),
    338                 errors::Unimplemented(
    339                     "AvgPooling3D supports exactly one of pooling across depth "
    340                     "or pooling across depth/width/height."));
    341   }
    342 
    343   if (depth_window == 1) {  // we are pooling in the D (Pool3D only), H and W
    344     if (!is_pool2d) {
    345       OP_REQUIRES_OK(
    346           context, GetWindowedOutputSizeVerbose(tensor_in_planes, window_planes,
    347                                                 planes_stride, padding,
    348                                                 &out_planes, &pad_P1, &pad_P2));
    349     }
    350 
    351     OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
    352                                 tensor_in_rows, window_rows, row_stride,
    353                                 padding, &out_height, &pad_top, &pad_bottom));
    354 
    355     OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
    356                                 tensor_in_cols, window_cols, col_stride,
    357                                 padding, &out_width, &pad_left, &pad_right));
    358 #ifndef INTEL_MKL_ML_ONLY
    359     // TF can work with int64, but mkldnn only supports int32
    360     // Fail if the depth, height or width are greater than MAX_INT
    361     // We check depth only for 3D pooling case
    362 
    363     if (!is_pool2d) {
    364       OP_REQUIRES(context,
    365                   FastBoundsCheck(out_planes, std::numeric_limits<int>::max()),
    366                   errors::InvalidArgument("output depth/planes is too large"));
    367     }
    368 
    369     OP_REQUIRES(context,
    370                 FastBoundsCheck(out_height, std::numeric_limits<int>::max()),
    371                 errors::InvalidArgument("output height is too large"));
    372 
    373     OP_REQUIRES(context,
    374                 FastBoundsCheck(out_width, std::numeric_limits<int>::max()),
    375                 errors::InvalidArgument("output width is too large"));
    376 #endif
    377     out_depth = depth;  // output will have the same depth as the input
    378   } else {              // we are pooling in the depth dimension
    379     // Our current version of depthwise max pooling does not support
    380     // any padding, and expects the depth_window to equal the depth
    381     // stride (no overlapping).
    382     OP_REQUIRES(context, depth % depth_window == 0,
    383                 errors::Unimplemented("Depthwise max pooling requires the"
    384                                       " depth window to evenly divide the"
    385                                       " input depth"));
    386     OP_REQUIRES(context, depth_stride == depth_window,
    387                 errors::Unimplemented("Depthwise max pooling requires the"
    388                                       " depth window to equal the depth"
    389                                       " stride"));
    390 
    391     // The current version of depthwise max is only implemented on CPU.
    392     OP_REQUIRES(context,
    393                 (DeviceType(static_cast<Device*>(context->device())
    394                                 ->attributes()
    395                                 .device_type()) == DeviceType(DEVICE_CPU)),
    396                 errors::Unimplemented("Depthwise max pooling is currently "
    397                                       "only implemented for CPU devices."));
    398 
    399     out_depth = depth / depth_window;
    400   }
    401 }
    402 
    403 // Transfers the right parameters for pooling to the op parameters
    404 // Updates context->status if there is an invalid input.
    405 void ExtractMklOpParams(OpKernelContext* context, TensorFormat data_format,
    406                         const MklPoolParameters& params,
    407                         MklPoolingOpParams* mkl_params) {
    408   mkl_params->in_sizes[0] = params.tensor_in_cols;
    409   mkl_params->in_sizes[1] = params.tensor_in_rows;
    410   mkl_params->in_sizes[2] = params.depth;
    411   mkl_params->in_sizes[3] = params.tensor_in_batch;
    412 
    413   GetStridesFromSizes(data_format, mkl_params->in_strides,
    414                       mkl_params->in_sizes);
    415 
    416   mkl_params->out_sizes[0] = params.out_width;
    417   mkl_params->out_sizes[1] = params.out_height;
    418   mkl_params->out_sizes[2] = params.depth;
    419   mkl_params->out_sizes[3] = params.tensor_in_batch;
    420 
    421   GetStridesFromSizes(data_format, mkl_params->out_strides,
    422                       mkl_params->out_sizes);
    423 
    424   mkl_params->in_offset[0] = -params.pad_left;
    425   mkl_params->in_offset[1] = -params.pad_top;
    426   mkl_params->in_offset[2] = -params.pad_right;
    427   mkl_params->in_offset[3] = -params.pad_bottom;
    428 
    429   mkl_params->kernel_stride[0] = params.col_stride;
    430   mkl_params->kernel_stride[1] = params.row_stride;
    431 
    432   mkl_params->kernel_size[0] = params.window_cols;
    433   mkl_params->kernel_size[1] = params.window_rows;
    434 }
    435 }  // namespace tensorflow
    436 #endif  // INTEL_MKL
    437