1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifdef INTEL_MKL 17 18 #include "tensorflow/core/kernels/mkl_pooling_ops_common.h" 19 #include <limits> 20 #include <vector> 21 #include "tensorflow/core/common_runtime/device.h" 22 #include "tensorflow/core/framework/bounds_check.h" 23 #include "tensorflow/core/framework/common_shape_fns.h" 24 25 namespace tensorflow { 26 27 #ifndef INTEL_MKL_ML_ONLY 28 29 using mkldnn::pooling_avg; 30 using mkldnn::pooling_avg_exclude_padding; 31 using mkldnn::pooling_avg_include_padding; 32 using mkldnn::pooling_max; 33 using mkldnn::prop_kind; 34 35 template <typename T> 36 void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) { 37 DCHECK(fwdParams.alg_kind == pooling_max || 38 fwdParams.alg_kind == pooling_avg || 39 fwdParams.alg_kind == pooling_avg_include_padding || 40 fwdParams.alg_kind == pooling_avg_exclude_padding) 41 << "Pooling algorithm kind is not supported"; 42 43 context_.alg_kind = fwdParams.alg_kind; 44 context_.prop_kind = fwdParams.prop_kind; 45 46 // create memory desc 47 // FIXME: Pooling doesn't expose to get the src_primitive_desc, 48 // so src format is currently hard-coded. 49 // A utility function is used to do this, 50 // which may be broken with future CPU architectures 51 bool is_2d = (fwdParams.src_dims.size() == 4); 52 if (std::is_same<T, qint8>::value || std::is_same<T, quint8>::value) 53 context_.src_fmt = is_2d ? memory::format::nhwc : memory::format::ndhwc; 54 else 55 context_.src_fmt = get_desired_format(fwdParams.src_dims[1], is_2d); 56 57 context_.src_md.reset(new memory::desc({fwdParams.src_dims}, MklDnnType<T>(), 58 context_.src_fmt)); 59 context_.dst_md.reset(new memory::desc({fwdParams.dst_dims}, MklDnnType<T>(), 60 memory::format::any)); 61 62 // create a pooling descriptor 63 context_.fwd_desc.reset(new pooling_forward::desc( 64 fwdParams.prop_kind, fwdParams.alg_kind, *context_.src_md, 65 *context_.dst_md, fwdParams.strides, fwdParams.filter_dims, 66 fwdParams.padding_left, fwdParams.padding_right, padding_kind::zero)); 67 context_.fwd_pd.reset( 68 new pooling_forward::primitive_desc(*context_.fwd_desc, cpu_engine_)); 69 70 // store expected primitive format 71 context_.dst_fmt = static_cast<mkldnn::memory::format>( 72 context_.fwd_pd.get()->dst_primitive_desc().desc().data.format); 73 74 // create MKL-DNN internal memory object with dummy data 75 context_.src_mem.reset(new memory( 76 {{{fwdParams.src_dims}, MklDnnType<T>(), context_.src_fmt}, cpu_engine_}, 77 DummyData)); 78 context_.dst_mem.reset( 79 new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData)); 80 81 // for max pooling, need to return workspace(ws) for backward computing 82 if (fwdParams.alg_kind == pooling_max && 83 fwdParams.prop_kind == prop_kind::forward_training) { 84 auto ws_pd = context_.fwd_pd.get()->workspace_primitive_desc().desc().data; 85 // store workspace's dims and format to create workspace tensor 86 context_.ws_fmt = static_cast<mkldnn::memory::format>(ws_pd.format); 87 context_.ws_dims.assign(ws_pd.dims, ws_pd.dims + ws_pd.ndims); 88 context_.ws_dt = static_cast<mkldnn::memory::data_type>(ws_pd.data_type); 89 context_.ws_size = 90 context_.fwd_pd.get()->workspace_primitive_desc().get_size(); 91 context_.ws_mem.reset(new memory( 92 context_.fwd_pd.get()->workspace_primitive_desc(), DummyData)); 93 context_.fwd.reset(new pooling_forward(*context_.fwd_pd, *context_.src_mem, 94 *context_.dst_mem, 95 *context_.ws_mem)); 96 } else { 97 context_.fwd.reset(new pooling_forward(*context_.fwd_pd, *context_.src_mem, 98 *context_.dst_mem)); 99 } 100 101 context_.fwd_primitives.push_back(*context_.fwd); 102 } 103 104 template <typename T> 105 void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data, 106 void* ws_data) { 107 context_.src_mem->set_data_handle( 108 static_cast<void*>(const_cast<T*>(src_data))); 109 context_.dst_mem->set_data_handle(static_cast<void*>(dst_data)); 110 if (context_.alg_kind == pooling_max && 111 context_.prop_kind == 112 prop_kind::forward_training) { // max pooling must have ws 113 DCHECK(ws_data != nullptr); 114 context_.ws_mem->set_data_handle(ws_data); 115 } 116 context_.fwd_stream->submit(context_.fwd_primitives); 117 118 // set back data handle 119 context_.src_mem->set_data_handle(DummyData); 120 context_.dst_mem->set_data_handle(DummyData); 121 if (context_.alg_kind == pooling_max && 122 context_.prop_kind == 123 prop_kind::forward_training) { // max pooling must have ws 124 DCHECK(ws_data != nullptr); 125 context_.ws_mem->set_data_handle(DummyData); 126 } 127 } 128 129 template class MklPoolingFwdPrimitive<float>; 130 template class MklPoolingFwdPrimitive<quint8>; 131 template class MklPoolingFwdPrimitive<qint8>; 132 133 template <typename T> 134 void MklPoolingBwdPrimitive<T>::Setup(const MklPoolingParams& bwdParams) { 135 DCHECK(bwdParams.alg_kind == pooling_max || 136 bwdParams.alg_kind == pooling_avg || 137 bwdParams.alg_kind == pooling_avg_include_padding || 138 bwdParams.alg_kind == pooling_avg_exclude_padding) 139 << "Pooling algorithm kind is not supported"; 140 context_.alg_kind = bwdParams.alg_kind; 141 142 // check whether it is 2d or 3d 143 bool is_2d = (bwdParams.dst_dims.size() == 4); 144 // Create memory desc 145 context_.diff_src_md.reset(new memory::desc( 146 {bwdParams.src_dims}, MklDnnType<T>(), memory::format::any)); 147 context_.diff_dst_md.reset( 148 new memory::desc({bwdParams.dst_dims}, MklDnnType<T>(), 149 get_desired_format(bwdParams.dst_dims[1], is_2d))); 150 context_.bwd_desc.reset(new pooling_backward::desc( 151 bwdParams.alg_kind, *context_.diff_src_md, *context_.diff_dst_md, 152 bwdParams.strides, bwdParams.filter_dims, bwdParams.padding_left, 153 bwdParams.padding_right, padding_kind::zero)); 154 155 // create a forward primitive, 156 // which will be used as a hint for creating backward primitive 157 context_.fwd_desc.reset(new pooling_forward::desc( 158 bwdParams.prop_kind, bwdParams.alg_kind, *context_.diff_src_md, 159 *context_.diff_dst_md, bwdParams.strides, bwdParams.filter_dims, 160 bwdParams.padding_left, bwdParams.padding_right, padding_kind::zero)); 161 context_.fwd_pd.reset( 162 new pooling_forward::primitive_desc(*context_.fwd_desc, cpu_engine)); 163 context_.bwd_pd.reset(new pooling_backward::primitive_desc( 164 *context_.bwd_desc, cpu_engine, *context_.fwd_pd)); 165 166 // store expected primitive format 167 context_.diff_src_fmt = static_cast<mkldnn::memory::format>( 168 context_.bwd_pd.get()->diff_src_primitive_desc().desc().data.format); 169 context_.diff_dst_fmt = get_desired_format(bwdParams.dst_dims[1], is_2d); 170 171 // create MKL-DNN internal memory object with dummy data 172 context_.diff_src_mem.reset( 173 new memory(context_.bwd_pd.get()->diff_src_primitive_desc(), DummyData)); 174 context_.diff_dst_mem.reset(new memory( 175 {{{bwdParams.dst_dims}, MklDnnType<T>(), context_.diff_dst_fmt}, 176 cpu_engine}, 177 DummyData)); 178 179 // for max pooling, need to return workspace for backward 180 if (bwdParams.alg_kind == pooling_max) { 181 auto ws_pd = context_.fwd_pd.get()->workspace_primitive_desc().desc().data; 182 context_.ws_dims.assign(ws_pd.dims, ws_pd.dims + ws_pd.ndims); 183 context_.ws_fmt = get_desired_format(context_.ws_dims[1], is_2d); 184 context_.ws_dt = static_cast<mkldnn::memory::data_type>(ws_pd.data_type); 185 context_.ws_mem.reset(new memory( 186 {{{context_.ws_dims}, context_.ws_dt, context_.ws_fmt}, cpu_engine}, 187 DummyData)); 188 context_.bwd.reset( 189 new pooling_backward(*context_.bwd_pd, *context_.diff_dst_mem, 190 *context_.ws_mem, *context_.diff_src_mem)); 191 } else { 192 context_.bwd.reset(new pooling_backward( 193 *context_.bwd_pd, *context_.diff_dst_mem, *context_.diff_src_mem)); 194 } 195 context_.bwd_primitives.push_back(*context_.bwd); 196 } 197 198 template <typename T> 199 void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data, 200 T* diff_src_data, const void* ws_data) { 201 context_.diff_dst_mem->set_data_handle( 202 static_cast<void*>(const_cast<T*>(diff_dst_data))); 203 context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data)); 204 if (context_.alg_kind == pooling_max) { 205 DCHECK(ws_data != nullptr); 206 context_.ws_mem->set_data_handle(const_cast<void*>(ws_data)); 207 } 208 209 context_.bwd_stream->submit(context_.bwd_primitives); 210 // set back data handle 211 context_.diff_dst_mem->set_data_handle(DummyData); 212 context_.diff_src_mem->set_data_handle(DummyData); 213 if (context_.alg_kind == pooling_max) { 214 DCHECK(ws_data != nullptr); 215 context_.ws_mem->set_data_handle(DummyData); 216 } 217 } 218 219 template class MklPoolingBwdPrimitive<float>; 220 221 #endif 222 223 // Initialization for TensorFlow format 224 void MklPoolParameters::Init(OpKernelContext* context, 225 const std::vector<int32>& ksize, 226 const std::vector<int32>& stride, Padding padding, 227 TensorFormat data_format, 228 const TensorShape& tensor_in_shape) { 229 // For maxpooling, tensor_in should have 4 or 5 dimensions. 230 OP_REQUIRES(context, 231 tensor_in_shape.dims() == 4 || tensor_in_shape.dims() == 5, 232 errors::InvalidArgument("tensor_in must be 4 or 5-dimensional")); 233 234 depth = GetTensorDim(tensor_in_shape, data_format, 'C'); 235 if (tensor_in_shape.dims() == 4) { 236 // Pool2D 237 tensor_in_cols = GetTensorDim(tensor_in_shape, data_format, 'W'); 238 tensor_in_rows = GetTensorDim(tensor_in_shape, data_format, 'H'); 239 } else { 240 // Pool3D 241 tensor_in_planes = GetTensorDim(tensor_in_shape, data_format, '0'); 242 tensor_in_rows = GetTensorDim(tensor_in_shape, data_format, '1'); 243 tensor_in_cols = GetTensorDim(tensor_in_shape, data_format, '2'); 244 } 245 tensor_in_batch = GetTensorDim(tensor_in_shape, data_format, 'N'); 246 247 Init(context, ksize, stride, padding, data_format); 248 } 249 250 #ifdef INTEL_MKL_ML_ONLY 251 // Initialization for MKL format 252 void MklPoolParameters::Init(OpKernelContext* context, 253 const std::vector<int32>& ksize, 254 const std::vector<int32>& stride, Padding padding, 255 TensorFormat data_format, 256 const MklShape* mklInputShape) { 257 // Get the input sizes 258 depth = mklInputShape->GetSizes()[2]; 259 tensor_in_cols = mklInputShape->GetSizes()[0]; 260 tensor_in_rows = mklInputShape->GetSizes()[1]; 261 tensor_in_batch = mklInputShape->GetSizes()[3]; 262 263 Init(context, ksize, stride, padding, data_format); 264 } 265 #else 266 // Initialization for MKL format 267 void MklPoolParameters::Init(OpKernelContext* context, 268 const std::vector<int32>& ksize, 269 const std::vector<int32>& stride, Padding padding, 270 TensorFormat data_format, 271 const MklDnnShape* mklInputShape) { 272 // Get the input sizes 273 if (ksize.size() == 4) { 274 // Pool2D 275 depth = mklInputShape->GetDimension('C'); 276 tensor_in_cols = mklInputShape->GetDimension('W'); 277 tensor_in_rows = mklInputShape->GetDimension('H'); 278 tensor_in_batch = mklInputShape->GetDimension('N'); 279 } else { 280 // Pool3D 281 depth = mklInputShape->GetDimension3D('C'); 282 tensor_in_cols = mklInputShape->GetDimension3D('W'); 283 tensor_in_rows = mklInputShape->GetDimension3D('H'); 284 tensor_in_planes = mklInputShape->GetDimension3D('D'); 285 tensor_in_batch = mklInputShape->GetDimension3D('N'); 286 } 287 288 Init(context, ksize, stride, padding, data_format); 289 } 290 #endif // INTEL_MKL_ML_ONLY 291 // Common Initialization for TensorFlow and MKL formats 292 void MklPoolParameters::Init(OpKernelContext* context, 293 const std::vector<int32>& ksize, 294 const std::vector<int32>& stride, Padding padding, 295 TensorFormat data_format) { 296 // Get the data format 297 this->data_format = data_format; 298 299 bool is_pool2d = (ksize.size() == 4); 300 if (is_pool2d) { 301 // Pool2D 302 // Get the output sizes 303 window_rows = GetTensorDim(ksize, data_format, 'H'); 304 window_cols = GetTensorDim(ksize, data_format, 'W'); 305 depth_window = GetTensorDim(ksize, data_format, 'C'); 306 307 // Get the strides 308 row_stride = GetTensorDim(stride, data_format, 'H'); 309 col_stride = GetTensorDim(stride, data_format, 'W'); 310 depth_stride = GetTensorDim(stride, data_format, 'C'); 311 312 // We only support 2D pooling across width/height and depthwise 313 // pooling, not a combination. 314 OP_REQUIRES(context, 315 (depth_window == 1 || (window_rows == 1 && window_cols == 1)), 316 errors::Unimplemented( 317 "MaxPooling supports exactly one of pooling across depth " 318 "or pooling across width/height.")); 319 } else { 320 // Pool3D 321 // Get the output sizes 322 window_planes = GetTensorDim(ksize, data_format, '0'); 323 window_rows = GetTensorDim(ksize, data_format, '1'); 324 window_cols = GetTensorDim(ksize, data_format, '2'); 325 depth_window = GetTensorDim(ksize, data_format, 'C'); 326 327 // Get the strides 328 planes_stride = GetTensorDim(stride, data_format, '0'); 329 row_stride = GetTensorDim(stride, data_format, '1'); 330 col_stride = GetTensorDim(stride, data_format, '2'); 331 depth_stride = GetTensorDim(stride, data_format, 'C'); 332 333 // We only support 3D pooling across depth/width/height and depthwise 334 // pooling, not a combination. 335 OP_REQUIRES(context, 336 (depth_window == 1 || 337 (window_rows == 1 && window_cols == 1 && window_planes == 1)), 338 errors::Unimplemented( 339 "AvgPooling3D supports exactly one of pooling across depth " 340 "or pooling across depth/width/height.")); 341 } 342 343 if (depth_window == 1) { // we are pooling in the D (Pool3D only), H and W 344 if (!is_pool2d) { 345 OP_REQUIRES_OK( 346 context, GetWindowedOutputSizeVerbose(tensor_in_planes, window_planes, 347 planes_stride, padding, 348 &out_planes, &pad_P1, &pad_P2)); 349 } 350 351 OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( 352 tensor_in_rows, window_rows, row_stride, 353 padding, &out_height, &pad_top, &pad_bottom)); 354 355 OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( 356 tensor_in_cols, window_cols, col_stride, 357 padding, &out_width, &pad_left, &pad_right)); 358 #ifndef INTEL_MKL_ML_ONLY 359 // TF can work with int64, but mkldnn only supports int32 360 // Fail if the depth, height or width are greater than MAX_INT 361 // We check depth only for 3D pooling case 362 363 if (!is_pool2d) { 364 OP_REQUIRES(context, 365 FastBoundsCheck(out_planes, std::numeric_limits<int>::max()), 366 errors::InvalidArgument("output depth/planes is too large")); 367 } 368 369 OP_REQUIRES(context, 370 FastBoundsCheck(out_height, std::numeric_limits<int>::max()), 371 errors::InvalidArgument("output height is too large")); 372 373 OP_REQUIRES(context, 374 FastBoundsCheck(out_width, std::numeric_limits<int>::max()), 375 errors::InvalidArgument("output width is too large")); 376 #endif 377 out_depth = depth; // output will have the same depth as the input 378 } else { // we are pooling in the depth dimension 379 // Our current version of depthwise max pooling does not support 380 // any padding, and expects the depth_window to equal the depth 381 // stride (no overlapping). 382 OP_REQUIRES(context, depth % depth_window == 0, 383 errors::Unimplemented("Depthwise max pooling requires the" 384 " depth window to evenly divide the" 385 " input depth")); 386 OP_REQUIRES(context, depth_stride == depth_window, 387 errors::Unimplemented("Depthwise max pooling requires the" 388 " depth window to equal the depth" 389 " stride")); 390 391 // The current version of depthwise max is only implemented on CPU. 392 OP_REQUIRES(context, 393 (DeviceType(static_cast<Device*>(context->device()) 394 ->attributes() 395 .device_type()) == DeviceType(DEVICE_CPU)), 396 errors::Unimplemented("Depthwise max pooling is currently " 397 "only implemented for CPU devices.")); 398 399 out_depth = depth / depth_window; 400 } 401 } 402 403 // Transfers the right parameters for pooling to the op parameters 404 // Updates context->status if there is an invalid input. 405 void ExtractMklOpParams(OpKernelContext* context, TensorFormat data_format, 406 const MklPoolParameters& params, 407 MklPoolingOpParams* mkl_params) { 408 mkl_params->in_sizes[0] = params.tensor_in_cols; 409 mkl_params->in_sizes[1] = params.tensor_in_rows; 410 mkl_params->in_sizes[2] = params.depth; 411 mkl_params->in_sizes[3] = params.tensor_in_batch; 412 413 GetStridesFromSizes(data_format, mkl_params->in_strides, 414 mkl_params->in_sizes); 415 416 mkl_params->out_sizes[0] = params.out_width; 417 mkl_params->out_sizes[1] = params.out_height; 418 mkl_params->out_sizes[2] = params.depth; 419 mkl_params->out_sizes[3] = params.tensor_in_batch; 420 421 GetStridesFromSizes(data_format, mkl_params->out_strides, 422 mkl_params->out_sizes); 423 424 mkl_params->in_offset[0] = -params.pad_left; 425 mkl_params->in_offset[1] = -params.pad_top; 426 mkl_params->in_offset[2] = -params.pad_right; 427 mkl_params->in_offset[3] = -params.pad_bottom; 428 429 mkl_params->kernel_stride[0] = params.col_stride; 430 mkl_params->kernel_stride[1] = params.row_stride; 431 432 mkl_params->kernel_size[0] = params.window_cols; 433 mkl_params->kernel_size[1] = params.window_rows; 434 } 435 } // namespace tensorflow 436 #endif // INTEL_MKL 437